From 646989cc5bb61f73335017243e0b008b149ba0ab Mon Sep 17 00:00:00 2001 From: Rich Coggins Date: Wed, 14 Jun 2023 10:19:18 -0400 Subject: [PATCH 001/242] Improve (#356) to support registration of wildcard URLs (#359) * Improve (#356) to support registration of wildcard URLs * Add TestAzureChatCompletions & TestAzureChatCompletionsWithCustomDeploymentName * Remove TestAzureChatCompletionsWithCustomDeploymentName --------- Co-authored-by: coggsflod --- chat_test.go | 18 ++++++++++++++++++ internal/test/server.go | 14 +++++++++----- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/chat_test.go b/chat_test.go index ebe29f9eb..a43bb4aa6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -67,6 +67,24 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestAzureChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error diff --git a/internal/test/server.go b/internal/test/server.go index 79d55c405..3813ff869 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -4,6 +4,7 @@ import ( "log" "net/http" "net/http/httptest" + "regexp" ) const testAPI = "this-is-my-secure-token-do-not-steal!!" @@ -36,11 +37,14 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { return } - handlerCall, ok := ts.handlers[r.URL.Path] - if !ok { - http.Error(w, "the resource path doesn't exist", http.StatusNotFound) - return + // Handle /path/* routes. + for route, handler := range ts.handlers { + pattern, _ := regexp.Compile(route) + if pattern.MatchString(r.URL.Path) { + handler(w, r) + return + } } - handlerCall(w, r) + http.Error(w, "the resource path doesn't exist", http.StatusNotFound) })) } From 7e76a682a949cf234c05896d2e2aa3f7d5c5d118 Mon Sep 17 00:00:00 2001 From: beichideyuwan <57309366+beichideyuwan@users.noreply.github.com> Date: Wed, 14 Jun 2023 22:23:03 +0800 Subject: [PATCH 002/242] Add 16k 0613 model (#365) * add 16k_0613 model * add 16k_0613 model * add model: --- completion.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/completion.go b/completion.go index e7bf75acb..efded208b 100644 --- a/completion.go +++ b/completion.go @@ -26,6 +26,7 @@ const ( GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3TextDavinci003 = "text-davinci-003" GPT3TextDavinci002 = "text-davinci-002" @@ -52,16 +53,17 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { - GPT3Dot5Turbo: true, - GPT3Dot5Turbo0301: true, - GPT3Dot5Turbo0613: true, - GPT3Dot5Turbo16K: true, - GPT4: true, - GPT40314: true, - GPT40613: true, - GPT432K: true, - GPT432K0314: true, - GPT432K0613: true, + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo16K: true, + GPT3Dot5Turbo16K0613: true, + GPT4: true, + GPT40314: true, + GPT40613: true, + GPT432K: true, + GPT432K0314: true, + GPT432K0613: true, }, "/chat/completions": { CodexCodeDavinci002: true, From 2bd65aa720926506c49ddf89d7e619b3b83512c4 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Thu, 15 Jun 2023 16:49:54 +0800 Subject: [PATCH 003/242] feat(chat): support function call api (#369) * feat(chat): support function call api * rename struct & add const ChatMessageRoleFunction --- chat.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++--- chat_stream.go | 2 +- completion.go | 2 +- 3 files changed, 75 insertions(+), 6 deletions(-) diff --git a/chat.go b/chat.go index a7ce5486a..c8cff319e 100644 --- a/chat.go +++ b/chat.go @@ -11,8 +11,11 @@ const ( ChatMessageRoleSystem = "system" ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" + ChatMessageRoleFunction = "function" ) +const chatCompletionsSuffix = "/chat/completions" + var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll @@ -27,6 +30,14 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-python/blob/main/chatml.md // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + // call function with arguments in JSON format + Arguments string `json:"arguments,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -43,12 +54,70 @@ type ChatCompletionRequest struct { FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` + Functions []*FunctionDefine `json:"functions,omitempty"` + FunctionCall string `json:"function_call,omitempty"` } +type FunctionDefine struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + // it's required in function call + Parameters *FunctionParams `json:"parameters"` +} + +type FunctionParams struct { + // the Type must be JSONSchemaTypeObject + Type JSONSchemaType `json:"type"` + Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type JSONSchemaType string + +const ( + JSONSchemaTypeObject JSONSchemaType = "object" + JSONSchemaTypeNumber JSONSchemaType = "number" + JSONSchemaTypeString JSONSchemaType = "string" + JSONSchemaTypeArray JSONSchemaType = "array" + JSONSchemaTypeNull JSONSchemaType = "null" + JSONSchemaTypeBoolean JSONSchemaType = "boolean" +) + +// JSONSchemaDefine is a struct for JSON Schema. +type JSONSchemaDefine struct { + // Type is a type of JSON Schema. + Type JSONSchemaType `json:"type,omitempty"` + // Description is a description of JSON Schema. + Description string `json:"description,omitempty"` + // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. + Enum []string `json:"enum,omitempty"` + // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. + Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. + Required []string `json:"required,omitempty"` +} + +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonContentFilter FinishReason = "content_filter" + FinishReasonNull FinishReason = "null" +) + type ChatCompletionChoice struct { - Index int `json:"index"` - Message ChatCompletionMessage `json:"message"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + // FinishReason + // stop: API returned complete message, + // or a message terminated by one of the stop sequences provided via the stop parameter + // length: Incomplete model output due to max_tokens parameter or token limit + // function_call: The model decided to call a function + // content_filter: Omitted content due to a flag from our content filters + // null: API response still in progress or incomplete + FinishReason FinishReason `json:"finish_reason"` } // ChatCompletionResponse represents a response structure for chat completion API. @@ -71,7 +140,7 @@ func (c *Client) CreateChatCompletion( return } - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return diff --git a/chat_stream.go b/chat_stream.go index 625d436cb..c7341feac 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -40,7 +40,7 @@ func (c *Client) CreateChatCompletionStream( ctx context.Context, request ChatCompletionRequest, ) (stream *ChatCompletionStream, err error) { - urlSuffix := "/chat/completions" + urlSuffix := chatCompletionsSuffix if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrChatCompletionInvalidModel return diff --git a/completion.go b/completion.go index efded208b..e0571b007 100644 --- a/completion.go +++ b/completion.go @@ -65,7 +65,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT432K0314: true, GPT432K0613: true, }, - "/chat/completions": { + chatCompletionsSuffix: { CodexCodeDavinci002: true, CodexCodeCushman001: true, CodexCodeDavinci001: true, From 43de77162f7f6a1f391efce7a56b75d0b63042a9 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 15 Jun 2023 12:53:52 +0400 Subject: [PATCH 004/242] Create FUNDING.yml (#371) --- .github/FUNDING.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 000000000..d9fd885a9 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [sashabaranov] From 0bd14f9584baf8b47dd9251b674c26aed9c5a723 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 15 Jun 2023 17:58:26 +0800 Subject: [PATCH 005/242] refactor: ChatCompletionStreamChoice.FinishReason from string to FinishReason (#372) --- chat_stream.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat_stream.go b/chat_stream.go index c7341feac..9093bde9e 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -15,7 +15,7 @@ type ChatCompletionStreamChoiceDelta struct { type ChatCompletionStreamChoice struct { Index int `json:"index"` Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason string `json:"finish_reason"` + FinishReason FinishReason `json:"finish_reason"` } type ChatCompletionStreamResponse struct { From ac25f318ba29e1461ceec19f40bf5fd7765b7225 Mon Sep 17 00:00:00 2001 From: Alex Wormuth Date: Fri, 16 Jun 2023 08:11:50 -0500 Subject: [PATCH 006/242] add items, which is required for array type (#373) * add items, which is required for array type * use JSONSchemaDefine directly --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index c8cff319e..4764e36ba 100644 --- a/chat.go +++ b/chat.go @@ -95,6 +95,8 @@ type JSONSchemaDefine struct { Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. Required []string `json:"required,omitempty"` + // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. + Items *JSONSchemaDefine `json:"items,omitempty"` } type FinishReason string From f0770cfe1d5094d5d40a878658abf535bbdcec4c Mon Sep 17 00:00:00 2001 From: romazu Date: Fri, 16 Jun 2023 17:13:26 +0400 Subject: [PATCH 007/242] audio: add items to AudioResponseFormat enum (#382) * audio: add items to AudioResponseFormat enum * audio: expand AudioResponse struct to accommodate verbose json response --------- Co-authored-by: Roman Zubov --- audio.go | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/audio.go b/audio.go index 20e865f11..adfc52766 100644 --- a/audio.go +++ b/audio.go @@ -20,9 +20,11 @@ const ( type AudioResponseFormat string const ( - AudioResponseFormatJSON AudioResponseFormat = "json" - AudioResponseFormatSRT AudioResponseFormat = "srt" - AudioResponseFormatVTT AudioResponseFormat = "vtt" + AudioResponseFormatJSON AudioResponseFormat = "json" + AudioResponseFormatText AudioResponseFormat = "text" + AudioResponseFormatSRT AudioResponseFormat = "srt" + AudioResponseFormatVerboseJSON AudioResponseFormat = "verbose_json" + AudioResponseFormatVTT AudioResponseFormat = "vtt" ) // AudioRequest represents a request structure for audio API. @@ -44,6 +46,22 @@ type AudioRequest struct { // AudioResponse represents a response structure for audio API. type AudioResponse struct { + Task string `json:"task"` + Language string `json:"language"` + Duration float64 `json:"duration"` + Segments []struct { + ID int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` + Transient bool `json:"transient"` + } `json:"segments"` Text string `json:"text"` } @@ -96,7 +114,7 @@ func (c *Client) callAudioAPI( // HasJSONResponse returns true if the response format is JSON. func (r AudioRequest) HasJSONResponse() bool { - return r.Format == "" || r.Format == AudioResponseFormatJSON + return r.Format == "" || r.Format == AudioResponseFormatJSON || r.Format == AudioResponseFormatVerboseJSON } // audioMultipartForm creates a form with audio file contents and the name of the model to use for From e49d771fff3bc699bca7cf22c9d93b67316047e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sat, 17 Jun 2023 22:57:29 +0900 Subject: [PATCH 008/242] support for parsing error response message fields even if they are arrays (#381) (#384) --- api_test.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++++---- error.go | 10 ++++- 2 files changed, 111 insertions(+), 8 deletions(-) diff --git a/api_test.go b/api_test.go index 083b67412..34173708f 100644 --- a/api_test.go +++ b/api_test.go @@ -137,6 +137,108 @@ func TestAPIError(t *testing.T) { } } +func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFn func(t *testing.T, apiErr APIError) + } + testCases := []testCase{ + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + expected := "foo, bar, baz" + if apiErr.Message != expected { + t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) + } + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + if apiErr.Message != "" { + t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) + } + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFn: func(t *testing.T, apiErr APIError) { + if apiErr.Message != "" { + t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) + } + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr APIError + err := json.Unmarshal([]byte(tc.response), &apiErr) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + return + } + if tc.checkFn != nil { + tc.checkFn(t, apiErr) + } + }) + } +} + func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { var apiErr APIError response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` @@ -217,13 +319,6 @@ func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { checks.HasError(t, err, "Type should be a string") } -func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":false,"param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Message should be a string") -} - func TestRequestError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/error.go b/error.go index b789ed7d5..f68e92875 100644 --- a/error.go +++ b/error.go @@ -3,6 +3,7 @@ package openai import ( "encoding/json" "fmt" + "strings" ) // APIError provides error information returned by the OpenAI API. @@ -41,7 +42,14 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { err = json.Unmarshal(rawMap["message"], &e.Message) if err != nil { - return + // If the parameter field of a function call is invalid as a JSON schema + // refs: https://github.com/sashabaranov/go-openai/issues/381 + var messages []string + err = json.Unmarshal(rawMap["message"], &messages) + if err != nil { + return + } + e.Message = strings.Join(messages, ", ") } // optional fields for azure openai From b0959382c8fc01bf12de71a843d961f0d579f6f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sun, 18 Jun 2023 19:51:20 +0900 Subject: [PATCH 009/242] extract and split integration tests (#389) --- api_integration_test.go | 136 ++++++++++++++++ api_test.go | 353 ---------------------------------------- engines_test.go | 11 ++ error_test.go | 201 +++++++++++++++++++++++ openai_test.go | 9 + 5 files changed, 357 insertions(+), 353 deletions(-) create mode 100644 api_integration_test.go delete mode 100644 api_test.go create mode 100644 error_test.go diff --git a/api_integration_test.go b/api_integration_test.go new file mode 100644 index 000000000..3cafa24b4 --- /dev/null +++ b/api_integration_test.go @@ -0,0 +1,136 @@ +package openai_test + +import ( + "context" + "errors" + "io" + "os" + "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestAPI(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken) + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.NoError(t, err, "ListEngines error") + + _, err = c.GetEngine(ctx, "davinci") + checks.NoError(t, err, "GetEngine error") + + fileRes, err := c.ListFiles(ctx) + checks.NoError(t, err, "ListFiles error") + + if len(fileRes.Files) > 0 { + _, err = c.GetFile(ctx, fileRes.Files[0].ID) + checks.NoError(t, err, "GetFile error") + } // else skip + + embeddingReq := EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: AdaSearchQuery, + } + _, err = c.CreateEmbeddings(ctx, embeddingReq) + checks.NoError(t, err, "Embedding error") + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + checks.NoError(t, err, "CreateChatCompletion (without name) returned error") + + _, err = c.CreateChatCompletion( + ctx, + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Name: "John_Doe", + Content: "Hello!", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with name) returned error") + + stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: GPT3Ada, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + +func TestAPIError(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := NewClient(apiToken + "_invalid") + ctx := context.Background() + _, err = c.ListEngines(ctx) + checks.HasError(t, err, "ListEngines should fail with an invalid key") + + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Error is not an APIError: %+v", err) + } + + if apiErr.HTTPStatusCode != 401 { + t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) + } + + switch v := apiErr.Code.(type) { + case string: + if v != "invalid_api_key" { + t.Fatalf("Unexpected API error code: %s", v) + } + default: + t.Fatalf("Unexpected API error code type: %T", v) + } + + if apiErr.Error() == "" { + t.Fatal("Empty error message occurred") + } +} diff --git a/api_test.go b/api_test.go deleted file mode 100644 index 34173708f..000000000 --- a/api_test.go +++ /dev/null @@ -1,353 +0,0 @@ -package openai_test - -import ( - "context" - "encoding/json" - "errors" - "io" - "net/http" - "os" - "testing" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" -) - -func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.NoError(t, err, "ListEngines error") - - _, err = c.GetEngine(ctx, "davinci") - checks.NoError(t, err, "GetEngine error") - - fileRes, err := c.ListFiles(ctx) - checks.NoError(t, err, "ListFiles error") - - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) - checks.NoError(t, err, "GetFile error") - } // else skip - - embeddingReq := EmbeddingRequest{ - Input: []string{ - "The food was delicious and the waiter", - "Other examples of embedding request", - }, - Model: AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) - checks.NoError(t, err, "Embedding error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Content: "Hello!", - }, - }, - }, - ) - - checks.NoError(t, err, "CreateChatCompletion (without name) returned error") - - _, err = c.CreateChatCompletion( - ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ - { - Role: ChatMessageRoleUser, - Name: "John_Doe", - Content: "Hello!", - }, - }, - }, - ) - checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } -} - -func TestAPIError(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { - t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") - } - - var err error - c := NewClient(apiToken + "_invalid") - ctx := context.Background() - _, err = c.ListEngines(ctx) - checks.HasError(t, err, "ListEngines should fail with an invalid key") - - var apiErr *APIError - if !errors.As(err, &apiErr) { - t.Fatalf("Error is not an APIError: %+v", err) - } - - if apiErr.HTTPStatusCode != 401 { - t.Fatalf("Unexpected API error status code: %d", apiErr.HTTPStatusCode) - } - - switch v := apiErr.Code.(type) { - case string: - if v != "invalid_api_key" { - t.Fatalf("Unexpected API error code: %s", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } - - if apiErr.Error() == "" { - t.Fatal("Empty error message occurred") - } -} - -func TestAPIErrorUnmarshalJSONMessageField(t *testing.T) { - type testCase struct { - name string - response string - hasError bool - checkFn func(t *testing.T, apiErr APIError) - } - testCases := []testCase{ - { - name: "parse succeeds when the message is string", - response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is array with single item", - response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is array with multiple items", - response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - expected := "foo, bar, baz" - if apiErr.Message != expected { - t.Fatalf("Unexpected API message: %v; expected: %s", apiErr, expected) - } - }, - }, - { - name: "parse succeeds when the message is empty array", - response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - if apiErr.Message != "" { - t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) - } - }, - }, - { - name: "parse succeeds when the message is null", - response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, - hasError: false, - checkFn: func(t *testing.T, apiErr APIError) { - if apiErr.Message != "" { - t.Fatalf("Unexpected API message: %v; expected: empty", apiErr) - } - }, - }, - { - name: "parse failed when the message is object", - response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is int", - response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is float", - response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is bool", - response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - { - name: "parse failed when the message is not exists", - response: `{"type":"invalid_request_error","param":null,"code":null}`, - hasError: true, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - var apiErr APIError - err := json.Unmarshal([]byte(tc.response), &apiErr) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) - return - } - if tc.checkFn != nil { - tc.checkFn(t, apiErr) - } - }) - } -} - -func TestAPIErrorUnmarshalJSONInteger(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case int: - if v != 418 { - t.Fatalf("Unexpected API code integer: %d; expected 418", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONString(t *testing.T) { - var apiErr APIError - response := `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case string: - if v != "teapot" { - t.Fatalf("Unexpected API code string: %s; expected `teapot`", v) - } - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalJSONNoCode(t *testing.T) { - // test integer code - response := `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}` - var apiErr APIError - err := json.Unmarshal([]byte(response), &apiErr) - checks.NoError(t, err, "Unexpected Unmarshal API response error") - - switch v := apiErr.Code.(type) { - case nil: - default: - t.Fatalf("Unexpected API error code type: %T", v) - } -} - -func TestAPIErrorUnmarshalInvalidData(t *testing.T) { - apiErr := APIError{} - data := []byte(`--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`) - err := apiErr.UnmarshalJSON(data) - checks.HasError(t, err, "Expected error when unmarshaling invalid data") - - if apiErr.Code != nil { - t.Fatalf("Expected nil code, got %q", apiErr.Code) - } - if apiErr.Message != "" { - t.Fatalf("Expected empty message, got %q", apiErr.Message) - } - if apiErr.Param != nil { - t.Fatalf("Expected nil param, got %q", *apiErr.Param) - } - if apiErr.Type != "" { - t.Fatalf("Expected empty type, got %q", apiErr.Type) - } -} - -func TestAPIErrorUnmarshalJSONInvalidParam(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Param should be a string") -} - -func TestAPIErrorUnmarshalJSONInvalidType(t *testing.T) { - var apiErr APIError - response := `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}` - err := json.Unmarshal([]byte(response), &apiErr) - checks.HasError(t, err, "Type should be a string") -} - -func TestRequestError(t *testing.T) { - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusTeapot) - }) - - _, err := client.ListEngines(context.Background()) - checks.HasError(t, err, "ListEngines did not fail") - - var reqErr *RequestError - if !errors.As(err, &reqErr) { - t.Fatalf("Error is not a RequestError: %+v", err) - } - - if reqErr.HTTPStatusCode != 418 { - t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) - } - - if reqErr.Unwrap() == nil { - t.Fatalf("Empty request error occurred") - } -} - -// numTokens Returns the number of GPT-3 encoded tokens in the given text. -// This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer -// -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) -func numTokens(s string) int { - return int(float32(len(s)) / 4) -} diff --git a/engines_test.go b/engines_test.go index 2beb333b3..31e7ec8be 100644 --- a/engines_test.go +++ b/engines_test.go @@ -34,3 +34,14 @@ func TestListEngines(t *testing.T) { _, err := client.ListEngines(context.Background()) checks.NoError(t, err, "ListEngines error") } + +func TestListEnginesReturnError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + }) + + _, err := client.ListEngines(context.Background()) + checks.HasError(t, err, "ListEngines did not fail") +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 000000000..e2309abd7 --- /dev/null +++ b/error_test.go @@ -0,0 +1,201 @@ +package openai_test + +import ( + "errors" + "net/http" + "testing" + + . "github.com/sashabaranov/go-openai" +) + +func TestAPIErrorUnmarshalJSON(t *testing.T) { + type testCase struct { + name string + response string + hasError bool + checkFunc func(t *testing.T, apiErr APIError) + } + testCases := []testCase{ + // testcase for message field + { + name: "parse succeeds when the message is string", + response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with single item", + response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo") + }, + }, + { + name: "parse succeeds when the message is array with multiple items", + response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "foo, bar, baz") + }, + }, + { + name: "parse succeeds when the message is empty array", + response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse succeeds when the message is null", + response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorMessage(t, apiErr, "") + }, + }, + { + name: "parse failed when the message is object", + response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is int", + response: `{"message":1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is float", + response: `{"message":0.1,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is bool", + response: `{"message":true,"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + { + name: "parse failed when the message is not exists", + response: `{"type":"invalid_request_error","param":null,"code":null}`, + hasError: true, + }, + // testcase for code field + { + name: "parse succeeds when the code is int", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, 418) + }, + }, + { + name: "parse succeeds when the code is string", + response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, "teapot") + }, + }, + { + name: "parse succeeds when the code is not exists", + response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, nil) + }, + }, + // testcase for param field + { + name: "parse failed when the param is bool", + response: `{"code":418,"message":"I'm a teapot","param":true,"type":"teapot_error"}`, + hasError: true, + }, + // testcase for type field + { + name: "parse failed when the type is bool", + response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":true}`, + hasError: true, + }, + // testcase for error response + { + name: "parse failed when the response is invalid json", + response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorCode(t, apiErr, nil) + assertAPIErrorMessage(t, apiErr, "") + assertAPIErrorParam(t, apiErr, nil) + assertAPIErrorType(t, apiErr, "") + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var apiErr APIError + err := apiErr.UnmarshalJSON([]byte(tc.response)) + if (err != nil) != tc.hasError { + t.Errorf("Unexpected error: %v", err) + } + if tc.checkFunc != nil { + tc.checkFunc(t, apiErr) + } + }) + } +} + +func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { + if apiErr.Message != expected { + t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) + } +} + +func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { + switch v := apiErr.Code.(type) { + case int: + if v != expected { + t.Errorf("Unexpected APIError code integer: %d; expected %d", v, expected) + } + case string: + if v != expected { + t.Errorf("Unexpected APIError code string: %s; expected %s", v, expected) + } + case nil: + default: + t.Errorf("Unexpected APIError error code type: %T", v) + } +} + +func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { + if apiErr.Param != expected { + t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) + } +} + +func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { + if apiErr.Type != typ { + t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) + } +} + +func TestRequestError(t *testing.T) { + var err error = &RequestError{ + HTTPStatusCode: http.StatusTeapot, + Err: errors.New("i am a teapot"), + } + + var reqErr *RequestError + if !errors.As(err, &reqErr) { + t.Fatalf("Error is not a RequestError: %+v", err) + } + + if reqErr.HTTPStatusCode != 418 { + t.Fatalf("Unexpected request error status code: %d", reqErr.HTTPStatusCode) + } + + if reqErr.Unwrap() == nil { + t.Fatalf("Empty request error occurred") + } +} diff --git a/openai_test.go b/openai_test.go index a5e7b64ee..4fc41ecc0 100644 --- a/openai_test.go +++ b/openai_test.go @@ -26,3 +26,12 @@ func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown f client = NewClientWithConfig(config) return } + +// numTokens Returns the number of GPT-3 encoded tokens in the given text. +// This function approximates based on the rule of thumb stated by OpenAI: +// https://beta.openai.com/tokenizer +// +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +func numTokens(s string) int { + return int(float32(len(s)) / 4) +} From 68f9ef92beeb368eb77ea1bb206abedb5066501b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 19 Jun 2023 17:12:38 +0900 Subject: [PATCH 010/242] split integration test using go build tag (#392) --- README.md | 13 +++++++++++++ api_integration_test.go | 2 ++ 2 files changed, 15 insertions(+) diff --git a/README.md b/README.md index 7562694df..9a7262332 100644 --- a/README.md +++ b/README.md @@ -542,3 +542,16 @@ if errors.As(err, &e) { See the `examples/` folder for more. +### Integration tests: + +Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. + +**Notes:** +These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. + +**Run tests using:** +``` +OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go +``` + +If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. \ No newline at end of file diff --git a/api_integration_test.go b/api_integration_test.go index 3cafa24b4..d4e7328a2 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -1,3 +1,5 @@ +//go:build integration + package openai_test import ( From 720377087fae943d15000d47c7c9ea0a214798b1 Mon Sep 17 00:00:00 2001 From: cem-unuvar <87916654+cem-unuvar@users.noreply.github.com> Date: Tue, 20 Jun 2023 18:33:53 +0300 Subject: [PATCH 011/242] feat: added function call info to chat completions (#390) --- chat_stream.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 9093bde9e..75aa6858a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -8,8 +8,9 @@ import ( ) type ChatCompletionStreamChoiceDelta struct { - Content string `json:"content,omitempty"` - Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` } type ChatCompletionStreamChoice struct { From e948150829ac980f3aea86ed1d73aa2fc5a7f12b Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 20 Jun 2023 23:39:19 +0800 Subject: [PATCH 012/242] fix: chat stream returns an error response with a 'data: ' prefix (#396) * fix: chat stream resp has 'data: ' prefix * fix: lint error * fix: lint error * fix: lint error --- chat_stream_test.go | 39 +++++++++++++++++++++++++++++++++++++++ stream_reader.go | 22 ++++++++++++++++++---- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index c3cb9f3f7..5fc70b032 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -178,6 +178,45 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + _, streamErr := stream.Recv() + checks.HasError(t, streamErr, "stream.Recv() did not return error") + + var apiErr *APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/stream_reader.go b/stream_reader.go index 34161986e..87e59e0ca 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -10,6 +10,11 @@ import ( utils "github.com/sashabaranov/go-openai/internal" ) +var ( + headerData = []byte("data: ") + errorPrefix = []byte(`data: {"error":`) +) + type streamable interface { ChatCompletionStreamResponse | CompletionResponse } @@ -34,12 +39,16 @@ func (stream *streamReader[T]) Recv() (response T, err error) { return } +//nolint:gocognit func (stream *streamReader[T]) processLines() (T, error) { - var emptyMessagesCount uint + var ( + emptyMessagesCount uint + hasErrorPrefix bool + ) for { rawLine, readErr := stream.reader.ReadBytes('\n') - if readErr != nil { + if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { return *new(T), fmt.Errorf("error, %w", respErr.Error) @@ -47,9 +56,14 @@ func (stream *streamReader[T]) processLines() (T, error) { return *new(T), readErr } - var headerData = []byte("data: ") noSpaceLine := bytes.TrimSpace(rawLine) - if !bytes.HasPrefix(noSpaceLine, headerData) { + if bytes.HasPrefix(noSpaceLine, errorPrefix) { + hasErrorPrefix = true + } + if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if hasErrorPrefix { + noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { return *new(T), writeErr From f22da8a7ed896d19661dfcce3e330e4b209b2eb3 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Wed, 21 Jun 2023 08:58:27 -0400 Subject: [PATCH 013/242] feat: allow more input types to functions, fix tests (#377) * feat: use json.rawMessage, test functions * chore: lint * fix: tests the ChatCompletion mock server doesn't actually run otherwise. N=0 is the default request but the server will treat it as n=1 * fix: tests should default to n=1 completions * chore: add back removed interfaces, custom marshal * chore: lint * chore: lint * chore: add some tests * chore: appease lint * clean up JSON schema + tests * chore: lint * feat: remove backwards compatible functions for illustrative purposes * fix: revert params change * chore: use interface{} * chore: add test * chore: add back FunctionDefine * chore: /s/interface{}/any * chore: add back jsonschemadefinition * chore: testcov * chore: lint * chore: remove pointers * chore: update comment * chore: address CR added test for compatibility as well --------- Co-authored-by: James --- chat.go | 34 +++++----- chat_test.go | 157 ++++++++++++++++++++++++++++++++++++++++++++- completion_test.go | 10 ++- 3 files changed, 180 insertions(+), 21 deletions(-) diff --git a/chat.go b/chat.go index 4764e36ba..f99af2735 100644 --- a/chat.go +++ b/chat.go @@ -54,23 +54,23 @@ type ChatCompletionRequest struct { FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` LogitBias map[string]int `json:"logit_bias,omitempty"` User string `json:"user,omitempty"` - Functions []*FunctionDefine `json:"functions,omitempty"` - FunctionCall string `json:"function_call,omitempty"` + Functions []FunctionDefinition `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` } -type FunctionDefine struct { +type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` - // it's required in function call - Parameters *FunctionParams `json:"parameters"` + // Parameters is an object describing the function. + // You can pass a raw byte array describing the schema, + // or you can pass in a struct which serializes to the proper JSONSchema. + // The JSONSchemaDefinition struct is provided for convenience, but you should + // consider another specialized library for more complex schemas. + Parameters any `json:"parameters"` } -type FunctionParams struct { - // the Type must be JSONSchemaTypeObject - Type JSONSchemaType `json:"type"` - Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` -} +// Deprecated: use FunctionDefinition instead. +type FunctionDefine = FunctionDefinition type JSONSchemaType string @@ -83,8 +83,9 @@ const ( JSONSchemaTypeBoolean JSONSchemaType = "boolean" ) -// JSONSchemaDefine is a struct for JSON Schema. -type JSONSchemaDefine struct { +// JSONSchemaDefinition is a struct for JSON Schema. +// It is fairly limited and you may have better luck using a third-party library. +type JSONSchemaDefinition struct { // Type is a type of JSON Schema. Type JSONSchemaType `json:"type,omitempty"` // Description is a description of JSON Schema. @@ -92,13 +93,16 @@ type JSONSchemaDefine struct { // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. Enum []string `json:"enum,omitempty"` // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. - Properties map[string]*JSONSchemaDefine `json:"properties,omitempty"` + Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"` // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. Required []string `json:"required,omitempty"` // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. - Items *JSONSchemaDefine `json:"items,omitempty"` + Items *JSONSchemaDefinition `json:"items,omitempty"` } +// Deprecated: use JSONSchemaDefinition instead. +type JSONSchemaDefine = JSONSchemaDefinition + type FinishReason string const ( diff --git a/chat_test.go b/chat_test.go index a43bb4aa6..3c759b310 100644 --- a/chat_test.go +++ b/chat_test.go @@ -67,6 +67,130 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestChatCompletionsFunctions tests including a function call. +func TestChatCompletionsFunctions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + t.Run("bytes", func(t *testing.T) { + //nolint:lll + msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefine{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("struct", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefinition{{ + Name: "test", + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefine", func(t *testing.T) { + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefinition{{ + Name: "test", + Parameters: &JSONSchemaDefinition{ + Type: JSONSchemaTypeObject, + Properties: map[string]JSONSchemaDefinition{ + "count": { + Type: JSONSchemaTypeNumber, + Description: "total number of words in sentence", + }, + "words": { + Type: JSONSchemaTypeArray, + Description: "list of words in sentence", + Items: &JSONSchemaDefinition{ + Type: JSONSchemaTypeString, + }, + }, + "enumTest": { + Type: JSONSchemaTypeString, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) + t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) { + // this is a compatibility check + _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo0613, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []FunctionDefine{{ + Name: "test", + Parameters: &JSONSchemaDefine{ + Type: JSONSchemaTypeObject, + Properties: map[string]JSONSchemaDefine{ + "count": { + Type: JSONSchemaTypeNumber, + Description: "total number of words in sentence", + }, + "words": { + Type: JSONSchemaTypeArray, + Description: "list of words in sentence", + Items: &JSONSchemaDefine{ + Type: JSONSchemaTypeString, + }, + }, + "enumTest": { + Type: JSONSchemaTypeString, + Enum: []string{"hello", "world"}, + }, + }, + }, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) +} + func TestAzureChatCompletions(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() @@ -109,7 +233,34 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { + // if there are functions, include them + if len(completionReq.Functions) > 0 { + var fcb []byte + b := completionReq.Functions[0].Parameters + fcb, err = json.Marshal(b) + if err != nil { + http.Error(w, "could not marshal function parameters", http.StatusInternalServerError) + return + } + + res.Choices = append(res.Choices, ChatCompletionChoice{ + Message: ChatCompletionMessage{ + Role: ChatMessageRoleFunction, + // this is valid json so it should be fine + FunctionCall: &FunctionCall{ + Name: completionReq.Functions[0].Name, + Arguments: string(fcb), + }, + }, + Index: i, + }) + continue + } // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) @@ -121,8 +272,8 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Index: i, }) } - inputTokens := numTokens(completionReq.Messages[0].Content) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, diff --git a/completion_test.go b/completion_test.go index aeddcfca1..844ef484f 100644 --- a/completion_test.go +++ b/completion_test.go @@ -83,7 +83,11 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) if completionReq.Echo { @@ -94,8 +98,8 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Index: i, }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N + inputTokens := numTokens(completionReq.Prompt.(string)) * n + completionTokens := completionReq.MaxTokens * n res.Usage = Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, From e19b074a114a5add5f005911668d0cda8476a908 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 21 Jun 2023 23:53:15 +0900 Subject: [PATCH 014/242] docs: add requires go version in README.md (#397) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9a7262332..5f166dc31 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op ``` go get github.com/sashabaranov/go-openai ``` - +Currently, go-openai requires Go version 1.18 or greater. ### ChatGPT example usage: @@ -554,4 +554,4 @@ These tests send real network traffic to the OpenAI API and may reach rate limit OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go ``` -If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. \ No newline at end of file +If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. From ffa7abc050b22b068ed16680de3b96ef26211651 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 21 Jun 2023 18:54:10 +0400 Subject: [PATCH 015/242] Update README.md (#399) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5f166dc31..522a85e78 100644 --- a/README.md +++ b/README.md @@ -554,4 +554,4 @@ These tests send real network traffic to the OpenAI API and may reach rate limit OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go ``` -If `OPENAI_TOKEN` environment variables are not available, integration tests will be skipped. +If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. From 157de0680f39f7c521cdd79bf69fb66390380c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 22 Jun 2023 18:49:46 +0900 Subject: [PATCH 016/242] add vvatanabe to FUNDING.yml (#402) --- .github/FUNDING.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index d9fd885a9..e36c38239 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,3 +1,3 @@ # These are supported funding model platforms -github: [sashabaranov] +github: [sashabaranov, vvatanabe] From f1b66967a426c3dfaf5e652b118d807cf1e7473f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 22 Jun 2023 18:57:52 +0900 Subject: [PATCH 017/242] refactor: refactoring http request creation and sending (#395) * refactoring http request creation and sending * fix lint error * increase the test coverage of client.go * refactor: Change the style of HTTPRequestBuilder.Build func to one-argument-per-line. --- api_internal_test.go | 2 +- audio.go | 4 +- chat.go | 2 +- chat_stream.go | 22 ++------ client.go | 94 ++++++++++++++++++++++++-------- client_test.go | 25 +++++++-- completion.go | 2 +- edits.go | 2 +- embeddings.go | 2 +- engines.go | 4 +- files.go | 28 +++------- fine_tunes.go | 12 ++-- image.go | 13 ++--- internal/request_builder.go | 42 +++++++++----- internal/request_builder_test.go | 6 +- models.go | 4 +- models_test.go | 22 ++++++++ moderation.go | 2 +- stream.go | 21 ++----- stream_test.go | 26 +++++++++ 20 files changed, 209 insertions(+), 126 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 214b627bf..0fb0f8993 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -94,7 +94,7 @@ func TestRequestAuthHeader(t *testing.T) { az.OrgID = c.OrgID cli := NewClientWithConfig(az) - req, err := cli.newStreamRequest(context.Background(), "POST", "/chat/completions", nil, "") + req, err := cli.newRequest(context.Background(), "POST", "/chat/completions") if err != nil { t.Errorf("Failed to create request: %v", err) } diff --git a/audio.go b/audio.go index adfc52766..9f469159d 100644 --- a/audio.go +++ b/audio.go @@ -95,11 +95,11 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), &formBody) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), + withBody(&formBody), withContentType(builder.FormDataContentType())) if err != nil { return AudioResponse{}, err } - req.Header.Add("Content-Type", builder.FormDataContentType()) if request.HasJSONResponse() { err = c.sendRequest(req, &response) diff --git a/chat.go b/chat.go index f99af2735..b74720d38 100644 --- a/chat.go +++ b/chat.go @@ -152,7 +152,7 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 75aa6858a..9f4e80cff 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -1,10 +1,8 @@ package openai import ( - "bufio" "context" - - utils "github.com/sashabaranov/go-openai/internal" + "net/http" ) type ChatCompletionStreamChoiceDelta struct { @@ -48,27 +46,17 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[ChatCompletionStreamResponse](c, req) if err != nil { return } - if isFailureStatusCode(resp) { - return nil, c.handleErrorResp(resp) - } - stream = &ChatCompletionStream{ - streamReader: &streamReader[ChatCompletionStreamResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: utils.NewErrorAccumulator(), - unmarshaler: &utils.JSONUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/client.go b/client.go index f38c1dfc3..5779a8e1c 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package openai import ( + "bufio" "context" "encoding/json" "fmt" @@ -45,6 +46,42 @@ func NewOrgClient(authToken, org string) *Client { return NewClientWithConfig(config) } +type requestOptions struct { + body any + header http.Header +} + +type requestOption func(*requestOptions) + +func withBody(body any) requestOption { + return func(args *requestOptions) { + args.body = body + } +} + +func withContentType(contentType string) requestOption { + return func(args *requestOptions) { + args.header.Set("Content-Type", contentType) + } +} + +func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { + // Default Options + args := &requestOptions{ + body: nil, + header: make(http.Header), + } + for _, setter := range setters { + setter(args) + } + req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header) + if err != nil { + return nil, err + } + c.setCommonHeaders(req) + return req, nil +} + func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") @@ -55,8 +92,6 @@ func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Content-Type", "application/json; charset=utf-8") } - c.setCommonHeaders(req) - res, err := c.config.HTTPClient.Do(req) if err != nil { return err @@ -71,6 +106,41 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return decodeResponse(res.Body, v) } +func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { + resp, err := c.config.HTTPClient.Do(req) + if err != nil { + return + } + + if isFailureStatusCode(resp) { + err = c.handleErrorResp(resp) + return + } + return resp.Body, nil +} + +func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + if err != nil { + return new(streamReader[T]), err + } + if isFailureStatusCode(resp) { + return new(streamReader[T]), client.handleErrorResp(resp) + } + return &streamReader[T]{ + emptyMessagesLimit: client.config.EmptyMessagesLimit, + reader: bufio.NewReader(resp.Body), + response: resp, + errAccumulator: utils.NewErrorAccumulator(), + unmarshaler: &utils.JSONUnmarshaler{}, + }, nil +} + func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication @@ -138,26 +208,6 @@ func (c *Client) fullURL(suffix string, args ...any) string { return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } -func (c *Client) newStreamRequest( - ctx context.Context, - method string, - urlSuffix string, - body any, - model string) (*http.Request, error) { - req, err := c.requestBuilder.Build(ctx, method, c.fullURL(urlSuffix, model), body) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - - c.setCommonHeaders(req) - return req, nil -} - func (c *Client) handleErrorResp(resp *http.Response) error { var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) diff --git a/client_test.go b/client_test.go index 00b66feae..29d84edfa 100644 --- a/client_test.go +++ b/client_test.go @@ -16,7 +16,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed") type failingRequestBuilder struct{} -func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any) (*http.Request, error) { +func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) { return nil, errTestRequestBuilderFailed } @@ -41,9 +41,10 @@ func TestDecodeResponse(t *testing.T) { stringInput := "" testCases := []struct { - name string - value interface{} - body io.Reader + name string + value interface{} + body io.Reader + hasError bool }{ { name: "nil input", @@ -60,18 +61,32 @@ func TestDecodeResponse(t *testing.T) { value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), }, + { + name: "reader return error", + value: &stringInput, + body: &errorReader{err: errors.New("dummy")}, + hasError: true, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if err != nil { + if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) } }) } } +type errorReader struct { + err error +} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, e.err +} + func TestHandleErrorResp(t *testing.T) { // var errRes *ErrorResponse var errRes ErrorResponse diff --git a/completion.go b/completion.go index e0571b007..b3b3abd1c 100644 --- a/completion.go +++ b/completion.go @@ -165,7 +165,7 @@ func (c *Client) CreateCompletion( return } - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } diff --git a/edits.go b/edits.go index 23b1a64f0..3d3fc8950 100644 --- a/edits.go +++ b/edits.go @@ -32,7 +32,7 @@ type EditsResponse struct { // Perform an API call to the Edits endpoint. func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index 942f3ea3a..ba327ce77 100644 --- a/embeddings.go +++ b/embeddings.go @@ -132,7 +132,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request)) if err != nil { return } diff --git a/engines.go b/engines.go index ac01a00ed..adf6025c2 100644 --- a/engines.go +++ b/engines.go @@ -22,7 +22,7 @@ type EnginesList struct { // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/engines"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/engines")) if err != nil { return } @@ -38,7 +38,7 @@ func (c *Client) GetEngine( engineID string, ) (engine Engine, err error) { urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/files.go b/files.go index fb9937bea..ea1f50a73 100644 --- a/files.go +++ b/files.go @@ -57,21 +57,19 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File return } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL("/files"), &b) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) - err = c.sendRequest(req, &file) - return } // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/files/"+fileID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/files/"+fileID)) if err != nil { return } @@ -83,7 +81,7 @@ func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/files"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/files")) if err != nil { return } @@ -96,7 +94,7 @@ func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -107,23 +105,11 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) - if err != nil { - return - } - - c.setCommonHeaders(req) - - res, err := c.config.HTTPClient.Do(req) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - if isFailureStatusCode(res) { - err = c.handleErrorResp(res) - return - } - - content = res.Body + content, err = c.sendRequestRaw(req) return } diff --git a/fine_tunes.go b/fine_tunes.go index 069ddccfd..96e731d51 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -68,7 +68,7 @@ type FineTuneDeleteResponse struct { func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -79,7 +79,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // CancelFineTune cancel a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel"), nil) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return } @@ -89,7 +89,7 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { return } @@ -100,7 +100,7 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } @@ -110,7 +110,7 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F } func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID), nil) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { return } @@ -120,7 +120,7 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons } func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { return } diff --git a/image.go b/image.go index df7363865..cb96f4f5e 100644 --- a/image.go +++ b/image.go @@ -44,7 +44,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL(urlSuffix), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) if err != nil { return } @@ -107,13 +107,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - urlSuffix := "/images/edits" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"), + withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } @@ -158,14 +157,12 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - //https://platform.openai.com/docs/api-reference/images/create-variation - urlSuffix := "/images/variations" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"), + withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return } - req.Header.Set("Content-Type", builder.FormDataContentType()) err = c.sendRequest(req, &response) return } diff --git a/internal/request_builder.go b/internal/request_builder.go index 0a9eabfde..5699f6b18 100644 --- a/internal/request_builder.go +++ b/internal/request_builder.go @@ -3,11 +3,12 @@ package openai import ( "bytes" "context" + "io" "net/http" ) type RequestBuilder interface { - Build(ctx context.Context, method, url string, request any) (*http.Request, error) + Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) } type HTTPRequestBuilder struct { @@ -20,21 +21,32 @@ func NewRequestBuilder() *HTTPRequestBuilder { } } -func (b *HTTPRequestBuilder) Build(ctx context.Context, method, url string, request any) (*http.Request, error) { - if request == nil { - return http.NewRequestWithContext(ctx, method, url, nil) +func (b *HTTPRequestBuilder) Build( + ctx context.Context, + method string, + url string, + body any, + header http.Header, +) (req *http.Request, err error) { + var bodyReader io.Reader + if body != nil { + if v, ok := body.(io.Reader); ok { + bodyReader = v + } else { + var reqBytes []byte + reqBytes, err = b.marshaller.Marshal(body) + if err != nil { + return + } + bodyReader = bytes.NewBuffer(reqBytes) + } } - - var reqBytes []byte - reqBytes, err := b.marshaller.Marshal(request) + req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) if err != nil { - return nil, err + return } - - return http.NewRequestWithContext( - ctx, - method, - url, - bytes.NewBuffer(reqBytes), - ) + if header != nil { + req.Header = header + } + return } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e47d0f6ca..e26022a6b 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -22,7 +22,7 @@ func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { marshaller: &failingMarshaller{}, } - _, err := builder.Build(context.Background(), "", "", struct{}{}) + _, err := builder.Build(context.Background(), "", "", struct{}{}, nil) if !errors.Is(err, errTestMarshallerFailed) { t.Fatalf("Did not return error when marshaller failed: %v", err) } @@ -38,7 +38,7 @@ func TestRequestBuilderReturnsRequest(t *testing.T) { reqBytes, _ = b.marshaller.Marshal(request) want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) ) - got, _ := b.Build(ctx, method, url, request) + got, _ := b.Build(ctx, method, url, request, nil) if !reflect.DeepEqual(got.Body, want.Body) || !reflect.DeepEqual(got.URL, want.URL) || !reflect.DeepEqual(got.Method, want.Method) { @@ -54,7 +54,7 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { want, _ = http.NewRequestWithContext(ctx, method, url, nil) ) b := NewRequestBuilder() - got, _ := b.Build(ctx, method, url, nil) + got, _ := b.Build(ctx, method, url, nil, nil) if !reflect.DeepEqual(got, want) { t.Errorf("Build() got = %v, want %v", got, want) } diff --git a/models.go b/models.go index b3d458366..560402e3f 100644 --- a/models.go +++ b/models.go @@ -41,7 +41,7 @@ type ModelsList struct { // ListModels Lists the currently available models, // and provides basic information about each model such as the model id and parent. func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL("/models"), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/models")) if err != nil { return } @@ -54,7 +54,7 @@ func (c *Client) ListModels(ctx context.Context) (models ModelsList, err error) // the model such as the owner and permissioning. func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err error) { urlSuffix := fmt.Sprintf("/models/%s", modelID) - req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } diff --git a/models_test.go b/models_test.go index 0b4daf4a8..59b4f5ef7 100644 --- a/models_test.go +++ b/models_test.go @@ -1,6 +1,9 @@ package openai_test import ( + "os" + "time" + . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -56,3 +59,22 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(Model{}) fmt.Fprintln(w, string(resBytes)) } + +func TestGetModelReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetModel(ctx, "text-davinci-003") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/moderation.go b/moderation.go index bae788035..a58d759c0 100644 --- a/moderation.go +++ b/moderation.go @@ -63,7 +63,7 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - req, err := c.requestBuilder.Build(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), request) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) if err != nil { return } diff --git a/stream.go b/stream.go index 94cc0a0a2..b277f3c29 100644 --- a/stream.go +++ b/stream.go @@ -1,11 +1,8 @@ package openai import ( - "bufio" "context" "errors" - - utils "github.com/sashabaranov/go-openai/internal" ) var ( @@ -36,27 +33,17 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newStreamRequest(ctx, "POST", urlSuffix, request, request.Model) + req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { - return + return nil, err } - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := sendRequestStream[CompletionResponse](c, req) if err != nil { return } - if isFailureStatusCode(resp) { - return nil, c.handleErrorResp(resp) - } - stream = &CompletionStream{ - streamReader: &streamReader[CompletionResponse]{ - emptyMessagesLimit: c.config.EmptyMessagesLimit, - reader: bufio.NewReader(resp.Body), - response: resp, - errAccumulator: utils.NewErrorAccumulator(), - unmarshaler: &utils.JSONUnmarshaler{}, - }, + streamReader: resp, } return } diff --git a/stream_test.go b/stream_test.go index 5997f27e8..f3f8f85cd 100644 --- a/stream_test.go +++ b/stream_test.go @@ -6,7 +6,9 @@ import ( "errors" "io" "net/http" + "os" "testing" + "time" . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" @@ -300,6 +302,30 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { } } +func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: "text-davinci-002", + MaxTokens: 10, + Stream: true, + }) + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} + // Helper funcs. func compareResponses(r1, r2 CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { From 5f4ef298e3d4d74784ac53d75d0d43379efa2efc Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 23 Jun 2023 13:07:43 +0400 Subject: [PATCH 018/242] Update README.md (#406) --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 522a85e78..ef1db98cc 100644 --- a/README.md +++ b/README.md @@ -555,3 +555,10 @@ OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go ``` If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. + +## Thank you + +We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: +- [Carson Kahn](https://carsonkahn.com) of [Spindle AI](https://spindleai.com) + +To all of you: thank you. You've helped us achieve more than we ever imagined possible. Can't wait to see where we go next, together! From 0ca4ea48671c631fb15cc01d50b89c6c3658dafb Mon Sep 17 00:00:00 2001 From: James MacWhyte Date: Sat, 24 Jun 2023 18:22:11 +0200 Subject: [PATCH 019/242] move json schema to directory/package (#407) * move json schema to directory/package * added jsonschema to README --- README.md | 60 ++++++++++++++++++++++++++++++++++++++++++++++ chat.go | 39 ++++-------------------------- chat_test.go | 39 +++++++++++++++--------------- jsonschema/json.go | 35 +++++++++++++++++++++++++++ 4 files changed, 119 insertions(+), 54 deletions(-) create mode 100644 jsonschema/json.go diff --git a/README.md b/README.md index ef1db98cc..da1a2804d 100644 --- a/README.md +++ b/README.md @@ -516,6 +516,66 @@ func main() { ``` +
+JSON Schema for function calling + +It is now possible for chat completion to choose to call a function for more information ([see developer docs here](https://platform.openai.com/docs/guides/gpt/function-calling)). + +In order to describe the type of functions that can be called, a JSON schema must be provided. Many JSON schema libraries exist and are more advanced than what we can offer in this library, however we have included a simple `jsonschema` package for those who want to use this feature without formatting their own JSON schema payload. + +The developer documents give this JSON schema definition as an example: + +```json +{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and state, e.g. San Francisco, CA" + }, + "unit":{ + "type":"string", + "enum":[ + "celsius", + "fahrenheit" + ] + } + }, + "required":[ + "location" + ] + } +} +``` + +Using the `jsonschema` package, this schema could be created using structs as such: + +```go +FunctionDefinition{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celcius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, +} +``` + +The `Parameters` field of a `FunctionDefinition` can accept either of the above styles, or even a nested struct from another library (as long as it can be marshalled into JSON). +
+
Error handling diff --git a/chat.go b/chat.go index b74720d38..e4f23df07 100644 --- a/chat.go +++ b/chat.go @@ -62,47 +62,16 @@ type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` // Parameters is an object describing the function. - // You can pass a raw byte array describing the schema, - // or you can pass in a struct which serializes to the proper JSONSchema. - // The JSONSchemaDefinition struct is provided for convenience, but you should - // consider another specialized library for more complex schemas. + // You can pass a []byte describing the schema, + // or you can pass in a struct which serializes to the proper JSON schema. + // The jsonschema package is provided for convenience, but you should + // consider another specialized library if you require more complex schemas. Parameters any `json:"parameters"` } // Deprecated: use FunctionDefinition instead. type FunctionDefine = FunctionDefinition -type JSONSchemaType string - -const ( - JSONSchemaTypeObject JSONSchemaType = "object" - JSONSchemaTypeNumber JSONSchemaType = "number" - JSONSchemaTypeString JSONSchemaType = "string" - JSONSchemaTypeArray JSONSchemaType = "array" - JSONSchemaTypeNull JSONSchemaType = "null" - JSONSchemaTypeBoolean JSONSchemaType = "boolean" -) - -// JSONSchemaDefinition is a struct for JSON Schema. -// It is fairly limited and you may have better luck using a third-party library. -type JSONSchemaDefinition struct { - // Type is a type of JSON Schema. - Type JSONSchemaType `json:"type,omitempty"` - // Description is a description of JSON Schema. - Description string `json:"description,omitempty"` - // Enum is a enum of JSON Schema. It used if Type is JSONSchemaTypeString. - Enum []string `json:"enum,omitempty"` - // Properties is a properties of JSON Schema. It used if Type is JSONSchemaTypeObject. - Properties map[string]JSONSchemaDefinition `json:"properties,omitempty"` - // Required is a required of JSON Schema. It used if Type is JSONSchemaTypeObject. - Required []string `json:"required,omitempty"` - // Items is a property of JSON Schema. It used if Type is JSONSchemaTypeArray. - Items *JSONSchemaDefinition `json:"items,omitempty"` -} - -// Deprecated: use JSONSchemaDefinition instead. -type JSONSchemaDefine = JSONSchemaDefinition - type FinishReason string const ( diff --git a/chat_test.go b/chat_test.go index 3c759b310..d5879e60f 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -13,6 +10,10 @@ import ( "strings" "testing" "time" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestChatCompletionsWrongModel(t *testing.T) { @@ -128,22 +129,22 @@ func TestChatCompletionsFunctions(t *testing.T) { }, Functions: []FunctionDefinition{{ Name: "test", - Parameters: &JSONSchemaDefinition{ - Type: JSONSchemaTypeObject, - Properties: map[string]JSONSchemaDefinition{ + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "count": { - Type: JSONSchemaTypeNumber, + Type: jsonschema.Number, Description: "total number of words in sentence", }, "words": { - Type: JSONSchemaTypeArray, + Type: jsonschema.Array, Description: "list of words in sentence", - Items: &JSONSchemaDefinition{ - Type: JSONSchemaTypeString, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, }, "enumTest": { - Type: JSONSchemaTypeString, + Type: jsonschema.String, Enum: []string{"hello", "world"}, }, }, @@ -165,22 +166,22 @@ func TestChatCompletionsFunctions(t *testing.T) { }, Functions: []FunctionDefine{{ Name: "test", - Parameters: &JSONSchemaDefine{ - Type: JSONSchemaTypeObject, - Properties: map[string]JSONSchemaDefine{ + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "count": { - Type: JSONSchemaTypeNumber, + Type: jsonschema.Number, Description: "total number of words in sentence", }, "words": { - Type: JSONSchemaTypeArray, + Type: jsonschema.Array, Description: "list of words in sentence", - Items: &JSONSchemaDefine{ - Type: JSONSchemaTypeString, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, }, "enumTest": { - Type: JSONSchemaTypeString, + Type: jsonschema.String, Enum: []string{"hello", "world"}, }, }, diff --git a/jsonschema/json.go b/jsonschema/json.go new file mode 100644 index 000000000..24af8584e --- /dev/null +++ b/jsonschema/json.go @@ -0,0 +1,35 @@ +// Package jsonschema provides very simple functionality for representing a JSON schema as a +// (nested) struct. This struct can be used with the chat completion "function call" feature. +// For more complicated schemas, it is recommended to use a dedicated JSON schema library +// and/or pass in the schema in []byte format. +package jsonschema + +type DataType string + +const ( + Object DataType = "object" + Number DataType = "number" + Integer DataType = "integer" + String DataType = "string" + Array DataType = "array" + Null DataType = "null" + Boolean DataType = "boolean" +) + +// Definition is a struct for describing a JSON Schema. +// It is fairly limited and you may have better luck using a third-party library. +type Definition struct { + // Type specifies the data type of the schema. + Type DataType `json:"type,omitempty"` + // Description is the description of the schema. + Description string `json:"description,omitempty"` + // Enum is used to restrict a value to a fixed set of values. It must be an array with at least + // one element, where each element is unique. You will probably only use this with strings. + Enum []string `json:"enum,omitempty"` + // Properties describes the properties of an object, if the schema type is Object. + Properties map[string]Definition `json:"properties,omitempty"` + // Required specifies which properties are required, if the schema type is Object. + Required []string `json:"required,omitempty"` + // Items specifies which data type an array contains, if the schema type is Array. + Items *Definition `json:"items,omitempty"` +} From a3c0b36b35dac5168c9ef07dacb4c1ad55efc51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 26 Jun 2023 23:32:57 +0900 Subject: [PATCH 020/242] chore: add an issue template for feature request (#410) --- .github/ISSUE_TEMPLATE/feature_request.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..2359e5c00 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,23 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +Your issue may already be reported! +Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. From 581f70b102d7443aeea4f19cf04d570150e8cc42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 26 Jun 2023 23:33:32 +0900 Subject: [PATCH 021/242] chore: add an issue template for bug report (#408) --- .github/ISSUE_TEMPLATE/bug_report.md | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..536a2ee29 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,32 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +Your issue may already be reported! +Please search on the [issue tracker](https://github.com/sashabaranov/go-openai/issues) before creating one. + +**Describe the bug** +A clear and concise description of what the bug is. If it's an API-related bug, please provide relevant endpoint(s). + +**To Reproduce** +Steps to reproduce the behavior, including any relevant code snippets. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots/Logs** +If applicable, add screenshots to help explain your problem. For non-graphical issues, please provide any relevant logs or stack traces. + +**Environment (please complete the following information):** + - go-openai version: [e.g. v1.12.0] + - Go version: [e.g. 1.18] + - OpenAI API version: [e.g. v1] + - OS: [e.g. Ubuntu 20.04] + +**Additional context** +Add any other context about the problem here. From 86d0f48d2ddd88fed5ed4036ab32218e90d2ee4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 29 Jun 2023 02:18:34 +0900 Subject: [PATCH 022/242] chore: add a pull request template (#412) --- .../pull_request_template.md | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/pull_request_template.md diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md new file mode 100644 index 000000000..b078d1964 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md @@ -0,0 +1,27 @@ +--- +name: Pull Request +about: Propose changes to the codebase +title: '' +labels: '' +assignees: '' + +--- + +A similar PR may already be submitted! +Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +Thanks for submitting a pull request! Please provide enough information so that others can review your pull request. + +**Describe the change** +Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. + +**Describe your solution** +Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. + +**Tests** +Briefly describe how you have tested these changes. + +**Additional context** +Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. + +Issue: #XXXX From 9c99f3626f1d80382e187df8adc38f7e8e929a75 Mon Sep 17 00:00:00 2001 From: ryomak <21288308+ryomak@users.noreply.github.com> Date: Thu, 29 Jun 2023 09:41:22 +0900 Subject: [PATCH 023/242] replace deprecated FunctionDefine in chat_test.go (#416) --- chat_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chat_test.go b/chat_test.go index d5879e60f..5723d6ccf 100644 --- a/chat_test.go +++ b/chat_test.go @@ -85,7 +85,7 @@ func TestChatCompletionsFunctions(t *testing.T) { Content: "Hello!", }, }, - Functions: []FunctionDefine{{ + Functions: []FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -117,7 +117,7 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) - t.Run("JSONSchemaDefine", func(t *testing.T) { + t.Run("JSONSchemaDefinition", func(t *testing.T) { _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ MaxTokens: 5, Model: GPT3Dot5Turbo0613, @@ -153,7 +153,7 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) - t.Run("JSONSchemaDefineWithFunctionDefine", func(t *testing.T) { + t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ MaxTokens: 5, From 1efcf2d23de7866701bce946c65f271fff2f05e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Fri, 30 Jun 2023 19:49:36 +0900 Subject: [PATCH 024/242] fix: move pull request template (#420) --- .../pull_request_template.md => PULL_REQUEST_TEMPLATE.md} | 2 ++ 1 file changed, 2 insertions(+) rename .github/{PULL_REQUEST_TEMPLATE/pull_request_template.md => PULL_REQUEST_TEMPLATE.md} (79%) diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE.md similarity index 79% rename from .github/PULL_REQUEST_TEMPLATE/pull_request_template.md rename to .github/PULL_REQUEST_TEMPLATE.md index b078d1964..f7e45401b 100644 --- a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -10,6 +10,8 @@ assignees: '' A similar PR may already be submitted! Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. +If your changes introduce breaking changes, please prefix the title of your pull request with "[BREAKING_CHANGES]". This allows for clear identification of such changes in the 'What's Changed' section on the release page, making it developer-friendly. + Thanks for submitting a pull request! Please provide enough information so that others can review your pull request. **Describe the change** From 177c143be7c373a9b5d33e6c7c64ad3e6670a32c Mon Sep 17 00:00:00 2001 From: Rick Date: Sat, 1 Jul 2023 06:38:22 +0800 Subject: [PATCH 025/242] Fix OpenAI error when properties is empty in function call : object schema missing properties (#419) Co-authored-by: Rick --- jsonschema/json.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index 24af8584e..c02d250aa 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -27,7 +27,7 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties,omitempty"` + Properties map[string]Definition `json:"properties"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. From 204260818e9987c4a84f520d63d9c8758c986cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 3 Jul 2023 19:46:38 +0900 Subject: [PATCH 026/242] docs: remove medatada in PULL_REQUEST_TEMPLATE.md (#423) --- .github/PULL_REQUEST_TEMPLATE.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index f7e45401b..44bf697ed 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,12 +1,3 @@ ---- -name: Pull Request -about: Propose changes to the codebase -title: '' -labels: '' -assignees: '' - ---- - A similar PR may already be submitted! Please search among the [Pull request](https://github.com/sashabaranov/go-openai/pulls) before creating one. From 5c7d88212f6e73fdac89723d42b9e3a1b113931c Mon Sep 17 00:00:00 2001 From: Jackson Stone Date: Wed, 5 Jul 2023 16:53:53 -0500 Subject: [PATCH 027/242] Allow embeddings requests to be tokens or strings (#417) * Allow raw tokens to be used as embedding input * fix linting issues (lines too long) * add endpoint test for embedding from tokens * remove redundant comments * fix comment to match new param name * change interface to any * Rename methods and implement convert for base req * add comments to CreateEmbeddings * update tests * shorten line length * rename parameter --- embeddings.go | 62 +++++++++++++++++++++++++++++++++++++++++----- embeddings_test.go | 38 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 6 deletions(-) diff --git a/embeddings.go b/embeddings.go index ba327ce77..41af50b4b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -113,10 +113,25 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } -// EmbeddingRequest is the input to a Create embeddings request. +type EmbeddingRequestConverter interface { + // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens + Convert() EmbeddingRequest +} + type EmbeddingRequest struct { + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` +} + +func (r EmbeddingRequest) Convert() EmbeddingRequest { + return r +} + +// EmbeddingRequestStrings is the input to a create embeddings request with a slice of strings. +type EmbeddingRequestStrings struct { // Input is a slice of strings for which you want to generate an Embedding vector. - // Each input must not exceed 2048 tokens in length. + // Each input must not exceed 8192 tokens in length. // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they // have observed inferior results when newlines are present. // E.g. @@ -129,15 +144,50 @@ type EmbeddingRequest struct { User string `json:"user"` } -// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. +func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + } +} + +type EmbeddingRequestTokens struct { + // Input is a slice of slices of ints ([][]int) for which you want to generate an Embedding vector. + // Each input must not exceed 8192 tokens in length. + // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they + // have observed inferior results when newlines are present. + // E.g. + // "The food was delicious and the waiter..." + Input [][]int `json:"input"` + // ID of the model to use. You can use the List models API to see all of your available models, + // or see our Model overview for descriptions of them. + Model EmbeddingModel `json:"model"` + // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + User string `json:"user"` +} + +func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { + return EmbeddingRequest{ + Input: r.Input, + Model: r.Model, + User: r.User, + } +} + +// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |body.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", request.Model.String()), withBody(request)) +// +// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens +// for embedding groups of text already converted to tokens. +func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll + baseReq := conv.Convert() + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(req, &resp) + err = c.sendRequest(req, &res) return } diff --git a/embeddings_test.go b/embeddings_test.go index d7892cd5d..47c4f5108 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -32,6 +32,7 @@ func TestEmbedding(t *testing.T) { BabbageCodeSearchText, } for _, model := range embeddedModels { + // test embedding request with strings (simple embedding request) embeddingReq := EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", @@ -46,6 +47,34 @@ func TestEmbedding(t *testing.T) { if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { t.Fatalf("Expected embedding request to contain model field") } + + // test embedding request with strings + embeddingReqStrings := EmbeddingRequestStrings{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqStrings) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + + // test embedding request with tokens + embeddingReqTokens := EmbeddingRequestTokens{ + Input: [][]int{ + {464, 2057, 373, 12625, 290, 262, 46612}, + {6395, 6096, 286, 11525, 12083, 2581}, + }, + Model: model, + } + marshaled, err = json.Marshal(embeddingReqTokens) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } } } @@ -75,6 +104,15 @@ func TestEmbeddingEndpoint(t *testing.T) { fmt.Fprintln(w, string(resBytes)) }, ) + // test create embeddings with strings (simple embedding request) _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") + + // test create embeddings with strings + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + checks.NoError(t, err, "CreateEmbeddings strings error") + + // test create embeddings with tokens + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + checks.NoError(t, err, "CreateEmbeddings tokens error") } From 619ad717353d8b9d5f4d9049b1ce1b168c5851b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Thu, 6 Jul 2023 06:54:27 +0900 Subject: [PATCH 028/242] docs: added instructions for obtaining OpenAI API key to README (#421) * docs: added instructions for obtaining OpenAI API key to README * docs: move 'Getting an OpenAI API key' before 'Other examples' --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index da1a2804d..1f708af70 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,17 @@ func main() { ``` +### Getting an OpenAI API Key: + +1. Visit the OpenAI website at [https://platform.openai.com/account/api-keys](https://platform.openai.com/account/api-keys). +2. If you don't have an account, click on "Sign Up" to create one. If you do, click "Log In". +3. Once logged in, navigate to your API key management page. +4. Click on "Create new secret key". +5. Enter a name for your new key, then click "Create secret key". +6. Your new API key will be displayed. Use this key to interact with the OpenAI API. + +**Note:** Your API key is sensitive information. Do not share it with anyone. + ### Other examples:
From 7b22898f5d3fd86232057ed61e83adb47bf24cb0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Sun, 9 Jul 2023 17:09:50 +0800 Subject: [PATCH 029/242] Implement OpenAI July 2023 Updates (#427) * Implement OpenAI July 2023 Updates * fix: golangci-lint * add comment * fix: remove some model Deprecated --- completion.go | 55 +++++++++++++++++++++++++++++++-------------------- edits.go | 6 +++++- embeddings.go | 16 +++++++++++++++ 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/completion.go b/completion.go index b3b3abd1c..61bfed654 100644 --- a/completion.go +++ b/completion.go @@ -17,29 +17,42 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( - GPT432K0613 = "gpt-4-32k-0613" - GPT432K0314 = "gpt-4-32k-0314" - GPT432K = "gpt-4-32k" - GPT40613 = "gpt-4-0613" - GPT40314 = "gpt-4-0314" - GPT4 = "gpt-4" - GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" - GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" - GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3TextDavinci003 = "text-davinci-003" - GPT3TextDavinci002 = "text-davinci-002" - GPT3TextCurie001 = "text-curie-001" - GPT3TextBabbage001 = "text-babbage-001" - GPT3TextAda001 = "text-ada-001" - GPT3TextDavinci001 = "text-davinci-001" + GPT432K0613 = "gpt-4-32k-0613" + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" + GPT40314 = "gpt-4-0314" + GPT4 = "gpt-4" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci003 = "text-davinci-003" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci002 = "text-davinci-002" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextCurie001 = "text-curie-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextBabbage001 = "text-babbage-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextAda001 = "text-ada-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3TextDavinci001 = "text-davinci-001" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" GPT3Davinci = "davinci" - GPT3CurieInstructBeta = "curie-instruct-beta" - GPT3Curie = "curie" - GPT3Ada = "ada" - GPT3Babbage = "babbage" + GPT3Davinci002 = "davinci-002" + // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + GPT3CurieInstructBeta = "curie-instruct-beta" + GPT3Curie = "curie" + GPT3Curie002 = "curie-002" + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/edits.go b/edits.go index 3d3fc8950..831aade2f 100644 --- a/edits.go +++ b/edits.go @@ -30,7 +30,11 @@ type EditsResponse struct { Choices []EditsChoice `json:"choices"` } -// Perform an API call to the Edits endpoint. +// Edits Perform an API call to the Edits endpoint. +/* Deprecated: Users of the Edits API and its associated models (e.g., text-davinci-edit-001 or code-davinci-edit-001) +will need to migrate to GPT-3.5 Turbo by January 4, 2024. +You can use CreateChatCompletion or CreateChatCompletionStream instead. +*/ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) if err != nil { diff --git a/embeddings.go b/embeddings.go index 41af50b4b..1d3199597 100644 --- a/embeddings.go +++ b/embeddings.go @@ -34,21 +34,37 @@ func (e *EmbeddingModel) UnmarshalText(b []byte) error { const ( Unknown EmbeddingModel = iota + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSimilarity + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. CurieSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSearchDocument + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. DavinciSearchQuery + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaCodeSearchCode + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. AdaCodeSearchText + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageCodeSearchCode + // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. BabbageCodeSearchText AdaEmbeddingV2 ) From 181fc2ade904c7d6a0910cfebdd6c90e7a4d80ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Sun, 9 Jul 2023 18:11:39 +0900 Subject: [PATCH 030/242] docs: explanation about LogitBias. (129) (#426) --- chat.go | 11 +++++++---- completion.go | 35 +++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/chat.go b/chat.go index e4f23df07..17d7cd574 100644 --- a/chat.go +++ b/chat.go @@ -52,10 +52,13 @@ type ChatCompletionRequest struct { Stop []string `json:"stop,omitempty"` PresencePenalty float32 `json:"presence_penalty,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` - Functions []FunctionDefinition `json:"functions,omitempty"` - FunctionCall any `json:"function_call,omitempty"` + // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + Functions []FunctionDefinition `json:"functions,omitempty"` + FunctionCall any `json:"function_call,omitempty"` } type FunctionDefinition struct { diff --git a/completion.go b/completion.go index 61bfed654..7b9ae89e7 100644 --- a/completion.go +++ b/completion.go @@ -109,22 +109,25 @@ func checkPromptType(prompt any) bool { // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + Suffix string `json:"suffix,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + Echo bool `json:"echo,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + BestOf int `json:"best_of,omitempty"` + // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. + // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` + // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From f028c289d2e2ae7562d97594d122447fd23a632d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 10 Jul 2023 02:07:01 +0900 Subject: [PATCH 031/242] fix: function call error due to nil properties (429) (#431) * fix: fix function call error due to nil properties (429) * refactor: refactoring initializeProperties func in jsonschema pkg (429) --- jsonschema/json.go | 25 ++++- jsonschema/json_test.go | 201 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 jsonschema/json_test.go diff --git a/jsonschema/json.go b/jsonschema/json.go index c02d250aa..e4eef98e7 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,6 +4,8 @@ // and/or pass in the schema in []byte format. package jsonschema +import "encoding/json" + type DataType string const ( @@ -17,7 +19,7 @@ const ( ) // Definition is a struct for describing a JSON Schema. -// It is fairly limited and you may have better luck using a third-party library. +// It is fairly limited, and you may have better luck using a third-party library. type Definition struct { // Type specifies the data type of the schema. Type DataType `json:"type,omitempty"` @@ -33,3 +35,24 @@ type Definition struct { // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` } + +func (d *Definition) MarshalJSON() ([]byte, error) { + d.initializeProperties() + return json.Marshal(*d) +} + +func (d *Definition) initializeProperties() { + if d.Properties == nil { + d.Properties = make(map[string]Definition) + return + } + + for k, v := range d.Properties { + if v.Properties == nil { + v.Properties = make(map[string]Definition) + } else { + v.initializeProperties() + } + d.Properties[k] = v + } +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go new file mode 100644 index 000000000..0dc31a58a --- /dev/null +++ b/jsonschema/json_test.go @@ -0,0 +1,201 @@ +package jsonschema_test + +import ( + "encoding/json" + "reflect" + "testing" + + . "github.com/sashabaranov/go-openai/jsonschema" +) + +func TestDefinition_MarshalJSON(t *testing.T) { + tests := []struct { + name string + def Definition + want string + }{ + { + name: "Test with empty Definition", + def: Definition{}, + want: `{"properties":{}}`, + }, + { + name: "Test with Definition properties set", + def: Definition{ + Type: String, + Description: "A string type", + Properties: map[string]Definition{ + "name": { + Type: String, + }, + }, + }, + want: `{ + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string", + "properties":{} + } + } +}`, + }, + { + name: "Test with nested Definition properties", + def: Definition{ + Type: Object, + Properties: map[string]Definition{ + "user": { + Type: Object, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + "age": { + Type: Integer, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "properties":{} + }, + "age":{ + "type":"integer", + "properties":{} + } + } + } + } +}`, + }, + { + name: "Test with complex nested Definition", + def: Definition{ + Type: Object, + Properties: map[string]Definition{ + "user": { + Type: Object, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + "age": { + Type: Integer, + }, + "address": { + Type: Object, + Properties: map[string]Definition{ + "city": { + Type: String, + }, + "country": { + Type: String, + }, + }, + }, + }, + }, + }, + }, + want: `{ + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "properties":{} + }, + "age":{ + "type":"integer", + "properties":{} + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string", + "properties":{} + }, + "country":{ + "type":"string", + "properties":{} + } + } + } + } + } + } +}`, + }, + { + name: "Test with Array type Definition", + def: Definition{ + Type: Array, + Items: &Definition{ + Type: String, + }, + Properties: map[string]Definition{ + "name": { + Type: String, + }, + }, + }, + want: `{ + "type":"array", + "items":{ + "type":"string", + "properties":{ + + } + }, + "properties":{ + "name":{ + "type":"string", + "properties":{} + } + } +}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBytes, err := json.Marshal(&tt.def) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + wantBytes := []byte(tt.want) + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + }) + } +} From c3b2451f7c7dc477d98e1baa10993ac55392c7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Tue, 11 Jul 2023 20:48:15 +0900 Subject: [PATCH 032/242] fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) (#434) * fix: invalid schema for function 'func_name': None is not of type 'object' (#429)(#432) * test: add integration test for function call (#429)(#432) * style: remove duplicate import (#429)(#432) --- api_integration_test.go | 32 ++++++++++++++++++++++++++++++++ jsonschema/json.go | 23 +++++++---------------- jsonschema/json_test.go | 38 ++++++++++++++++++++++++-------------- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index d4e7328a2..254fbeb03 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -11,6 +11,7 @@ import ( . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestAPI(t *testing.T) { @@ -100,6 +101,37 @@ func TestAPI(t *testing.T) { if counter == 0 { t.Error("Stream did not return any responses") } + + _, err = c.CreateChatCompletion( + context.Background(), + ChatCompletionRequest{ + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "What is the weather like in Boston?", + }, + }, + Functions: []FunctionDefinition{{ + Name: "get_current_weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }}, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } func TestAPIError(t *testing.T) { diff --git a/jsonschema/json.go b/jsonschema/json.go index e4eef98e7..cb941eb75 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -36,23 +36,14 @@ type Definition struct { Items *Definition `json:"items,omitempty"` } -func (d *Definition) MarshalJSON() ([]byte, error) { - d.initializeProperties() - return json.Marshal(*d) -} - -func (d *Definition) initializeProperties() { +func (d Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) - return - } - - for k, v := range d.Properties { - if v.Properties == nil { - v.Properties = make(map[string]Definition) - } else { - v.initializeProperties() - } - d.Properties[k] = v } + type Alias Definition + return json.Marshal(struct { + Alias + }{ + Alias: (Alias)(d), + }) } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 0dc31a58a..c8d0c1d9e 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -172,30 +172,40 @@ func TestDefinition_MarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotBytes, err := json.Marshal(&tt.def) - if err != nil { - t.Errorf("Failed to Marshal JSON: error = %v", err) - return - } - - var got map[string]interface{} - err = json.Unmarshal(gotBytes, &got) - if err != nil { - t.Errorf("Failed to Unmarshal JSON: error = %v", err) - return - } - wantBytes := []byte(tt.want) var want map[string]interface{} - err = json.Unmarshal(wantBytes, &want) + err := json.Unmarshal(wantBytes, &want) if err != nil { t.Errorf("Failed to Unmarshal JSON: error = %v", err) return } + got := structToMap(t, tt.def) + gotPtr := structToMap(t, &tt.def) + if !reflect.DeepEqual(got, want) { t.Errorf("MarshalJSON() got = %v, want %v", got, want) } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } }) } } + +func structToMap(t *testing.T, v any) map[string]any { + t.Helper() + gotBytes, err := json.Marshal(v) + if err != nil { + t.Errorf("Failed to Marshal JSON: error = %v", err) + return nil + } + + var got map[string]interface{} + err = json.Unmarshal(gotBytes, &got) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return nil + } + return got +} From 39b2acb5c93c3ee12020cda8d1a5cc0cf2bea1a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 12 Jul 2023 23:15:39 +0900 Subject: [PATCH 033/242] ci: set up closing-inactive-issues in GitHub Action (129) (#428) --- .github/workflows/close-inactive-issues.yml | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/close-inactive-issues.yml diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml new file mode 100644 index 000000000..bfe9b5c96 --- /dev/null +++ b/.github/workflows/close-inactive-issues.yml @@ -0,0 +1,23 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" + +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - uses: actions/stale@v5 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + exempt-issue-labels: 'bug,enhancement' + stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." + close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file From e22a29d84ebb8c5c911937669f27ac3265f3c982 Mon Sep 17 00:00:00 2001 From: Munar <118156704+MunaerYesiyan@users.noreply.github.com> Date: Thu, 13 Jul 2023 13:30:58 +0900 Subject: [PATCH 034/242] Check if the model param is valid for moderations endpoint (#437) * chore: check for models before sending moderation requets to openai endpoint * chore: table driven tests to include more model cases for moderations endpoint --- moderation.go | 17 ++++++++++++++++- moderation_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/moderation.go b/moderation.go index a58d759c0..a32f123f3 100644 --- a/moderation.go +++ b/moderation.go @@ -2,6 +2,7 @@ package openai import ( "context" + "errors" "net/http" ) @@ -15,9 +16,19 @@ import ( const ( ModerationTextStable = "text-moderation-stable" ModerationTextLatest = "text-moderation-latest" - ModerationText001 = "text-moderation-001" + // Deprecated: use ModerationTextStable and ModerationTextLatest instead. + ModerationText001 = "text-moderation-001" ) +var ( + ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll +) + +var validModerationModel = map[string]struct{}{ + ModerationTextStable: {}, + ModerationTextLatest: {}, +} + // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { Input string `json:"input,omitempty"` @@ -63,6 +74,10 @@ type ModerationResponse struct { // Moderations — perform a moderation api call over a string. // Input can be an array or slice but a string will reduce the complexity. func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { + if _, ok := validModerationModel[request.Model]; len(request.Model) > 0 && !ok { + err = ErrModerationInvalidModel + return + } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) if err != nil { return diff --git a/moderation_test.go b/moderation_test.go index 4e756137e..68f9565e1 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -27,6 +27,41 @@ func TestModerations(t *testing.T) { checks.NoError(t, err, "Moderation error") } +// TestModerationsWithIncorrectModel Tests passing valid and invalid models to moderations endpoint. +func TestModerationsWithDifferentModelOptions(t *testing.T) { + var modelOptions []struct { + model string + expect error + } + modelOptions = append(modelOptions, + getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), + getModerationModelTestOption(ModerationTextStable, nil), + getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption("", nil), + ) + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/moderations", handleModerationEndpoint) + for _, modelTest := range modelOptions { + _, err := client.Moderations(context.Background(), ModerationRequest{ + Model: modelTest.model, + Input: "I want to kill them.", + }) + checks.ErrorIs(t, err, modelTest.expect, + fmt.Sprintf("Moderations(..) expects err: %v, actual err:%v", modelTest.expect, err)) + } +} + +func getModerationModelTestOption(model string, expect error) struct { + model string + expect error +} { + return struct { + model string + expect error + }{model: model, expect: expect} +} + // handleModerationEndpoint Handles the moderation endpoint by the test server. func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { var err error From 0234c1e0c2769c9599f0799259fb8db5c4e3e011 Mon Sep 17 00:00:00 2001 From: Mehul Gohil Date: Sat, 15 Jul 2023 03:43:05 +0530 Subject: [PATCH 035/242] add example: fine tune (#438) * add example for fine tune * update example for fine tune * fix comments --- README.md | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/README.md b/README.md index 1f708af70..19aadde2a 100644 --- a/README.md +++ b/README.md @@ -611,6 +611,73 @@ if errors.As(err, &e) { ```
+
+Fine Tune Model + +```go +package main + +import ( + "context" + "fmt" + "github.com/sashabaranov/go-openai" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + // create a .jsonl file with your training data + // {"prompt": "", "completion": ""} + // {"prompt": "", "completion": ""} + // {"prompt": "", "completion": ""} + + // you can use openai cli tool to validate the data + // For more info - https://platform.openai.com/docs/guides/fine-tuning + + file, err := client.CreateFile(ctx, openai.FileRequest{ + FilePath: "training_prepared.jsonl", + Purpose: "fine-tune", + }) + if err != nil { + fmt.Printf("Upload JSONL file error: %v\n", err) + return + } + + // create a fine tune job + // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) + // use below get method to know the status of your model + tune, err := client.CreateFineTune(ctx, openai.FineTuneRequest{ + TrainingFile: file.ID, + Model: "ada", // babbage, curie, davinci, or a fine-tuned model created after 2022-04-21. + }) + if err != nil { + fmt.Printf("Creating new fine tune model error: %v\n", err) + return + } + + getTune, err := client.GetFineTune(ctx, tune.ID) + if err != nil { + fmt.Printf("Getting fine tune model error: %v\n", err) + return + } + fmt.Println(getTune.FineTunedModel) + + // once the status of getTune is `succeeded`, you can use your fine tune model in Completion Request + + // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ + // Model: getTune.FineTunedModel, + // Prompt: "your prompt", + // }) + // if err != nil { + // fmt.Printf("Create completion error %v\n", err) + // return + // } + // + // fmt.Println(resp.Choices[0].Text) +} +``` +
See the `examples/` folder for more. ### Integration tests: From 1876e0c20716afc4d012688bc393ccd5f28def79 Mon Sep 17 00:00:00 2001 From: Savannah Ostrowski Date: Fri, 14 Jul 2023 21:33:55 -0700 Subject: [PATCH 036/242] update to json.RawMessage (#441) --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 17d7cd574..7a6438e7f 100644 --- a/chat.go +++ b/chat.go @@ -65,7 +65,7 @@ type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` // Parameters is an object describing the function. - // You can pass a []byte describing the schema, + // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. // The jsonschema package is provided for convenience, but you should // consider another specialized library if you require more complex schemas. From 1153eb2595d1529927757dd6df4de71faaafde02 Mon Sep 17 00:00:00 2001 From: ZeroDeng Date: Fri, 21 Jul 2023 00:25:58 +0800 Subject: [PATCH 037/242] Add support for azure openai new version API (2023-07-01-preview) (#451) --- chat.go | 29 +++++++++++++++++++++++++++++ chat_stream.go | 18 ++++++++++-------- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/chat.go b/chat.go index 7a6438e7f..514aaee75 100644 --- a/chat.go +++ b/chat.go @@ -21,6 +21,35 @@ var ( ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll ) +type Hate struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type SelfHarm struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type Sexual struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} +type Violence struct { + Filtered bool `json:"filtered"` + Severity string `json:"severity,omitempty"` +} + +type ContentFilterResults struct { + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` +} + +type PromptAnnotation struct { + PromptIndex int `json:"prompt_index,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` diff --git a/chat_stream.go b/chat_stream.go index 9f4e80cff..f1faa3964 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -12,17 +12,19 @@ type ChatCompletionStreamChoiceDelta struct { } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` } // ChatCompletionStream From 62dc817b395d16fb0e65be490d49e294ac8c40b0 Mon Sep 17 00:00:00 2001 From: Yu <1095780+yuikns@users.noreply.github.com> Date: Fri, 28 Jul 2023 12:06:48 +0800 Subject: [PATCH 038/242] feat: make finish reason nullable in json marshal (#449) --- chat.go | 7 +++++++ chat_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/chat.go b/chat.go index 514aaee75..8d29b3237 100644 --- a/chat.go +++ b/chat.go @@ -114,6 +114,13 @@ const ( FinishReasonNull FinishReason = "null" ) +func (r FinishReason) MarshalJSON() ([]byte, error) { + if r == FinishReasonNull || r == "" { + return []byte("null"), nil + } + return []byte(`"` + string(r) + `"`), nil // best effort to not break future API changes +} + type ChatCompletionChoice struct { Index int `json:"index"` Message ChatCompletionMessage `json:"message"` diff --git a/chat_test.go b/chat_test.go index 5723d6ccf..38d66fa64 100644 --- a/chat_test.go +++ b/chat_test.go @@ -298,3 +298,34 @@ func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { } return completion, nil } + +func TestFinishReason(t *testing.T) { + c := &ChatCompletionChoice{ + FinishReason: FinishReasonNull, + } + resBytes, _ := json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + c.FinishReason = "" + + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), `"finish_reason":null`) { + t.Error("null should not be quoted") + } + + otherReasons := []FinishReason{ + FinishReasonStop, + FinishReasonLength, + FinishReasonFunctionCall, + FinishReasonContentFilter, + } + for _, r := range otherReasons { + c.FinishReason = r + resBytes, _ = json.Marshal(c) + if !strings.Contains(string(resBytes), fmt.Sprintf(`"finish_reason":"%s"`, r)) { + t.Errorf("%s should be quoted", r) + } + } +} From 71a24931dbc5b7029901ff963dc4d0d2509aa7ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Mon, 31 Jul 2023 04:58:49 +0900 Subject: [PATCH 039/242] docs: add Frequently Asked Questions to README.md (#462) * docs: add Frequently Asked Questions to README.md * Update README.md Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- README.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/README.md b/README.md index 19aadde2a..d627a19ce 100644 --- a/README.md +++ b/README.md @@ -694,6 +694,37 @@ OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. +## Frequently Asked Questions + +### Why don't we get the same answer when specifying a temperature field of 0 and asking the same question? + +Even when specifying a temperature field of 0, it doesn't guarantee that you'll always get the same response. Several factors come into play. + +1. Go OpenAI Behavior: When you specify a temperature field of 0 in Go OpenAI, the omitempty tag causes that field to be removed from the request. Consequently, the OpenAI API applies the default value of 1. +2. Token Count for Input/Output: If there's a large number of tokens in the input and output, setting the temperature to 0 can still result in non-deterministic behavior. In particular, when using around 32k tokens, the likelihood of non-deterministic behavior becomes highest even with a temperature of 0. + +Due to the factors mentioned above, different answers may be returned even for the same question. + +**Workarounds:** +1. Using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +2. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. + +By adopting these strategies, you can expect more consistent results. + +**Related Issues:** +[omitempty option of request struct will generate incorrect request when parameter is 0.](https://github.com/sashabaranov/go-openai/issues/9) + +### Does Go OpenAI provide a method to count tokens? + +No, Go OpenAI does not offer a feature to count tokens, and there are no plans to provide such a feature in the future. However, if there's a way to implement a token counting feature with zero dependencies, it might be possible to merge that feature into Go OpenAI. Otherwise, it would be more appropriate to implement it in a dedicated library or repository. + +For counting tokens, you might find the following links helpful: +- [Counting Tokens For Chat API Calls](https://github.com/pkoukk/tiktoken-go#counting-tokens-for-chat-api-calls) +- [How to count tokens with tiktoken](https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb) + +**Related Issues:** +[Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) + ## Thank you We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: From 34569895f6a0ab4a3ccc497573a8343ee33dc3b1 Mon Sep 17 00:00:00 2001 From: ZeroDeng Date: Wed, 9 Aug 2023 12:05:39 +0800 Subject: [PATCH 040/242] Compatible with the 2023-07-01-preview API interface of Azure Openai, when content interception is triggered, the error message will contain innererror (#460) * Compatible with Azure Openai's 2023-07-01-preview version API interface about the error information returned by the intercepted interface * Compatible with the 2023-07-01-preview API interface of Azure Openai, when content interception is triggered, the error message will contain innererror.InnerError struct is only valid for Azure OpenAI Service. --- error.go | 25 +++++++++++++---- error_test.go | 78 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 5 deletions(-) diff --git a/error.go b/error.go index f68e92875..b2d01e22e 100644 --- a/error.go +++ b/error.go @@ -7,12 +7,20 @@ import ( ) // APIError provides error information returned by the OpenAI API. +// InnerError struct is only valid for Azure OpenAI Service. type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - HTTPStatusCode int `json:"-"` + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` + HTTPStatusCode int `json:"-"` + InnerError *InnerError `json:"innererror,omitempty"` +} + +// InnerError Azure Content filtering. Only valid for Azure OpenAI Service. +type InnerError struct { + Code string `json:"code,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } // RequestError provides informations about generic request errors. @@ -61,6 +69,13 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } } + if _, ok := rawMap["innererror"]; ok { + err = json.Unmarshal(rawMap["innererror"], &e.InnerError) + if err != nil { + return + } + } + // optional fields if _, ok := rawMap["param"]; ok { err = json.Unmarshal(rawMap["param"], &e.Param) diff --git a/error_test.go b/error_test.go index e2309abd7..a0806b7ed 100644 --- a/error_test.go +++ b/error_test.go @@ -3,6 +3,7 @@ package openai_test import ( "errors" "net/http" + "reflect" "testing" . "github.com/sashabaranov/go-openai" @@ -57,6 +58,77 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { assertAPIErrorMessage(t, apiErr, "") }, }, + { + name: "parse succeeds when the innerError is not exists (Azure Openai)", + response: `{ + "message": "test message", + "type": null, + "param": "prompt", + "code": "content_filter", + "status": 400, + "innererror": { + "code": "ResponsibleAIPolicyViolation", + "content_filter_result": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": true, + "severity": "medium" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + }`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{ + Code: "ResponsibleAIPolicyViolation", + ContentFilterResults: ContentFilterResults{ + Hate: Hate{ + Filtered: false, + Severity: "safe", + }, + SelfHarm: SelfHarm{ + Filtered: false, + Severity: "safe", + }, + Sexual: Sexual{ + Filtered: true, + Severity: "medium", + }, + Violence: Violence{ + Filtered: false, + Severity: "safe", + }, + }, + }) + }, + }, + { + name: "parse succeeds when the innerError is empty (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, + hasError: false, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, + { + name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", + response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, + hasError: true, + checkFunc: func(t *testing.T, apiErr APIError) { + assertAPIErrorInnerError(t, apiErr, &InnerError{}) + }, + }, { name: "parse failed when the message is object", response: `{"message":{},"type":"invalid_request_error","param":null,"code":null}`, @@ -152,6 +224,12 @@ func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { } } +func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { + if !reflect.DeepEqual(apiErr.InnerError, expected) { + t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) + } +} + func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: From a14bc103f4bc2b3ac40c844079fdf59dfdf62b0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 9 Aug 2023 13:07:14 +0900 Subject: [PATCH 041/242] docs: Add Contributing Guidelines (#463) --- CONTRIBUTING.md | 88 +++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 24 +++++--------- 2 files changed, 97 insertions(+), 15 deletions(-) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..4dd184042 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,88 @@ +# Contributing Guidelines + +## Overview +Thank you for your interest in contributing to the "Go OpenAI" project! By following this guideline, we hope to ensure that your contributions are made smoothly and efficiently. The Go OpenAI project is licensed under the [Apache 2.0 License](https://github.com/sashabaranov/go-openai/blob/master/LICENSE), and we welcome contributions through GitHub pull requests. + +## Reporting Bugs +If you discover a bug, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to see if the issue has already been reported. If you're reporting a new issue, please use the "Bug report" template and provide detailed information about the problem, including steps to reproduce it. + +## Suggesting Features +If you want to suggest a new feature or improvement, first check the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to ensure a similar suggestion hasn't already been made. Use the "Feature request" template to provide a detailed description of your suggestion. + +## Reporting Vulnerabilities +If you identify a security concern, please use the "Report a security vulnerability" template on the [GitHub Issues page](https://github.com/sashabaranov/go-openai/issues) to share the details. This report will only be viewable to repository maintainers. You will be credited if the advisory is published. + +## Questions for Users +If you have questions, please utilize [StackOverflow](https://stackoverflow.com/) or the [GitHub Discussions page](https://github.com/sashabaranov/go-openai/discussions). + +## Contributing Code +There might already be a similar pull requests submitted! Please search for [pull requests](https://github.com/sashabaranov/go-openai/pulls) before creating one. + +### Requirements for Merging a Pull Request + +The requirements to accept a pull request are as follows: + +- Features not provided by the OpenAI API will not be accepted. +- The functionality of the feature must match that of the official OpenAI API. +- All pull requests should be written in Go according to common conventions, formatted with `goimports`, and free of warnings from tools like `golangci-lint`. +- Include tests and ensure all tests pass. +- Maintain test coverage without any reduction. +- All pull requests require approval from at least one Go OpenAI maintainer. + +**Note:** +The merging method for pull requests in this repository is squash merge. + +### Creating a Pull Request +- Fork the repository. +- Create a new branch and commit your changes. +- Push that branch to GitHub. +- Start a new Pull Request on GitHub. (Please use the pull request template to provide detailed information.) + +**Note:** +If your changes introduce breaking changes, please prefix your pull request title with "[BREAKING_CHANGES]". + +### Code Style +In this project, we adhere to the standard coding style of Go. Your code should maintain consistency with the rest of the codebase. To achieve this, please format your code using tools like `goimports` and resolve any syntax or style issues with `golangci-lint`. + +**Run goimports:** +``` +go install golang.org/x/tools/cmd/goimports@latest +``` + +``` +goimports -w . +``` + +**Run golangci-lint:** +``` +go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest +``` + +``` +golangci-lint run --out-format=github-actions +``` + +### Unit Test +Please create or update tests relevant to your changes. Ensure all tests run successfully to verify that your modifications do not adversely affect other functionalities. + +**Run test:** +``` +go test -v ./... +``` + +### Integration Test +Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. + +**Notes:** +These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. + +**Run integration test:** +``` +OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go +``` + +If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. + +--- + +We wholeheartedly welcome your active participation. Let's build an amazing project together! diff --git a/README.md b/README.md index d627a19ce..9714d89fe 100644 --- a/README.md +++ b/README.md @@ -10,12 +10,16 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * DALL·E 2 * Whisper -### Installation: +## Installation + ``` go get github.com/sashabaranov/go-openai ``` Currently, go-openai requires Go version 1.18 or greater. + +## Usage + ### ChatGPT example usage: ```go @@ -680,20 +684,6 @@ func main() {
See the `examples/` folder for more. -### Integration tests: - -Integration tests are requested against the production version of the OpenAI API. These tests will verify that the library is properly coded against the actual behavior of the API, and will fail upon any incompatible change in the API. - -**Notes:** -These tests send real network traffic to the OpenAI API and may reach rate limits. Temporary network problems may also cause the test to fail. - -**Run tests using:** -``` -OPENAI_TOKEN=XXX go test -v -tags=integration ./api_integration_test.go -``` - -If the `OPENAI_TOKEN` environment variable is not available, integration tests will be skipped. - ## Frequently Asked Questions ### Why don't we get the same answer when specifying a temperature field of 0 and asking the same question? @@ -725,6 +715,10 @@ For counting tokens, you might find the following links helpful: **Related Issues:** [Is it possible to join the implementation of GPT3 Tokenizer](https://github.com/sashabaranov/go-openai/issues/62) +## Contributing + +By following [Contributing Guidelines](https://github.com/sashabaranov/go-openai/blob/master/CONTRIBUTING.md), we hope to ensure that your contributions are made smoothly and efficiently. + ## Thank you We want to take a moment to express our deepest gratitude to the [contributors](https://github.com/sashabaranov/go-openai/graphs/contributors) and sponsors of this project: From a2ca01bb6dae1a7d58860a5b2d5d5273667e089e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 29 Aug 2023 14:04:27 +0200 Subject: [PATCH 042/242] feat: implement new fine tuning job API (#479) * feat: implement new fine tuning job API * fix: export ListFineTuningJobEventsParameter * fix: lint errors * fix: test errors * fix: code test coverage * fix: code test coverage * fix: use any * chore: use url.Values --- client_test.go | 12 ++++ fine_tuning_job.go | 153 ++++++++++++++++++++++++++++++++++++++++ fine_tuning_job_test.go | 90 +++++++++++++++++++++++ 3 files changed, 255 insertions(+) create mode 100644 fine_tuning_job.go create mode 100644 fine_tuning_job_test.go diff --git a/client_test.go b/client_test.go index 29d84edfa..9b5046899 100644 --- a/client_test.go +++ b/client_test.go @@ -223,6 +223,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListFineTuneEvents", func() (any, error) { return client.ListFineTuneEvents(ctx, "") }}, + {"CreateFineTuningJob", func() (any, error) { + return client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + }}, + {"CancelFineTuningJob", func() (any, error) { + return client.CancelFineTuningJob(ctx, "") + }}, + {"RetrieveFineTuningJob", func() (any, error) { + return client.RetrieveFineTuningJob(ctx, "") + }}, + {"ListFineTuningJobEvents", func() (any, error) { + return client.ListFineTuningJobEvents(ctx, "") + }}, {"Moderations", func() (any, error) { return client.Moderations(ctx, ModerationRequest{}) }}, diff --git a/fine_tuning_job.go b/fine_tuning_job.go new file mode 100644 index 000000000..a840b7ec3 --- /dev/null +++ b/fine_tuning_job.go @@ -0,0 +1,153 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type FineTuningJob struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + FinishedAt int64 `json:"finished_at"` + Model string `json:"model"` + FineTunedModel string `json:"fine_tuned_model,omitempty"` + OrganizationID string `json:"organization_id"` + Status string `json:"status"` + Hyperparameters Hyperparameters `json:"hyperparameters"` + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + ResultFiles []string `json:"result_files"` + TrainedTokens int `json:"trained_tokens"` +} + +type Hyperparameters struct { + Epochs int `json:"n_epochs"` +} + +type FineTuningJobRequest struct { + TrainingFile string `json:"training_file"` + ValidationFile string `json:"validation_file,omitempty"` + Model string `json:"model,omitempty"` + Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"` + Suffix string `json:"suffix,omitempty"` +} + +type FineTuningJobEventList struct { + Object string `json:"object"` + Data []FineTuneEvent `json:"data"` + HasMore bool `json:"has_more"` +} + +type FineTuningJobEvent struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt int `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + Data any `json:"data"` + Type string `json:"type"` +} + +// CreateFineTuningJob create a fine tuning job. +func (c *Client) CreateFineTuningJob( + ctx context.Context, + request FineTuningJobRequest, +) (response FineTuningJob, err error) { + urlSuffix := "/fine_tuning/jobs" + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelFineTuningJob cancel a fine tuning job. +func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel")) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveFineTuningJob retrieve a fine tuning job. +func (c *Client) RetrieveFineTuningJob( + ctx context.Context, + fineTuningJobID string, +) (response FineTuningJob, err error) { + urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type listFineTuningJobEventsParameters struct { + after *string + limit *int +} + +type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters) + +func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.after = &after + } +} + +func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter { + return func(args *listFineTuningJobEventsParameters) { + args.limit = &limit + } +} + +// ListFineTuningJobs list fine tuning jobs events. +func (c *Client) ListFineTuningJobEvents( + ctx context.Context, + fineTuningJobID string, + setters ...ListFineTuningJobEventsParameter, +) (response FineTuningJobEventList, err error) { + parameters := &listFineTuningJobEventsParameters{ + after: nil, + limit: nil, + } + + for _, setter := range setters { + setter(parameters) + } + + urlValues := url.Values{} + if parameters.after != nil { + urlValues.Add("after", *parameters.after) + } + if parameters.limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit)) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go new file mode 100644 index 000000000..519c6cd2d --- /dev/null +++ b/fine_tuning_job_test.go @@ -0,0 +1,90 @@ +package openai_test + +import ( + "context" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +const testFineTuninigJobID = "fine-tuning-job-id" + +// TestFineTuningJob Tests the fine tuning job endpoint of the API using the mocked server. +func TestFineTuningJob(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler( + "/v1/fine_tuning/jobs", + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID, + func(w http.ResponseWriter, r *http.Request) { + var resBytes []byte + resBytes, _ = json.Marshal(FineTuningJob{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + server.RegisterHandler( + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(FineTuningJobEventList{}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + + ctx := context.Background() + + _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + checks.NoError(t, err, "CreateFineTuningJob error") + + _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "CancelFineTuningJob error") + + _, err = client.RetrieveFineTuningJob(ctx, testFineTuninigJobID) + checks.NoError(t, err, "RetrieveFineTuningJob error") + + _, err = client.ListFineTuningJobEvents(ctx, testFineTuninigJobID) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") + + _, err = client.ListFineTuningJobEvents( + ctx, + testFineTuninigJobID, + ListFineTuningJobEventsWithAfter("last-event-id"), + ListFineTuningJobEventsWithLimit(10), + ) + checks.NoError(t, err, "ListFineTuningJobEvents error") +} From 25da859c189c62c2454717fb2214da079017ff8e Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 31 Aug 2023 12:14:39 +0200 Subject: [PATCH 043/242] Chore Deprecate legacy fine tunes API (#484) * chore: add deprecation message * chore: use new fine tuning API in README example --- README.md | 21 +++++++++++++-------- fine_tunes.go | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9714d89fe..440c40968 100644 --- a/README.md +++ b/README.md @@ -631,11 +631,16 @@ func main() { client := openai.NewClient("your token") ctx := context.Background() - // create a .jsonl file with your training data + // create a .jsonl file with your training data for conversational model // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} // {"prompt": "", "completion": ""} + // chat models are trained using the following file format: + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "Oh, just some guy named William Shakespeare. Ever heard of him?"}]} + // {"messages": [{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "Around 384,400 kilometers. Give or take a few, like that really matters."}]} + // you can use openai cli tool to validate the data // For more info - https://platform.openai.com/docs/guides/fine-tuning @@ -648,29 +653,29 @@ func main() { return } - // create a fine tune job + // create a fine tuning job // Streams events until the job is done (this often takes minutes, but can take hours if there are many jobs in the queue or your dataset is large) // use below get method to know the status of your model - tune, err := client.CreateFineTune(ctx, openai.FineTuneRequest{ + fineTuningJob, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{ TrainingFile: file.ID, - Model: "ada", // babbage, curie, davinci, or a fine-tuned model created after 2022-04-21. + Model: "davinci-002", // gpt-3.5-turbo-0613, babbage-002. }) if err != nil { fmt.Printf("Creating new fine tune model error: %v\n", err) return } - getTune, err := client.GetFineTune(ctx, tune.ID) + fineTuningJob, err = client.RetrieveFineTuningJob(ctx, fineTuningJob.ID) if err != nil { fmt.Printf("Getting fine tune model error: %v\n", err) return } - fmt.Println(getTune.FineTunedModel) + fmt.Println(fineTuningJob.FineTunedModel) - // once the status of getTune is `succeeded`, you can use your fine tune model in Completion Request + // once the status of fineTuningJob is `succeeded`, you can use your fine tune model in Completion Request or Chat Completion Request // resp, err := client.CreateCompletion(ctx, openai.CompletionRequest{ - // Model: getTune.FineTunedModel, + // Model: fineTuningJob.FineTunedModel, // Prompt: "your prompt", // }) // if err != nil { diff --git a/fine_tunes.go b/fine_tunes.go index 96e731d51..7d3b59dbd 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -6,6 +6,9 @@ import ( "net/http" ) +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneRequest struct { TrainingFile string `json:"training_file"` ValidationFile string `json:"validation_file,omitempty"` @@ -21,6 +24,9 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTune struct { ID string `json:"id"` Object string `json:"object"` @@ -37,6 +43,9 @@ type FineTune struct { UpdatedAt int64 `json:"updated_at"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEvent struct { Object string `json:"object"` CreatedAt int64 `json:"created_at"` @@ -44,6 +53,9 @@ type FineTuneEvent struct { Message string `json:"message"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneHyperParams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` @@ -51,21 +63,34 @@ type FineTuneHyperParams struct { PromptLossWeight float64 `json:"prompt_loss_weight"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` } + +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (response FineTune, err error) { urlSuffix := "/fine-tunes" req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) @@ -78,6 +103,9 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r } // CancelFineTune cancel a fine-tune job. +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { @@ -88,6 +116,9 @@ func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes")) if err != nil { @@ -98,6 +129,9 @@ func (c *Client) ListFineTunes(ctx context.Context) (response FineTuneList, err return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { urlSuffix := fmt.Sprintf("/fine-tunes/%s", fineTuneID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) @@ -109,6 +143,9 @@ func (c *Client) GetFineTune(ctx context.Context, fineTuneID string) (response F return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (response FineTuneDeleteResponse, err error) { req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/fine-tunes/"+fineTuneID)) if err != nil { @@ -119,6 +156,9 @@ func (c *Client) DeleteFineTune(ctx context.Context, fineTuneID string) (respons return } +// Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. +// This API will be officially deprecated on January 4th, 2024. +// OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) ListFineTuneEvents(ctx context.Context, fineTuneID string) (response FineTuneEventList, err error) { req, err := c.newRequest(ctx, http.MethodGet, c.fullURL("/fine-tunes/"+fineTuneID+"/events")) if err != nil { From 3589837b229aeace205f312aa839bf73154e2820 Mon Sep 17 00:00:00 2001 From: NullpointerW <58949721+NullpointerW@users.noreply.github.com> Date: Thu, 7 Sep 2023 18:52:47 +0800 Subject: [PATCH 044/242] Update OpenAPI file return struct (#486) * completionBatchingRequestSupport * lint fix * fix Run test fail * fix TestClientReturnsRequestBuilderErrors fail * fix Codecov check * ignore TestClientReturnsRequestBuilderErrors lint * fix lint again * lint again*2 * replace checkPromptType implementation * remove nil check * update file return struct --------- Co-authored-by: W <825708370@qq.com> --- files.go | 15 ++++++++------- files_api_test.go | 1 - 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/files.go b/files.go index ea1f50a73..8b933c362 100644 --- a/files.go +++ b/files.go @@ -17,13 +17,14 @@ type FileRequest struct { // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int64 `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + Bytes int `json:"bytes"` + CreatedAt int64 `json:"created_at"` + ID string `json:"id"` + FileName string `json:"filename"` + Object string `json:"object"` + Status string `json:"status"` + Purpose string `json:"purpose"` + StatusDetails string `json:"status_details"` } // FilesList is a list of files that belong to the user or organization. diff --git a/files_api_test.go b/files_api_test.go index f0a08764d..1cbc72894 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -64,7 +64,6 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { Purpose: purpose, CreatedAt: time.Now().Unix(), Object: "test-objecct", - Owner: "test-owner", } resBytes, _ = json.Marshal(fileReq) From 8e4b7963a3f378332bd512a5040d75d8504505c8 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 11 Sep 2023 15:44:46 +0200 Subject: [PATCH 045/242] Chore Support base64 embedding format (#485) * chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix --- embeddings.go | 116 +++++++++++++++++++++++++++++++++++---- embeddings_test.go | 131 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 229 insertions(+), 18 deletions(-) diff --git a/embeddings.go b/embeddings.go index 1d3199597..5ba91f235 100644 --- a/embeddings.go +++ b/embeddings.go @@ -2,6 +2,9 @@ package openai import ( "context" + "encoding/base64" + "encoding/binary" + "math" "net/http" ) @@ -129,15 +132,83 @@ type EmbeddingResponse struct { Usage Usage `json:"usage"` } +type base64String string + +func (b base64String) Decode() ([]float32, error) { + decodedData, err := base64.StdEncoding.DecodeString(string(b)) + if err != nil { + return nil, err + } + + const sizeOfFloat32 = 4 + floats := make([]float32, len(decodedData)/sizeOfFloat32) + for i := 0; i < len(floats); i++ { + floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4])) + } + + return floats, nil +} + +// Base64Embedding is a container for base64 encoded embeddings. +type Base64Embedding struct { + Object string `json:"object"` + Embedding base64String `json:"embedding"` + Index int `json:"index"` +} + +// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format. +type EmbeddingResponseBase64 struct { + Object string `json:"object"` + Data []Base64Embedding `json:"data"` + Model EmbeddingModel `json:"model"` + Usage Usage `json:"usage"` +} + +// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. +func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) { + data := make([]Embedding, len(r.Data)) + + for i, base64Embedding := range r.Data { + embedding, err := base64Embedding.Embedding.Decode() + if err != nil { + return EmbeddingResponse{}, err + } + + data[i] = Embedding{ + Object: base64Embedding.Object, + Embedding: embedding, + Index: base64Embedding.Index, + } + } + + return EmbeddingResponse{ + Object: r.Object, + Model: r.Model, + Data: data, + Usage: r.Usage, + }, nil +} + type EmbeddingRequestConverter interface { // Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens Convert() EmbeddingRequest } +// EmbeddingEncodingFormat is the format of the embeddings data. +// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. +// If not specified OpenAI will use "float". +type EmbeddingEncodingFormat string + +const ( + EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float" + EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64" +) + type EmbeddingRequest struct { - Input any `json:"input"` - Model EmbeddingModel `json:"model"` - User string `json:"user"` + Input any `json:"input"` + Model EmbeddingModel `json:"model"` + User string `json:"user"` + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct { Model EmbeddingModel `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` + // EmbeddingEncodingFormat is the format of the embeddings data. + // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. + // If not specified OpenAI will use "float". + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { return EmbeddingRequest{ - Input: r.Input, - Model: r.Model, - User: r.User, + Input: r.Input, + Model: r.Model, + User: r.User, + EncodingFormat: r.EncodingFormat, } } @@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { // // Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens // for embedding groups of text already converted to tokens. -func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll +func (c *Client) CreateEmbeddings( + ctx context.Context, + conv EmbeddingRequestConverter, +) (res EmbeddingResponse, err error) { baseReq := conv.Convert() req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) if err != nil { return } - err = c.sendRequest(req, &res) + if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 { + err = c.sendRequest(req, &res) + return + } + + base64Response := &EmbeddingResponseBase64{} + err = c.sendRequest(req, base64Response) + if err != nil { + return + } + res, err = base64Response.ToEmbeddingResponse() return } diff --git a/embeddings_test.go b/embeddings_test.go index 47c4f5108..9c48c5b8f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,15 +1,16 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "bytes" "context" "encoding/json" "fmt" "net/http" + "reflect" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { @@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) { func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() + + sampleEmbeddings := []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + sampleBase64Embeddings := []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + } + server.RegisterHandler( "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EmbeddingResponse{}) + var req struct { + EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + + var resBytes []byte + switch { + case req.User == "invalid": + w.WriteHeader(http.StatusBadRequest) + return + case req.EncodingFormat == EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + default: + resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - _, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (simple embedding request) + res, err = client.CreateEmbeddings( + context.Background(), + EmbeddingRequest{ + EncodingFormat: EmbeddingEncodingFormatBase64, + }, + ) checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with strings - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } // test create embeddings with tokens - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test failed sendRequest + _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + User: "invalid", + EncodingFormat: EmbeddingEncodingFormatBase64, + }) + checks.HasError(t, err, "CreateEmbeddings error") +} + +func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { + type fields struct { + Object string + Data []Base64Embedding + Model EmbeddingModel + Usage Usage + } + tests := []struct { + name string + fields fields + want EmbeddingResponse + wantErr bool + }{ + { + name: "test embedding response base64 to embedding response", + fields: fields{ + Data: []Base64Embedding{ + {Embedding: "pHCdP4XrkUDhevxA"}, + {Embedding: "/1jku0G/rLvA/EI8"}, + }, + }, + want: EmbeddingResponse{ + Data: []Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + }, + }, + wantErr: false, + }, + { + name: "Invalid embedding", + fields: fields{ + Data: []Base64Embedding{ + { + Embedding: "----", + }, + }, + }, + want: EmbeddingResponse{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &EmbeddingResponseBase64{ + Object: tt.fields.Object, + Data: tt.fields.Data, + Model: tt.fields.Model, + Usage: tt.fields.Usage, + } + got, err := r.ToEmbeddingResponse() + if (err != nil) != tt.wantErr { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want) + } + }) + } } From 0d5256fb820a34a95b8944b9410a1e562087cd8f Mon Sep 17 00:00:00 2001 From: Brendan Martin Date: Mon, 25 Sep 2023 04:08:45 -0400 Subject: [PATCH 046/242] added delete fine tune model endpoint (#497) --- client_test.go | 3 +++ models.go | 20 ++++++++++++++++++++ models_test.go | 15 +++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/client_test.go b/client_test.go index 9b5046899..2c1d749ed 100644 --- a/client_test.go +++ b/client_test.go @@ -271,6 +271,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"GetModel", func() (any, error) { return client.GetModel(ctx, "text-davinci-003") }}, + {"DeleteFineTuneModel", func() (any, error) { + return client.DeleteFineTuneModel(ctx, "") + }}, } for _, testCase := range testCases { diff --git a/models.go b/models.go index 560402e3f..c207f0a86 100644 --- a/models.go +++ b/models.go @@ -33,6 +33,13 @@ type Permission struct { IsBlocking bool `json:"is_blocking"` } +// FineTuneModelDeleteResponse represents the deletion status of a fine-tuned model. +type FineTuneModelDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` +} + // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` @@ -62,3 +69,16 @@ func (c *Client) GetModel(ctx context.Context, modelID string) (model Model, err err = c.sendRequest(req, &model) return } + +// DeleteFineTuneModel Deletes a fine-tune model. You must have the Owner +// role in your organization to delete a model. +func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( + response FineTuneModelDeleteResponse, err error) { + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL("/models/"+modelID)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/models_test.go b/models_test.go index 59b4f5ef7..9ff73042a 100644 --- a/models_test.go +++ b/models_test.go @@ -14,6 +14,8 @@ import ( "testing" ) +const testFineTuneModelID = "fine-tune-model-id" + // TestListModels Tests the list models endpoint of the API using the mocked server. func TestListModels(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -78,3 +80,16 @@ func TestGetModelReturnTimeoutError(t *testing.T) { t.Fatal("Did not return timeout error") } } + +func TestDeleteFineTuneModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/"+testFineTuneModelID, handleDeleteFineTuneModelEndpoint) + _, err := client.DeleteFineTuneModel(context.Background(), testFineTuneModelID) + checks.NoError(t, err, "DeleteFineTuneModel error") +} + +func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + fmt.Fprintln(w, string(resBytes)) +} From 84f77a0acda6eb541f3312ed8f7711c89e661443 Mon Sep 17 00:00:00 2001 From: "e. alvarez" <55966724+ealvar3z@users.noreply.github.com> Date: Mon, 2 Oct 2023 07:39:10 -0700 Subject: [PATCH 047/242] Add DotProduct Method and README Example for Embedding Similarity Search (#492) * Add DotProduct Method and README Example for Embedding Similarity Search - Implement a DotProduct() method for the Embedding struct to calculate the dot product between two embeddings. - Add a custom error type for vector length mismatch. - Update README.md with a complete example demonstrating how to perform an embedding similarity search for user queries. - Add unit tests to validate the new DotProduct() method and error handling. * Update README to focus on Embedding Semantic Similarity --- README.md | 56 ++++++++++++++++++++++++++++++++++++++++++++++ embeddings.go | 20 +++++++++++++++++ embeddings_test.go | 38 +++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/README.md b/README.md index 440c40968..c618cd7fa 100644 --- a/README.md +++ b/README.md @@ -483,6 +483,62 @@ func main() { ``` + +Embedding Semantic Similarity + +```go +package main + +import ( + "context" + "log" + openai "github.com/sashabaranov/go-openai" + +) + +func main() { + client := openai.NewClient("your-token") + + // Create an EmbeddingRequest for the user query + queryReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the user query + queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq) + if err != nil { + log.Fatal("Error creating query embedding:", err) + } + + // Create an EmbeddingRequest for the target text + targetReq := openai.EmbeddingRequest{ + Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, + Model: openai.AdaEmbeddingv2, + } + + // Create an embedding for the target text + targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq) + if err != nil { + log.Fatal("Error creating target embedding:", err) + } + + // Now that we have the embeddings for the user query and the target text, we + // can calculate their similarity. + queryEmbedding := queryResponse.Data[0] + targetEmbedding := targetResponse.Data[0] + + similarity, err := queryEmbedding.DotProduct(&targetEmbedding) + if err != nil { + log.Fatal("Error calculating dot product:", err) + } + + log.Printf("The similarity score between the query and the target is %f", similarity) +} + +``` + +
Azure OpenAI Embeddings diff --git a/embeddings.go b/embeddings.go index 5ba91f235..660bc24c3 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,10 +4,13 @@ import ( "context" "encoding/base64" "encoding/binary" + "errors" "math" "net/http" ) +var ErrVectorLengthMismatch = errors.New("vector length mismatch") + // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. type EmbeddingModel int @@ -124,6 +127,23 @@ type Embedding struct { Index int `json:"index"` } +// DotProduct calculates the dot product of the embedding vector with another +// embedding vector. Both vectors must have the same length; otherwise, an +// ErrVectorLengthMismatch is returned. The method returns the calculated dot +// product as a float32 value. +func (e *Embedding) DotProduct(other *Embedding) (float32, error) { + if len(e.Embedding) != len(other.Embedding) { + return 0, ErrVectorLengthMismatch + } + + var dotProduct float32 + for i := range e.Embedding { + dotProduct += e.Embedding[i] * other.Embedding[i] + } + + return dotProduct, nil +} + // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { Object string `json:"object"` diff --git a/embeddings_test.go b/embeddings_test.go index 9c48c5b8f..72e8c245f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -4,7 +4,9 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" + "math" "net/http" "reflect" "testing" @@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { }) } } + +func TestDotProduct(t *testing.T) { + v1 := &Embedding{Embedding: []float32{1, 2, 3}} + v2 := &Embedding{Embedding: []float32{2, 4, 6}} + expected := float32(28.0) + + result, err := v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1, 0}} + expected = float32(0.0) + + result, err = v1.DotProduct(v2) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if math.Abs(float64(result-expected)) > 1e-12 { + t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) + } + + // Test for VectorLengthMismatchError + v1 = &Embedding{Embedding: []float32{1, 0, 0}} + v2 = &Embedding{Embedding: []float32{0, 1}} + _, err = v1.DotProduct(v2) + if !errors.Is(err, ErrVectorLengthMismatch) { + t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) + } +} From 533935e4fc31f2542ef77d3e545a527c756b641c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 6 Oct 2023 11:32:21 +0200 Subject: [PATCH 048/242] fix: use any for n_epochs (#499) * fix: use custom marshaler for n_epochs * chore: use any for n_epochs --- fine_tuning_job.go | 2 +- fine_tuning_job_test.go | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index a840b7ec3..07b0c337c 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -24,7 +24,7 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs int `json:"n_epochs"` + Epochs any `json:"n_epochs,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index 519c6cd2d..f6d41c33d 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -21,8 +21,23 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/v1/fine_tuning/jobs", func(w http.ResponseWriter, r *http.Request) { - var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ := json.Marshal(FineTuningJob{ + Object: "fine_tuning.job", + ID: testFineTuninigJobID, + Model: "davinci-002", + CreatedAt: 1692661014, + FinishedAt: 1692661190, + FineTunedModel: "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + OrganizationID: "org-123", + ResultFiles: []string{"file-abc123"}, + Status: "succeeded", + ValidationFile: "", + TrainingFile: "file-abc123", + Hyperparameters: Hyperparameters{ + Epochs: "auto", + }, + TrainedTokens: 5768, + }) fmt.Fprintln(w, string(resBytes)) }, ) From 8e165dc9aadc9f7045b91dd1b02d6404940dc023 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Mon, 9 Oct 2023 17:41:54 +0200 Subject: [PATCH 049/242] Feat Add headers to openai responses (#506) * feat: add headers to http response * chore: add test * fix: rename to httpHeader --- audio.go | 19 ++++++++++++++++++- chat.go | 2 ++ chat_test.go | 30 ++++++++++++++++++++++++++++++ client.go | 20 +++++++++++++++++++- completion.go | 2 ++ edits.go | 2 ++ embeddings.go | 4 ++++ engines.go | 4 ++++ files.go | 4 ++++ fine_tunes.go | 8 ++++++++ fine_tuning_job.go | 4 ++++ image.go | 2 ++ models.go | 6 ++++++ moderation.go | 2 ++ 14 files changed, 107 insertions(+), 2 deletions(-) diff --git a/audio.go b/audio.go index 9f469159d..4cbe4fe64 100644 --- a/audio.go +++ b/audio.go @@ -63,6 +63,21 @@ type AudioResponse struct { Transient bool `json:"transient"` } `json:"segments"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI( if request.HasJSONResponse() { err = c.sendRequest(req, &response) } else { - err = c.sendRequest(req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 8d29b3237..df0e5f970 100644 --- a/chat.go +++ b/chat.go @@ -142,6 +142,8 @@ type ChatCompletionResponse struct { Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. diff --git a/chat_test.go b/chat_test.go index 38d66fa64..52cd0bdef 100644 --- a/chat_test.go +++ b/chat_test.go @@ -16,6 +16,11 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 5779a8e1c..19902285b 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,20 @@ type Client struct { createFormBuilder func(io.Writer) utils.FormBuilder } +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h httpHeader) Header() http.Header { + return http.Header(h) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... return req, nil } -func (c *Client) sendRequest(req *http.Request, v any) error { +func (c *Client) sendRequest(req *http.Request, v Response) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Check whether Content-Type is already set, Upload Files API requires @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return c.handleErrorResp(res) } + if v != nil { + v.SetHeader(res.Header) + } + return decodeResponse(res.Body, v) } diff --git a/completion.go b/completion.go index 7b9ae89e7..c7ff94afc 100644 --- a/completion.go +++ b/completion.go @@ -154,6 +154,8 @@ type CompletionResponse struct { Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well diff --git a/edits.go b/edits.go index 831aade2f..97d026029 100644 --- a/edits.go +++ b/edits.go @@ -28,6 +28,8 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } // Edits Perform an API call to the Edits endpoint. diff --git a/embeddings.go b/embeddings.go index 660bc24c3..7e2aa7eb0 100644 --- a/embeddings.go +++ b/embeddings.go @@ -150,6 +150,8 @@ type EmbeddingResponse struct { Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } type base64String string @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct { Data []Base64Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } // ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. diff --git a/engines.go b/engines.go index adf6025c2..5a0dba858 100644 --- a/engines.go +++ b/engines.go @@ -12,11 +12,15 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic diff --git a/files.go b/files.go index 8b933c362..9e521fbbe 100644 --- a/files.go +++ b/files.go @@ -25,11 +25,15 @@ type File struct { Status string `json:"status"` Purpose string `json:"purpose"` StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader } // CreateFile uploads a jsonl file to GPT3 diff --git a/fine_tunes.go b/fine_tunes.go index 7d3b59dbd..ca840781c 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -41,6 +41,8 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -69,6 +71,8 @@ type FineTuneHyperParams struct { type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -77,6 +81,8 @@ type FineTuneList struct { type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 07b0c337c..9dcb49de1 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -21,6 +21,8 @@ type FineTuningJob struct { ValidationFile string `json:"validation_file,omitempty"` ResultFiles []string `json:"result_files"` TrainedTokens int `json:"trained_tokens"` + + httpHeader } type Hyperparameters struct { @@ -39,6 +41,8 @@ type FineTuningJobEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` HasMore bool `json:"has_more"` + + httpHeader } type FineTuningJobEvent struct { diff --git a/image.go b/image.go index cb96f4f5e..4addcdb1e 100644 --- a/image.go +++ b/image.go @@ -33,6 +33,8 @@ type ImageRequest struct { type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + + httpHeader } // ImageResponseDataInner represents a response data structure for image API. diff --git a/models.go b/models.go index c207f0a86..d94f98836 100644 --- a/models.go +++ b/models.go @@ -15,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, diff --git a/moderation.go b/moderation.go index a32f123f3..f8d20ee51 100644 --- a/moderation.go +++ b/moderation.go @@ -69,6 +69,8 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string. From b77d01edca43500f267c4b43333f645b84a4fcf0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 10 Oct 2023 10:29:41 -0500 Subject: [PATCH 050/242] Support get http header and x-ratelimit-* headers (#507) * feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct --- chat_stream_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++-- chat_test.go | 53 +++++++++++++++++++++++++++ client.go | 9 ++++- ratelimit.go | 43 ++++++++++++++++++++++ stream_reader.go | 2 + 5 files changed, 191 insertions(+), 5 deletions(-) create mode 100644 ratelimit.go diff --git a/chat_stream_test.go b/chat_stream_test.go index 5fc70b032..2c109d454 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,15 +1,17 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" + "strconv" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 52cd0bdef..329b2b9cb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -21,6 +21,17 @@ const ( xCustomHeaderValue = "test" ) +var ( + rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", + } +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) { } } +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + } + + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } resBytes, _ = json.Marshal(res) w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 19902285b..65ece812f 100644 --- a/client.go +++ b/client.go @@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) { *h = httpHeader(header) } -func (h httpHeader) Header() http.Header { - return http.Header(h) +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return newRateLimitHeaders(h.Header()) } // NewClient creates new OpenAI API client. @@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, + httpHeader: httpHeader(resp.Header), }, nil } diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 000000000..e8953f716 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} diff --git a/stream_reader.go b/stream_reader.go index 87e59e0ca..d17412591 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -27,6 +27,8 @@ type streamReader[T streamable] struct { response *http.Response errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) { From c47ddfc1a13b850115a80b03f3f9dd1822733bf7 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 10 Oct 2023 21:22:45 +0400 Subject: [PATCH 051/242] Update README.md (#511) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c618cd7fa..b41947be5 100644 --- a/README.md +++ b/README.md @@ -483,7 +483,7 @@ func main() { ```
- +
Embedding Semantic Similarity ```go @@ -537,7 +537,7 @@ func main() { } ``` - +
Azure OpenAI Embeddings From 6c52952b691ec294b7987689a5292a87a9acdbcb Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Mon, 6 Nov 2023 21:22:48 +0100 Subject: [PATCH 052/242] feat(completion): add constants for new GPT models (#520) Added constants for new GPT models including `gpt-4-1106-preview`, `gpt-4-vision-preview` and `gpt-3.5-turbo-1106`. The models were announced in the following blog post: https://openai.com/blog/new-models-and-developer-products-announced-at-devday --- completion.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/completion.go b/completion.go index c7ff94afc..2709c8b03 100644 --- a/completion.go +++ b/completion.go @@ -22,7 +22,10 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4TurboPreview = "gpt-4-1106-preview" + GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" @@ -69,9 +72,12 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, GPT40314: true, GPT40613: true, GPT432K: true, From 9e0232f941a0f2c1780bf20743effd051a39e4d3 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Mon, 6 Nov 2023 12:27:08 -0800 Subject: [PATCH 053/242] Fix typo in README: AdaEmbeddingV2 (#516) Copy-pasting the old sample caused compilation errors --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b41947be5..f0b609088 100644 --- a/README.md +++ b/README.md @@ -502,7 +502,7 @@ func main() { // Create an EmbeddingRequest for the user query queryReq := openai.EmbeddingRequest{ Input: []string{"How many chucks would a woodchuck chuck"}, - Model: openai.AdaEmbeddingv2, + Model: openai.AdaEmbeddingV2, } // Create an embedding for the user query @@ -514,7 +514,7 @@ func main() { // Create an EmbeddingRequest for the target text targetReq := openai.EmbeddingRequest{ Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"}, - Model: openai.AdaEmbeddingv2, + Model: openai.AdaEmbeddingV2, } // Create an embedding for the target text From 0664105387f52c99b13bb40fcbf966a8b8c8d838 Mon Sep 17 00:00:00 2001 From: Simon Klee Date: Tue, 7 Nov 2023 10:23:06 +0100 Subject: [PATCH 054/242] lint: fix linter warnings reported by golangci-lint (#522) - Fix #519 --- api_integration_test.go | 1 - audio_api_test.go | 14 ++-- audio_test.go | 2 +- chat_stream_test.go | 110 ++++++++++++++-------------- chat_test.go | 154 ++++++++++++++++++++-------------------- completion_test.go | 42 +++++------ config_test.go | 4 +- edits_test.go | 24 +++---- embeddings_test.go | 110 ++++++++++++++-------------- engines_test.go | 12 ++-- error_test.go | 60 ++++++++-------- example_test.go | 2 - files_api_test.go | 12 ++-- files_test.go | 6 +- fine_tunes.go | 1 + fine_tunes_test.go | 24 +++---- fine_tuning_job_test.go | 35 +++++---- image_api_test.go | 52 +++++++------- jsonschema/json_test.go | 62 ++++++++-------- models_test.go | 17 +++-- moderation_test.go | 52 +++++++------- openai_test.go | 14 ++-- stream_test.go | 46 ++++++------ 23 files changed, 425 insertions(+), 431 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 254fbeb03..6be188bc6 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,7 +9,6 @@ import ( "os" "testing" - . "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) diff --git a/audio_api_test.go b/audio_api_test.go index aad7a225a..a0efc7921 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -12,7 +12,7 @@ import ( "strings" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -26,7 +26,7 @@ func TestAudio(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -48,7 +48,7 @@ func TestAudio(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", } @@ -57,7 +57,7 @@ func TestAudio(t *testing.T) { }) t.Run(tc.name+" (with reader)", func(t *testing.T) { - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: "fake.webm", Reader: bytes.NewBuffer([]byte(`some webm binary data`)), Model: "whisper-3", @@ -76,7 +76,7 @@ func TestAudioWithOptionalArgs(t *testing.T) { testcases := []struct { name string - createFn func(context.Context, AudioRequest) (AudioResponse, error) + createFn func(context.Context, openai.AudioRequest) (openai.AudioResponse, error) }{ { "transcribe", @@ -98,13 +98,13 @@ func TestAudioWithOptionalArgs(t *testing.T) { path := filepath.Join(dir, "fake.mp3") test.CreateTestFile(t, path) - req := AudioRequest{ + req := openai.AudioRequest{ FilePath: path, Model: "whisper-3", Prompt: "用简体中文", Temperature: 0.5, Language: "zh", - Format: AudioResponseFormatSRT, + Format: openai.AudioResponseFormatSRT, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index e19a873f3..5346244c8 100644 --- a/audio_test.go +++ b/audio_test.go @@ -40,7 +40,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/chat_stream_test.go b/chat_stream_test.go index 2c109d454..bd571cb48 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -10,28 +10,28 @@ import ( "strconv" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, ErrChatCompletionInvalidModel) { + if !errors.Is(err, openai.ErrChatCompletionInvalidModel) { t.Fatalf("CreateChatCompletion should return ErrChatCompletionInvalidModel, but returned: %v", err) } } @@ -39,7 +39,7 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -61,12 +61,12 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -75,15 +75,15 @@ func TestCreateChatCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []ChatCompletionStreamResponse{ + expectedResponses := []openai.ChatCompletionStreamResponse{ { ID: "1", Object: "completion", Created: 1598069254, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response1", }, FinishReason: "max_tokens", @@ -94,10 +94,10 @@ func TestCreateChatCompletionStream(t *testing.T) { ID: "2", Object: "completion", Created: 1598069255, - Model: GPT3Dot5Turbo, - Choices: []ChatCompletionStreamChoice{ + Model: openai.GPT3Dot5Turbo, + Choices: []openai.ChatCompletionStreamChoice{ { - Delta: ChatCompletionStreamChoiceDelta{ + Delta: openai.ChatCompletionStreamChoiceDelta{ Content: "response2", }, FinishReason: "max_tokens", @@ -133,7 +133,7 @@ func TestCreateChatCompletionStream(t *testing.T) { func TestCreateChatCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -156,12 +156,12 @@ func TestCreateChatCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -173,7 +173,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -183,7 +183,7 @@ func TestCreateChatCompletionStreamError(t *testing.T) { func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set(xCustomHeader, xCustomHeaderValue) @@ -196,12 +196,12 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -219,7 +219,7 @@ func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") for k, v := range rateLimitHeaders { switch val := v.(type) { @@ -239,12 +239,12 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -264,7 +264,7 @@ func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -276,12 +276,12 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,7 +293,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -303,7 +303,7 @@ func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -317,18 +317,18 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { _, err := w.Write(dataBytes) checks.NoError(t, err, "Write error") }) - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, Stream: true, }) - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Errorf("TestCreateChatCompletionStreamRateLimitError did not return APIError") } @@ -345,7 +345,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() server.RegisterHandler("/openai/deployments/gpt-35-turbo/chat/completions", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) // Send test responses @@ -355,13 +355,13 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - apiErr := &APIError{} - _, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + apiErr := &openai.APIError{} + _, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -387,7 +387,7 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } // Helper funcs. -func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { +func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -402,7 +402,7 @@ func compareChatResponses(r1, r2 ChatCompletionStreamResponse) bool { return true } -func compareChatStreamResponseChoices(c1, c2 ChatCompletionStreamChoice) bool { +func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false } diff --git a/chat_test.go b/chat_test.go index 329b2b9cb..5bf1eaf6c 100644 --- a/chat_test.go +++ b/chat_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -21,49 +21,47 @@ const ( xCustomHeaderValue = "test" ) -var ( - rateLimitHeaders = map[string]any{ - "x-ratelimit-limit-requests": 60, - "x-ratelimit-limit-tokens": 150000, - "x-ratelimit-remaining-requests": 59, - "x-ratelimit-remaining-tokens": 149984, - "x-ratelimit-reset-requests": "1s", - "x-ratelimit-reset-tokens": "6m0s", - } -) +var rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", +} func TestChatCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ MaxTokens: 5, Model: "ada", - Messages: []ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, } _, err := client.CreateChatCompletion(ctx, req) msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) - checks.ErrorIs(t, err, ErrChatCompletionInvalidModel, msg) + checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } func TestChatCompletionsWithStream(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := ChatCompletionRequest{ + req := openai.ChatCompletionRequest{ Stream: true, } _, err := client.CreateChatCompletion(ctx, req) - checks.ErrorIs(t, err, ErrChatCompletionStreamNotSupported, "unexpected error") + checks.ErrorIs(t, err, openai.ErrChatCompletionStreamNotSupported, "unexpected error") } // TestCompletions Tests the completions endpoint of the API using the mocked server. @@ -71,12 +69,12 @@ func TestChatCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -89,12 +87,12 @@ func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -113,12 +111,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) - resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + resp, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -150,16 +148,16 @@ func TestChatCompletionsFunctions(t *testing.T) { t.Run("bytes", func(t *testing.T) { //nolint:lll msg := json.RawMessage(`{"properties":{"count":{"type":"integer","description":"total number of words in sentence"},"words":{"items":{"type":"string"},"type":"array","description":"list of words in sentence"}},"type":"object","required":["count","words"]}`) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -175,16 +173,16 @@ func TestChatCompletionsFunctions(t *testing.T) { Count: 2, Words: []string{"hello", "world"}, } - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &msg, }}, @@ -192,16 +190,16 @@ func TestChatCompletionsFunctions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion with functions error") }) t.Run("JSONSchemaDefinition", func(t *testing.T) { - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -229,16 +227,16 @@ func TestChatCompletionsFunctions(t *testing.T) { }) t.Run("JSONSchemaDefinitionWithFunctionDefine", func(t *testing.T) { // this is a compatibility check - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo0613, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, - Functions: []FunctionDefine{{ + Functions: []openai.FunctionDefine{{ Name: "test", Parameters: &jsonschema.Definition{ Type: jsonschema.Object, @@ -271,12 +269,12 @@ func TestAzureChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) - _, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -293,12 +291,12 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq ChatCompletionRequest + var completionReq openai.ChatCompletionRequest if completionReq, err = getChatCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ChatCompletionResponse{ + res := openai.ChatCompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -323,11 +321,11 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { return } - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleFunction, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleFunction, // this is valid json so it should be fine - FunctionCall: &FunctionCall{ + FunctionCall: &openai.FunctionCall{ Name: completionReq.Functions[0].Name, Arguments: string(fcb), }, @@ -339,9 +337,9 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // generate a random string of length completionReq.Length completionStr := strings.Repeat("a", completionReq.MaxTokens) - res.Choices = append(res.Choices, ChatCompletionChoice{ - Message: ChatCompletionMessage{ - Role: ChatMessageRoleAssistant, + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, Content: completionStr, }, Index: i, @@ -349,7 +347,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens := numTokens(completionReq.Messages[0].Content) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -368,23 +366,23 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getChatCompletionBody Returns the body of the request to create a completion. -func getChatCompletionBody(r *http.Request) (ChatCompletionRequest, error) { - completion := ChatCompletionRequest{} +func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { + completion := openai.ChatCompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return ChatCompletionRequest{}, err + return openai.ChatCompletionRequest{}, err } return completion, nil } func TestFinishReason(t *testing.T) { - c := &ChatCompletionChoice{ - FinishReason: FinishReasonNull, + c := &openai.ChatCompletionChoice{ + FinishReason: openai.FinishReasonNull, } resBytes, _ := json.Marshal(c) if !strings.Contains(string(resBytes), `"finish_reason":null`) { @@ -398,11 +396,11 @@ func TestFinishReason(t *testing.T) { t.Error("null should not be quoted") } - otherReasons := []FinishReason{ - FinishReasonStop, - FinishReasonLength, - FinishReasonFunctionCall, - FinishReasonContentFilter, + otherReasons := []openai.FinishReason{ + openai.FinishReasonStop, + openai.FinishReasonLength, + openai.FinishReasonFunctionCall, + openai.FinishReasonContentFilter, } for _, r := range otherReasons { c.FinishReason = r diff --git a/completion_test.go b/completion_test.go index 844ef484f..89950bf94 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" @@ -14,33 +11,36 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletion( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } func TestCompletionWithStream(t *testing.T) { - config := DefaultConfig("whatever") - client := NewClientWithConfig(config) + config := openai.DefaultConfig("whatever") + client := openai.NewClientWithConfig(config) ctx := context.Background() - req := CompletionRequest{Stream: true} + req := openai.CompletionRequest{Stream: true} _, err := client.CreateCompletion(ctx, req) - if !errors.Is(err, ErrCompletionStreamNotSupported) { + if !errors.Is(err, openai.ErrCompletionStreamNotSupported) { t.Fatalf("CreateCompletion didn't return ErrCompletionStreamNotSupported") } } @@ -50,7 +50,7 @@ func TestCompletions(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - req := CompletionRequest{ + req := openai.CompletionRequest{ MaxTokens: 5, Model: "ada", Prompt: "Lorem ipsum", @@ -68,12 +68,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: time.Now().Unix(), @@ -93,14 +93,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if completionReq.Echo { completionStr = completionReq.Prompt.(string) + completionStr } - res.Choices = append(res.Choices, CompletionChoice{ + res.Choices = append(res.Choices, openai.CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt.(string)) * n completionTokens := completionReq.MaxTokens * n - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -110,16 +110,16 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { + completion := openai.CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } return completion, nil } diff --git a/config_test.go b/config_test.go index 488511b11..3e528c3e9 100644 --- a/config_test.go +++ b/config_test.go @@ -3,7 +3,7 @@ package openai_test import ( "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestGetAzureDeploymentByModel(t *testing.T) { @@ -49,7 +49,7 @@ func TestGetAzureDeploymentByModel(t *testing.T) { for _, c := range cases { t.Run(c.Model, func(t *testing.T) { - conf := DefaultAzureConfig("", "https://test.openai.azure.com/") + conf := openai.DefaultAzureConfig("", "https://test.openai.azure.com/") if c.AzureModelMapperFunc != nil { conf.AzureModelMapperFunc = c.AzureModelMapperFunc } diff --git a/edits_test.go b/edits_test.go index c0bb84392..d2a6db40d 100644 --- a/edits_test.go +++ b/edits_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -11,6 +8,9 @@ import ( "net/http" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestEdits Tests the edits endpoint of the API using the mocked server. @@ -20,7 +20,7 @@ func TestEdits(t *testing.T) { server.RegisterHandler("/v1/edits", handleEditEndpoint) // create an edit request model := "ada" - editReq := EditsRequest{ + editReq := openai.EditsRequest{ Model: &model, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -45,14 +45,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ + res := openai.EditsResponse{ Object: "test-object", Created: time.Now().Unix(), } @@ -62,12 +62,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -77,16 +77,16 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (openai.EditsRequest, error) { + edit := openai.EditsRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } return edit, nil } diff --git a/embeddings_test.go b/embeddings_test.go index 72e8c245f..af04d96bf 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -11,32 +11,32 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []openai.EmbeddingModel{ + openai.AdaSimilarity, + openai.BabbageSimilarity, + openai.CurieSimilarity, + openai.DavinciSimilarity, + openai.AdaSearchDocument, + openai.AdaSearchQuery, + openai.BabbageSearchDocument, + openai.BabbageSearchQuery, + openai.CurieSearchDocument, + openai.CurieSearchQuery, + openai.DavinciSearchDocument, + openai.DavinciSearchQuery, + openai.AdaCodeSearchCode, + openai.AdaCodeSearchText, + openai.BabbageCodeSearchCode, + openai.BabbageCodeSearchText, } for _, model := range embeddedModels { // test embedding request with strings (simple embedding request) - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -52,7 +52,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with strings - embeddingReqStrings := EmbeddingRequestStrings{ + embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -66,7 +66,7 @@ func TestEmbedding(t *testing.T) { } // test embedding request with tokens - embeddingReqTokens := EmbeddingRequestTokens{ + embeddingReqTokens := openai.EmbeddingRequestTokens{ Input: [][]int{ {464, 2057, 373, 12625, 290, 262, 46612}, {6395, 6096, 286, 11525, 12083, 2581}, @@ -82,17 +82,17 @@ func TestEmbedding(t *testing.T) { } func TestEmbeddingModel(t *testing.T) { - var em EmbeddingModel + var em openai.EmbeddingModel err := em.UnmarshalText([]byte("text-similarity-ada-001")) checks.NoError(t, err, "Could not marshal embedding model") - if em != AdaSimilarity { + if em != openai.AdaSimilarity { t.Errorf("Model is not equal to AdaSimilarity") } err = em.UnmarshalText([]byte("some-non-existent-model")) checks.NoError(t, err, "Could not marshal embedding model") - if em != Unknown { + if em != openai.Unknown { t.Errorf("Model is not equal to Unknown") } } @@ -101,12 +101,12 @@ func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - sampleEmbeddings := []Embedding{ + sampleEmbeddings := []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, } - sampleBase64Embeddings := []Base64Embedding{ + sampleBase64Embeddings := []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, } @@ -115,8 +115,8 @@ func TestEmbeddingEndpoint(t *testing.T) { "/v1/embeddings", func(w http.ResponseWriter, r *http.Request) { var req struct { - EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"` - User string `json:"user"` + EncodingFormat openai.EmbeddingEncodingFormat `json:"encoding_format"` + User string `json:"user"` } _ = json.NewDecoder(r.Body).Decode(&req) @@ -125,16 +125,16 @@ func TestEmbeddingEndpoint(t *testing.T) { case req.User == "invalid": w.WriteHeader(http.StatusBadRequest) return - case req.EncodingFormat == EmbeddingEncodingFormatBase64: - resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings}) + case req.EncodingFormat == openai.EmbeddingEncodingFormatBase64: + resBytes, _ = json.Marshal(openai.EmbeddingResponseBase64{Data: sampleBase64Embeddings}) default: - resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings}) + resBytes, _ = json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) } fmt.Fprintln(w, string(resBytes)) }, ) // test create embeddings with strings (simple embedding request) - res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{}) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{}) checks.NoError(t, err, "CreateEmbeddings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) @@ -143,8 +143,8 @@ func TestEmbeddingEndpoint(t *testing.T) { // test create embeddings with strings (simple embedding request) res, err = client.CreateEmbeddings( context.Background(), - EmbeddingRequest{ - EncodingFormat: EmbeddingEncodingFormatBase64, + openai.EmbeddingRequest{ + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }, ) checks.NoError(t, err, "CreateEmbeddings error") @@ -153,23 +153,23 @@ func TestEmbeddingEndpoint(t *testing.T) { } // test create embeddings with strings - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestStrings{}) checks.NoError(t, err, "CreateEmbeddings strings error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test create embeddings with tokens - res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{}) + res, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequestTokens{}) checks.NoError(t, err, "CreateEmbeddings tokens error") if !reflect.DeepEqual(res.Data, sampleEmbeddings) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } // test failed sendRequest - _, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{ + _, err = client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ User: "invalid", - EncodingFormat: EmbeddingEncodingFormatBase64, + EncodingFormat: openai.EmbeddingEncodingFormatBase64, }) checks.HasError(t, err, "CreateEmbeddings error") } @@ -177,26 +177,26 @@ func TestEmbeddingEndpoint(t *testing.T) { func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string - Data []Base64Embedding - Model EmbeddingModel - Usage Usage + Data []openai.Base64Embedding + Model openai.EmbeddingModel + Usage openai.Usage } tests := []struct { name string fields fields - want EmbeddingResponse + want openai.EmbeddingResponse wantErr bool }{ { name: "test embedding response base64 to embedding response", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ {Embedding: "pHCdP4XrkUDhevxA"}, {Embedding: "/1jku0G/rLvA/EI8"}, }, }, - want: EmbeddingResponse{ - Data: []Embedding{ + want: openai.EmbeddingResponse{ + Data: []openai.Embedding{ {Embedding: []float32{1.23, 4.56, 7.89}}, {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, }, @@ -206,19 +206,19 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { { name: "Invalid embedding", fields: fields{ - Data: []Base64Embedding{ + Data: []openai.Base64Embedding{ { Embedding: "----", }, }, }, - want: EmbeddingResponse{}, + want: openai.EmbeddingResponse{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := &EmbeddingResponseBase64{ + r := &openai.EmbeddingResponseBase64{ Object: tt.fields.Object, Data: tt.fields.Data, Model: tt.fields.Model, @@ -237,8 +237,8 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { } func TestDotProduct(t *testing.T) { - v1 := &Embedding{Embedding: []float32{1, 2, 3}} - v2 := &Embedding{Embedding: []float32{2, 4, 6}} + v1 := &openai.Embedding{Embedding: []float32{1, 2, 3}} + v2 := &openai.Embedding{Embedding: []float32{2, 4, 6}} expected := float32(28.0) result, err := v1.DotProduct(v2) @@ -250,8 +250,8 @@ func TestDotProduct(t *testing.T) { t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result) } - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1, 0}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1, 0}} expected = float32(0.0) result, err = v1.DotProduct(v2) @@ -264,10 +264,10 @@ func TestDotProduct(t *testing.T) { } // Test for VectorLengthMismatchError - v1 = &Embedding{Embedding: []float32{1, 0, 0}} - v2 = &Embedding{Embedding: []float32{0, 1}} + v1 = &openai.Embedding{Embedding: []float32{1, 0, 0}} + v2 = &openai.Embedding{Embedding: []float32{0, 1}} _, err = v1.DotProduct(v2) - if !errors.Is(err, ErrVectorLengthMismatch) { + if !errors.Is(err, openai.ErrVectorLengthMismatch) { t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err) } } diff --git a/engines_test.go b/engines_test.go index 31e7ec8be..d26aa5541 100644 --- a/engines_test.go +++ b/engines_test.go @@ -7,7 +7,7 @@ import ( "net/http" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -15,8 +15,8 @@ import ( func TestGetEngine(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(Engine{}) + server.RegisterHandler("/v1/engines/text-davinci-003", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.Engine{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetEngine(context.Background(), "text-davinci-003") @@ -27,8 +27,8 @@ func TestGetEngine(t *testing.T) { func TestListEngines(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(EnginesList{}) + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.EnginesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListEngines(context.Background()) @@ -38,7 +38,7 @@ func TestListEngines(t *testing.T) { func TestListEnginesReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTeapot) }) diff --git a/error_test.go b/error_test.go index a0806b7ed..48cbe4f29 100644 --- a/error_test.go +++ b/error_test.go @@ -6,7 +6,7 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" ) func TestAPIErrorUnmarshalJSON(t *testing.T) { @@ -14,7 +14,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name string response string hasError bool - checkFunc func(t *testing.T, apiErr APIError) + checkFunc func(t *testing.T, apiErr openai.APIError) } testCases := []testCase{ // testcase for message field @@ -22,7 +22,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is string", response: `{"message":"foo","type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -30,7 +30,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with single item", response: `{"message":["foo"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo") }, }, @@ -38,7 +38,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is array with multiple items", response: `{"message":["foo", "bar", "baz"],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "foo, bar, baz") }, }, @@ -46,7 +46,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is empty array", response: `{"message":[],"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -54,7 +54,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the message is null", response: `{"message":null,"type":"invalid_request_error","param":null,"code":null}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorMessage(t, apiErr, "") }, }, @@ -89,23 +89,23 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } }`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{ + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{ Code: "ResponsibleAIPolicyViolation", - ContentFilterResults: ContentFilterResults{ - Hate: Hate{ + ContentFilterResults: openai.ContentFilterResults{ + Hate: openai.Hate{ Filtered: false, Severity: "safe", }, - SelfHarm: SelfHarm{ + SelfHarm: openai.SelfHarm{ Filtered: false, Severity: "safe", }, - Sexual: Sexual{ + Sexual: openai.Sexual{ Filtered: true, Severity: "medium", }, - Violence: Violence{ + Violence: openai.Violence{ Filtered: false, Severity: "safe", }, @@ -117,16 +117,16 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the innerError is empty (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": {}}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { name: "parse succeeds when the innerError is not InnerError struct (Azure Openai)", response: `{"message": "","type": null,"param": "","code": "","status": 0,"innererror": "test"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { - assertAPIErrorInnerError(t, apiErr, &InnerError{}) + checkFunc: func(t *testing.T, apiErr openai.APIError) { + assertAPIErrorInnerError(t, apiErr, &openai.InnerError{}) }, }, { @@ -159,7 +159,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is int", response: `{"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, 418) }, }, @@ -167,7 +167,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is string", response: `{"code":"teapot","message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, "teapot") }, }, @@ -175,7 +175,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse succeeds when the code is not exists", response: `{"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: false, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) }, }, @@ -196,7 +196,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { name: "parse failed when the response is invalid json", response: `--- {"code":418,"message":"I'm a teapot","param":"prompt","type":"teapot_error"}`, hasError: true, - checkFunc: func(t *testing.T, apiErr APIError) { + checkFunc: func(t *testing.T, apiErr openai.APIError) { assertAPIErrorCode(t, apiErr, nil) assertAPIErrorMessage(t, apiErr, "") assertAPIErrorParam(t, apiErr, nil) @@ -206,7 +206,7 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - var apiErr APIError + var apiErr openai.APIError err := apiErr.UnmarshalJSON([]byte(tc.response)) if (err != nil) != tc.hasError { t.Errorf("Unexpected error: %v", err) @@ -218,19 +218,19 @@ func TestAPIErrorUnmarshalJSON(t *testing.T) { } } -func assertAPIErrorMessage(t *testing.T, apiErr APIError, expected string) { +func assertAPIErrorMessage(t *testing.T, apiErr openai.APIError, expected string) { if apiErr.Message != expected { t.Errorf("Unexpected APIError message: %v; expected: %s", apiErr, expected) } } -func assertAPIErrorInnerError(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorInnerError(t *testing.T, apiErr openai.APIError, expected interface{}) { if !reflect.DeepEqual(apiErr.InnerError, expected) { t.Errorf("Unexpected APIError InnerError: %v; expected: %v; ", apiErr, expected) } } -func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { +func assertAPIErrorCode(t *testing.T, apiErr openai.APIError, expected interface{}) { switch v := apiErr.Code.(type) { case int: if v != expected { @@ -246,25 +246,25 @@ func assertAPIErrorCode(t *testing.T, apiErr APIError, expected interface{}) { } } -func assertAPIErrorParam(t *testing.T, apiErr APIError, expected *string) { +func assertAPIErrorParam(t *testing.T, apiErr openai.APIError, expected *string) { if apiErr.Param != expected { t.Errorf("Unexpected APIError param: %v; expected: %s", apiErr, *expected) } } -func assertAPIErrorType(t *testing.T, apiErr APIError, typ string) { +func assertAPIErrorType(t *testing.T, apiErr openai.APIError, typ string) { if apiErr.Type != typ { t.Errorf("Unexpected API type: %v; expected: %s", apiErr, typ) } } func TestRequestError(t *testing.T) { - var err error = &RequestError{ + var err error = &openai.RequestError{ HTTPStatusCode: http.StatusTeapot, Err: errors.New("i am a teapot"), } - var reqErr *RequestError + var reqErr *openai.RequestError if !errors.As(err, &reqErr) { t.Fatalf("Error is not a RequestError: %+v", err) } diff --git a/example_test.go b/example_test.go index b5dfafea9..de67c57cd 100644 --- a/example_test.go +++ b/example_test.go @@ -28,7 +28,6 @@ func Example() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return @@ -319,7 +318,6 @@ func ExampleDefaultAzureConfig() { }, }, ) - if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return diff --git a/files_api_test.go b/files_api_test.go index 1cbc72894..330b88159 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -20,7 +20,7 @@ func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", handleCreateFile) - req := FileRequest{ + req := openai.FileRequest{ FileName: "test.go", FilePath: "client.go", Purpose: "fine-tune", @@ -57,7 +57,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { } defer file.Close() - var fileReq = File{ + fileReq := openai.File{ Bytes: int(header.Size), ID: strconv.Itoa(int(time.Now().Unix())), FileName: header.Filename, @@ -82,7 +82,7 @@ func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FilesList{}) + resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.ListFiles(context.Background()) @@ -93,7 +93,7 @@ func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(File{}) + resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) _, err := client.GetFile(context.Background(), "deadbeef") @@ -148,7 +148,7 @@ func TestGetFileContentReturnError(t *testing.T) { t.Fatal("Did not return error") } - apiErr := &APIError{} + apiErr := &openai.APIError{} if !errors.As(err, &apiErr) { t.Fatalf("Did not return APIError: %+v\n", apiErr) } diff --git a/files_test.go b/files_test.go index df6eaef7b..f588b30dc 100644 --- a/files_test.go +++ b/files_test.go @@ -1,14 +1,14 @@ package openai //nolint:testpackage // testing private field import ( - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "fmt" "io" "os" "testing" + + utils "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestFileUploadWithFailingFormBuilder(t *testing.T) { diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..46f89f165 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,6 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { + //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 67f681d97..2ab6817f7 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -1,14 +1,14 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneID = "fine-tune-id" @@ -22,9 +22,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodGet { - resBytes, _ = json.Marshal(FineTuneList{}) + resBytes, _ = json.Marshal(openai.FineTuneList{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -32,8 +32,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTune{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTune{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -43,9 +43,9 @@ func TestFineTunes(t *testing.T) { func(w http.ResponseWriter, r *http.Request) { var resBytes []byte if r.Method == http.MethodDelete { - resBytes, _ = json.Marshal(FineTuneDeleteResponse{}) + resBytes, _ = json.Marshal(openai.FineTuneDeleteResponse{}) } else { - resBytes, _ = json.Marshal(FineTune{}) + resBytes, _ = json.Marshal(openai.FineTune{}) } fmt.Fprintln(w, string(resBytes)) }, @@ -53,8 +53,8 @@ func TestFineTunes(t *testing.T) { server.RegisterHandler( "/v1/fine-tunes/"+testFineTuneID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuneEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuneEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) @@ -64,7 +64,7 @@ func TestFineTunes(t *testing.T) { _, err := client.ListFineTunes(ctx) checks.NoError(t, err, "ListFineTunes error") - _, err = client.CreateFineTune(ctx, FineTuneRequest{}) + _, err = client.CreateFineTune(ctx, openai.FineTuneRequest{}) checks.NoError(t, err, "CreateFineTune error") _, err = client.CancelFineTune(ctx, testFineTuneID) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index f6d41c33d..c892ef775 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -2,14 +2,13 @@ package openai_test import ( "context" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "encoding/json" "fmt" "net/http" "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuninigJobID = "fine-tuning-job-id" @@ -20,8 +19,8 @@ func TestFineTuningJob(t *testing.T) { defer teardown() server.RegisterHandler( "/v1/fine_tuning/jobs", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{ + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{ Object: "fine_tuning.job", ID: testFineTuninigJobID, Model: "davinci-002", @@ -33,7 +32,7 @@ func TestFineTuningJob(t *testing.T) { Status: "succeeded", ValidationFile: "", TrainingFile: "file-abc123", - Hyperparameters: Hyperparameters{ + Hyperparameters: openai.Hyperparameters{ Epochs: "auto", }, TrainedTokens: 5768, @@ -44,32 +43,32 @@ func TestFineTuningJob(t *testing.T) { server.RegisterHandler( "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJob{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID, - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { var resBytes []byte - resBytes, _ = json.Marshal(FineTuningJob{}) + resBytes, _ = json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) }, ) server.RegisterHandler( "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/events", - func(w http.ResponseWriter, r *http.Request) { - resBytes, _ := json.Marshal(FineTuningJobEventList{}) + func(w http.ResponseWriter, _ *http.Request) { + resBytes, _ := json.Marshal(openai.FineTuningJobEventList{}) fmt.Fprintln(w, string(resBytes)) }, ) ctx := context.Background() - _, err := client.CreateFineTuningJob(ctx, FineTuningJobRequest{}) + _, err := client.CreateFineTuningJob(ctx, openai.FineTuningJobRequest{}) checks.NoError(t, err, "CreateFineTuningJob error") _, err = client.CancelFineTuningJob(ctx, testFineTuninigJobID) @@ -84,22 +83,22 @@ func TestFineTuningJob(t *testing.T) { _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") _, err = client.ListFineTuningJobEvents( ctx, testFineTuninigJobID, - ListFineTuningJobEventsWithAfter("last-event-id"), - ListFineTuningJobEventsWithLimit(10), + openai.ListFineTuningJobEventsWithAfter("last-event-id"), + openai.ListFineTuningJobEventsWithLimit(10), ) checks.NoError(t, err, "ListFineTuningJobEvents error") } diff --git a/image_api_test.go b/image_api_test.go index b472eb04a..422f831fe 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -12,13 +9,16 @@ import ( "os" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestImages(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) - _, err := client.CreateImage(context.Background(), ImageRequest{ + _, err := client.CreateImage(context.Background(), openai.ImageRequest{ Prompt: "Lorem ipsum", }) checks.NoError(t, err, "CreateImage error") @@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq ImageRequest + var imageReq openai.ImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ImageResponse{ + res := openai.ImageResponse{ Created: time.Now().Unix(), } for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} + imageData := openai.ImageResponseDataInner{} switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": + case openai.CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: + case openai.CreateImageResponseFormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = "e30K" default: @@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +func getImageBody(r *http.Request) (openai.ImageRequest, error) { + image := openai.ImageRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } return image, nil } @@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Mask: mask, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateEditImage(context.Background(), ImageEditRequest{ + _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, Prompt: "There is a turtle in the pool", N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", @@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) { os.Remove("image.png") }() - _, err = client.CreateVariImage(context.Background(), ImageVariRequest{ + _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, N: 3, - Size: CreateImageSize1024x1024, - ResponseFormat: CreateImageResponseFormatURL, + Size: openai.CreateImageSize1024x1024, + ResponseFormat: openai.CreateImageResponseFormatURL, }) checks.NoError(t, err, "CreateImage error") } @@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - responses := ImageResponse{ + responses := openai.ImageResponse{ Created: time.Now().Unix(), - Data: []ImageResponseDataInner{ + Data: []openai.ImageResponseDataInner{ { URL: "test-url1", B64JSON: "", diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index c8d0c1d9e..744706082 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -5,28 +5,28 @@ import ( "reflect" "testing" - . "github.com/sashabaranov/go-openai/jsonschema" + "github.com/sashabaranov/go-openai/jsonschema" ) func TestDefinition_MarshalJSON(t *testing.T) { tests := []struct { name string - def Definition + def jsonschema.Definition want string }{ { name: "Test with empty Definition", - def: Definition{}, + def: jsonschema.Definition{}, want: `{"properties":{}}`, }, { name: "Test with Definition properties set", - def: Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.String, Description: "A string type", - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -43,17 +43,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with nested Definition properties", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, }, }, @@ -80,26 +80,26 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with complex nested Definition", - def: Definition{ - Type: Object, - Properties: map[string]Definition{ + def: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "user": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, "age": { - Type: Integer, + Type: jsonschema.Integer, }, "address": { - Type: Object, - Properties: map[string]Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ "city": { - Type: String, + Type: jsonschema.String, }, "country": { - Type: String, + Type: jsonschema.String, }, }, }, @@ -141,14 +141,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, { name: "Test with Array type Definition", - def: Definition{ - Type: Array, - Items: &Definition{ - Type: String, + def: jsonschema.Definition{ + Type: jsonschema.Array, + Items: &jsonschema.Definition{ + Type: jsonschema.String, }, - Properties: map[string]Definition{ + Properties: map[string]jsonschema.Definition{ "name": { - Type: String, + Type: jsonschema.String, }, }, }, diff --git a/models_test.go b/models_test.go index 9ff73042a..4a4c759dc 100644 --- a/models_test.go +++ b/models_test.go @@ -1,17 +1,16 @@ package openai_test import ( - "os" - "time" - - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" "net/http" + "os" "testing" + "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) const testFineTuneModelID = "fine-tune-model-id" @@ -35,7 +34,7 @@ func TestAzureListModels(t *testing.T) { // handleListModelsEndpoint Handles the list models endpoint by the test server. func handleListModelsEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(ModelsList{}) + resBytes, _ := json.Marshal(openai.ModelsList{}) fmt.Fprintln(w, string(resBytes)) } @@ -58,7 +57,7 @@ func TestAzureGetModel(t *testing.T) { // handleGetModelsEndpoint Handles the get model endpoint by the test server. func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(Model{}) + resBytes, _ := json.Marshal(openai.Model{}) fmt.Fprintln(w, string(resBytes)) } @@ -90,6 +89,6 @@ func TestDeleteFineTuneModel(t *testing.T) { } func handleDeleteFineTuneModelEndpoint(w http.ResponseWriter, _ *http.Request) { - resBytes, _ := json.Marshal(FineTuneModelDeleteResponse{}) + resBytes, _ := json.Marshal(openai.FineTuneModelDeleteResponse{}) fmt.Fprintln(w, string(resBytes)) } diff --git a/moderation_test.go b/moderation_test.go index 68f9565e1..059f0d1c7 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,9 +1,6 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "fmt" @@ -13,6 +10,9 @@ import ( "strings" "testing" "time" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) // TestModeration Tests the moderations endpoint of the API using the mocked server. @@ -20,8 +20,8 @@ func TestModerations(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) - _, err := client.Moderations(context.Background(), ModerationRequest{ - Model: ModerationTextStable, + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ + Model: openai.ModerationTextStable, Input: "I want to kill them.", }) checks.NoError(t, err, "Moderation error") @@ -34,16 +34,16 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { expect error } modelOptions = append(modelOptions, - getModerationModelTestOption(GPT3Dot5Turbo, ErrModerationInvalidModel), - getModerationModelTestOption(ModerationTextStable, nil), - getModerationModelTestOption(ModerationTextLatest, nil), + getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), + getModerationModelTestOption(openai.ModerationTextStable, nil), + getModerationModelTestOption(openai.ModerationTextLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/moderations", handleModerationEndpoint) for _, modelTest := range modelOptions { - _, err := client.Moderations(context.Background(), ModerationRequest{ + _, err := client.Moderations(context.Background(), openai.ModerationRequest{ Model: modelTest.model, Input: "I want to kill them.", }) @@ -71,32 +71,32 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var moderationReq ModerationRequest + var moderationReq openai.ModerationRequest if moderationReq, err = getModerationBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - resCat := ResultCategories{} - resCatScore := ResultCategoryScores{} + resCat := openai.ResultCategories{} + resCatScore := openai.ResultCategoryScores{} switch { case strings.Contains(moderationReq.Input, "kill"): - resCat = ResultCategories{Violence: true} - resCatScore = ResultCategoryScores{Violence: 1} + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): - resCat = ResultCategories{Hate: true} - resCatScore = ResultCategoryScores{Hate: 1} + resCat = openai.ResultCategories{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "suicide"): - resCat = ResultCategories{SelfHarm: true} - resCatScore = ResultCategoryScores{SelfHarm: 1} + resCat = openai.ResultCategories{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "porn"): - resCat = ResultCategories{Sexual: true} - resCatScore = ResultCategoryScores{Sexual: 1} + resCat = openai.ResultCategories{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} } - result := Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} + result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} - res := ModerationResponse{ + res := openai.ModerationResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Model: moderationReq.Model, } @@ -107,16 +107,16 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { } // getModerationBody Returns the body of the request to do a moderation. -func getModerationBody(r *http.Request) (ModerationRequest, error) { - moderation := ModerationRequest{} +func getModerationBody(r *http.Request) (openai.ModerationRequest, error) { + moderation := openai.ModerationRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } err = json.Unmarshal(reqBody, &moderation) if err != nil { - return ModerationRequest{}, err + return openai.ModerationRequest{}, err } return moderation, nil } diff --git a/openai_test.go b/openai_test.go index 4fc41ecc0..729d8880c 100644 --- a/openai_test.go +++ b/openai_test.go @@ -1,29 +1,29 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test" ) -func setupOpenAITestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupOpenAITestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultConfig(test.GetTestToken()) + config := openai.DefaultConfig(test.GetTestToken()) config.BaseURL = ts.URL + "/v1" - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } -func setupAzureTestServer() (client *Client, server *test.ServerTest, teardown func()) { +func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, teardown func()) { server = test.NewTestServer() ts := server.OpenAITestServer() ts.Start() teardown = ts.Close - config := DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") + config := openai.DefaultAzureConfig(test.GetTestToken(), "https://dummylab.openai.azure.com/") config.BaseURL = ts.URL - client = NewClientWithConfig(config) + client = openai.NewClientWithConfig(config) return } diff --git a/stream_test.go b/stream_test.go index f3f8f85cd..35c52ae3b 100644 --- a/stream_test.go +++ b/stream_test.go @@ -10,23 +10,23 @@ import ( "testing" "time" - . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestCompletionsStreamWrongModel(t *testing.T) { - config := DefaultConfig("whatever") + config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" - client := NewClientWithConfig(config) + client := openai.NewClientWithConfig(config) _, err := client.CreateCompletionStream( context.Background(), - CompletionRequest{ + openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Dot5Turbo, + Model: openai.GPT3Dot5Turbo, }, ) - if !errors.Is(err, ErrCompletionUnsupportedModel) { + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", err) } } @@ -56,7 +56,7 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -65,20 +65,20 @@ func TestCreateCompletionStream(t *testing.T) { checks.NoError(t, err, "CreateCompletionStream returned error") defer stream.Close() - expectedResponses := []CompletionResponse{ + expectedResponses := []openai.CompletionResponse{ { ID: "1", Object: "completion", Created: 1598069254, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response1", FinishReason: "max_tokens"}}, }, { ID: "2", Object: "completion", Created: 1598069255, Model: "text-davinci-002", - Choices: []CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, + Choices: []openai.CompletionChoice{{Text: "response2", FinishReason: "max_tokens"}}, }, } @@ -129,9 +129,9 @@ func TestCreateCompletionStreamError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3TextDavinci003, + Model: openai.GPT3TextDavinci003, Prompt: "Hello!", Stream: true, }) @@ -141,7 +141,7 @@ func TestCreateCompletionStreamError(t *testing.T) { _, streamErr := stream.Recv() checks.HasError(t, streamErr, "stream.Recv() did not return error") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(streamErr, &apiErr) { t.Errorf("stream.Recv() did not return APIError") } @@ -166,10 +166,10 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { checks.NoError(t, err, "Write error") }) - var apiErr *APIError - _, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + var apiErr *openai.APIError + _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: GPT3Ada, + Model: openai.GPT3Ada, Prompt: "Hello!", Stream: true, }) @@ -209,7 +209,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -220,7 +220,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { _, _ = stream.Recv() _, streamErr := stream.Recv() - if !errors.Is(streamErr, ErrTooManyEmptyStreamMessages) { + if !errors.Is(streamErr, openai.ErrTooManyEmptyStreamMessages) { t.Errorf("TestCreateCompletionStreamTooManyEmptyStreamMessagesError did not return ErrTooManyEmptyStreamMessages") } } @@ -244,7 +244,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -285,7 +285,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { checks.NoError(t, err, "Write error") }) - stream, err := client.CreateCompletionStream(context.Background(), CompletionRequest{ + stream, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -312,7 +312,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) defer cancel() - _, err := client.CreateCompletionStream(ctx, CompletionRequest{ + _, err := client.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", Model: "text-davinci-002", MaxTokens: 10, @@ -327,7 +327,7 @@ func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { } // Helper funcs. -func compareResponses(r1, r2 CompletionResponse) bool { +func compareResponses(r1, r2 openai.CompletionResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { return false } @@ -342,7 +342,7 @@ func compareResponses(r1, r2 CompletionResponse) bool { return true } -func compareResponseChoices(c1, c2 CompletionChoice) bool { +func compareResponseChoices(c1, c2 openai.CompletionChoice) bool { if c1.Text != c2.Text || c1.FinishReason != c2.FinishReason { return false } From d07833e19bfbb2f26011c8881f7fb61366c07e75 Mon Sep 17 00:00:00 2001 From: Carson Kahn Date: Tue, 7 Nov 2023 04:27:29 -0500 Subject: [PATCH 055/242] Doc ways to improve reproducability besides Temp (#532) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f0b609088..4cb77db6b 100644 --- a/README.md +++ b/README.md @@ -757,8 +757,9 @@ Even when specifying a temperature field of 0, it doesn't guarantee that you'll Due to the factors mentioned above, different answers may be returned even for the same question. **Workarounds:** -1. Using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. -2. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. +1. As of November 2023, use [the new `seed` parameter](https://platform.openai.com/docs/guides/text-generation/reproducible-outputs) in conjunction with the `system_fingerprint` response field, alongside Temperature management. +2. Try using `math.SmallestNonzeroFloat32`: By specifying `math.SmallestNonzeroFloat32` in the temperature field instead of 0, you can mimic the behavior of setting it to 0. +3. Limiting Token Count: By limiting the number of tokens in the input and output and especially avoiding large requests close to 32k tokens, you can reduce the risk of non-deterministic behavior. By adopting these strategies, you can expect more consistent results. From 6d9c3a6365643d02692ecc6f0b34a5fa3e7fea45 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 15:25:21 +0100 Subject: [PATCH 056/242] Feat Support chat completion response format and seed new fields (#525) * feat: support chat completion response format * fix linting error * fix * fix linting * Revert "fix linting" This reverts commit 015c6ad62aad561218b693225f58670b5619dba8. * Revert "fix" This reverts commit 7b2ffe28c3e586b629d23479ec1728bf52f0c66f. * Revert "fix linting error" This reverts commit 29960423784e296cb6d22c5db8f8ccf00cac59fd. * chore: add seed new parameter * fix --- chat.go | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/chat.go b/chat.go index df0e5f970..88db8cf1d 100644 --- a/chat.go +++ b/chat.go @@ -69,18 +69,31 @@ type FunctionCall struct { Arguments string `json:"arguments,omitempty"` } +type ChatCompletionResponseFormatType string + +const ( + ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" +) + +type ChatCompletionResponseFormat struct { + Type ChatCompletionResponseFormatType `json:"type"` +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias From 3063e676bf5932024d76be8e8d9e41df06d4e8cc Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:20:59 +0100 Subject: [PATCH 057/242] Feat Implement assistants API (#535) * chore: implement assistants API * fix * fix * chore: add tests * fix tests * fix linting --- assistant.go | 260 ++++++++++++++++++++++++++++++++++++++++++++++ assistant_test.go | 202 +++++++++++++++++++++++++++++++++++ client_test.go | 27 +++++ 3 files changed, 489 insertions(+) create mode 100644 assistant.go create mode 100644 assistant_test.go diff --git a/assistant.go b/assistant.go new file mode 100644 index 000000000..d75eebef3 --- /dev/null +++ b/assistant.go @@ -0,0 +1,260 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + assistantsSuffix = "/assistants" + assistantsFilesSuffix = "/files" +) + +type Assistant struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + + httpHeader +} + +type AssistantTool struct { + Type string `json:"type"` +} + +type AssistantToolCodeInterpreter struct { + AssistantTool +} + +type AssistantToolRetrieval struct { + AssistantTool +} + +type AssistantToolFunction struct { + AssistantTool + Function FunctionDefinition `json:"function"` +} + +type AssistantRequest struct { + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []any `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// AssistantsList is a list of assistants. +type AssistantsList struct { + Assistants []Assistant `json:"data"` + + httpHeader +} + +type AssistantFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + + httpHeader +} + +type AssistantFileRequest struct { + FileID string `json:"file_id"` +} + +type AssistantFilesList struct { + AssistantFiles []AssistantFile `json:"data"` + + httpHeader +} + +// CreateAssistant creates a new assistant. +func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistant retrieves an assistant. +func (c *Client) RetrieveAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyAssistant modifies an assistant. +func (c *Client) ModifyAssistant( + ctx context.Context, + assistantID string, + request AssistantRequest, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistant deletes an assistant. +func (c *Client) DeleteAssistant( + ctx context.Context, + assistantID string, +) (response Assistant, err error) { + urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListAssistants Lists the currently available assistants. +func (c *Client) ListAssistants( + ctx context.Context, + limit *int, + order *string, + after *string, + before *string, +) (reponse AssistantsList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &reponse) + return +} + +// CreateAssistantFile creates a new assistant file. +func (c *Client) CreateAssistantFile( + ctx context.Context, + assistantID string, + request AssistantFileRequest, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveAssistantFile retrieves an assistant file. +func (c *Client) RetrieveAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (response AssistantFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteAssistantFile deletes an existing file. +func (c *Client) DeleteAssistantFile( + ctx context.Context, + assistantID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, nil) + return +} + +// ListAssistantFiles Lists the currently available files for an assistant. +func (c *Client) ListAssistantFiles( + ctx context.Context, + assistantID string, + limit *int, + order *string, + after *string, + before *string, +) (response AssistantFilesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/assistant_test.go b/assistant_test.go new file mode 100644 index 000000000..eb6f42458 --- /dev/null +++ b/assistant_test.go @@ -0,0 +1,202 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assitantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assitantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assitantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client_test.go b/client_test.go index 2c1d749ed..bff2597c5 100644 --- a/client_test.go +++ b/client_test.go @@ -274,6 +274,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteFineTuneModel", func() (any, error) { return client.DeleteFineTuneModel(ctx, "") }}, + {"CreateAssistant", func() (any, error) { + return client.CreateAssistant(ctx, AssistantRequest{}) + }}, + {"RetrieveAssistant", func() (any, error) { + return client.RetrieveAssistant(ctx, "") + }}, + {"ModifyAssistant", func() (any, error) { + return client.ModifyAssistant(ctx, "", AssistantRequest{}) + }}, + {"DeleteAssistant", func() (any, error) { + return client.DeleteAssistant(ctx, "") + }}, + {"ListAssistants", func() (any, error) { + return client.ListAssistants(ctx, nil, nil, nil, nil) + }}, + {"CreateAssistantFile", func() (any, error) { + return client.CreateAssistantFile(ctx, "", AssistantFileRequest{}) + }}, + {"ListAssistantFiles", func() (any, error) { + return client.ListAssistantFiles(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveAssistantFile", func() (any, error) { + return client.RetrieveAssistantFile(ctx, "", "") + }}, + {"DeleteAssistantFile", func() (any, error) { + return nil, client.DeleteAssistantFile(ctx, "", "") + }}, } for _, testCase := range testCases { From 1ad6b6f53dcd9abfaf56e8adb02b5b599936580c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 16:53:24 +0100 Subject: [PATCH 058/242] Feat Support tools and tools choice new fileds (#526) * feat: support tools and tools choice new fileds * fix: use value not pointers --- chat.go | 41 +++++++++++++++++++++++++++++++++++++---- chat_stream.go | 1 + 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/chat.go b/chat.go index 88db8cf1d..04303184a 100644 --- a/chat.go +++ b/chat.go @@ -12,6 +12,7 @@ const ( ChatMessageRoleUser = "user" ChatMessageRoleAssistant = "assistant" ChatMessageRoleFunction = "function" + ChatMessageRoleTool = "tool" ) const chatCompletionsSuffix = "/chat/completions" @@ -61,6 +62,12 @@ type ChatCompletionMessage struct { Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Function FunctionCall `json:"function"` } type FunctionCall struct { @@ -97,10 +104,35 @@ type ChatCompletionRequest struct { // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` - Functions []FunctionDefinition `json:"functions,omitempty"` - FunctionCall any `json:"function_call,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + User string `json:"user,omitempty"` + // Deprecated: use Tools instead. + Functions []FunctionDefinition `json:"functions,omitempty"` + // Deprecated: use ToolChoice instead. + FunctionCall any `json:"function_call,omitempty"` + Tools []Tool `json:"tools,omitempty"` + // This can be either a string or an ToolChoice object. + ToolChoiche any `json:"tool_choice,omitempty"` +} + +type ToolType string + +const ( + ToolTypeFunction ToolType = "function" +) + +type Tool struct { + Type ToolType `json:"type"` + Function FunctionDefinition `json:"function,omitempty"` +} + +type ToolChoiche struct { + Type ToolType `json:"type"` + Function ToolFunction `json:"function,omitempty"` +} + +type ToolFunction struct { + Name string `json:"name"` } type FunctionDefinition struct { @@ -123,6 +155,7 @@ const ( FinishReasonStop FinishReason = "stop" FinishReasonLength FinishReason = "length" FinishReasonFunctionCall FinishReason = "function_call" + FinishReasonToolCalls FinishReason = "tool_calls" FinishReasonContentFilter FinishReason = "content_filter" FinishReasonNull FinishReason = "null" ) diff --git a/chat_stream.go b/chat_stream.go index f1faa3964..57cfa789f 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -9,6 +9,7 @@ type ChatCompletionStreamChoiceDelta struct { Content string `json:"content,omitempty"` Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type ChatCompletionStreamChoice struct { From a20eb08b79e5c34882888a401020b47c145357ff Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Tue, 7 Nov 2023 22:30:05 +0100 Subject: [PATCH 059/242] fix: use pointer for ChatCompletionResponseFormat (#544) --- chat.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index 04303184a..609e0c311 100644 --- a/chat.go +++ b/chat.go @@ -89,18 +89,18 @@ type ChatCompletionResponseFormat struct { // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias From a0159ad2b00e4f127222814694bec68863395543 Mon Sep 17 00:00:00 2001 From: Mike Cutalo Date: Tue, 7 Nov 2023 23:16:22 -0800 Subject: [PATCH 060/242] Support new fields for /v1/images/generation API (#530) * add support for new image/generation api * fix one lint * add revised_prompt to response * fix lints * add CreateImageQualityStandard --- image.go | 26 ++++++++++++++++++++++++-- image_api_test.go | 9 ++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/image.go b/image.go index 4addcdb1e..4fe8b3a32 100644 --- a/image.go +++ b/image.go @@ -13,6 +13,9 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. + CreateImageSize1792x1024 = "1792x1024" + CreateImageSize1024x1792 = "1024x1792" ) const ( @@ -20,11 +23,29 @@ const ( CreateImageResponseFormatB64JSON = "b64_json" ) +const ( + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" +) + +const ( + CreateImageQualityHD = "hd" + CreateImageQualityStandard = "standard" +) + +const ( + CreateImageStyleVivid = "vivid" + CreateImageStyleNatural = "natural" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` ResponseFormat string `json:"response_format,omitempty"` User string `json:"user,omitempty"` } @@ -39,8 +60,9 @@ type ImageResponse struct { // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` } // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. diff --git a/image_api_test.go b/image_api_test.go index 422f831fe..2eb46f2b4 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -19,7 +19,14 @@ func TestImages(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/generations", handleImageEndpoint) _, err := client.CreateImage(context.Background(), openai.ImageRequest{ - Prompt: "Lorem ipsum", + Prompt: "Lorem ipsum", + Model: openai.CreateImageModelDallE3, + N: 1, + Quality: openai.CreateImageQualityHD, + Size: openai.CreateImageSize1024x1024, + Style: openai.CreateImageStyleVivid, + ResponseFormat: openai.CreateImageResponseFormatURL, + User: "user", }) checks.NoError(t, err, "CreateImage error") } From a2d2bf685122fd51d768f2a828787cae587d9ad6 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 8 Nov 2023 10:20:20 +0100 Subject: [PATCH 061/242] Fix Refactor assistant api (#545) * fix: refactor assistant API * fix * trigger build * fix: use AssistantDeleteResponse --- assistant.go | 90 ++++++++++++++++++++++++++++++---------------------- client.go | 6 ++++ 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/assistant.go b/assistant.go index d75eebef3..de49be680 100644 --- a/assistant.go +++ b/assistant.go @@ -10,46 +10,43 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" + openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []any `json:"tools,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools,omitempty"` httpHeader } -type AssistantTool struct { - Type string `json:"type"` -} - -type AssistantToolCodeInterpreter struct { - AssistantTool -} +type AssistantToolType string -type AssistantToolRetrieval struct { - AssistantTool -} +const ( + AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" + AssistantToolTypeRetrieval AssistantToolType = "retrieval" + AssistantToolTypeFunction AssistantToolType = "function" +) -type AssistantToolFunction struct { - AssistantTool - Function FunctionDefinition `json:"function"` +type AssistantTool struct { + Type AssistantToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` } type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []any `json:"tools,omitempty"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } // AssistantsList is a list of assistants. @@ -59,6 +56,14 @@ type AssistantsList struct { httpHeader } +type AssistantDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + type AssistantFile struct { ID string `json:"id"` Object string `json:"object"` @@ -80,7 +85,8 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -95,7 +101,8 @@ func (c *Client) RetrieveAssistant( assistantID string, ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -111,7 +118,8 @@ func (c *Client) ModifyAssistant( request AssistantRequest, ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -124,9 +132,10 @@ func (c *Client) ModifyAssistant( func (c *Client) DeleteAssistant( ctx context.Context, assistantID string, -) (response Assistant, err error) { +) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) - req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -163,7 +172,8 @@ func (c *Client) ListAssistants( } urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -180,7 +190,8 @@ func (c *Client) CreateAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(request)) + withBody(request), + withBetaAssistantV1()) if err != nil { return } @@ -196,7 +207,8 @@ func (c *Client) RetrieveAssistantFile( fileID string, ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -212,7 +224,8 @@ func (c *Client) DeleteAssistantFile( fileID string, ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) - req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } @@ -250,7 +263,8 @@ func (c *Client) ListAssistantFiles( } urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) if err != nil { return } diff --git a/client.go b/client.go index 65ece812f..056226c61 100644 --- a/client.go +++ b/client.go @@ -83,6 +83,12 @@ func withContentType(contentType string) requestOption { } } +func withBetaAssistantV1() requestOption { + return func(args *requestOptions) { + args.header.Set("OpenAI-Beta", "assistants=v1") + } +} + func (c *Client) newRequest(ctx context.Context, method, url string, setters ...requestOption) (*http.Request, error) { // Default Options args := &requestOptions{ From 08c167fecf6953619d1905ab2959ed341bfb063d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=A1=E9=82=89=E7=A5=90=E4=B8=80=20/=20Yuichi=20Watana?= =?UTF-8?q?be?= Date: Wed, 8 Nov 2023 18:21:51 +0900 Subject: [PATCH 062/242] test: fix compile error in api integration test (#548) --- api_integration_test.go | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 6be188bc6..736040c50 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -9,6 +9,7 @@ import ( "os" "testing" + "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/internal/test/checks" "github.com/sashabaranov/go-openai/jsonschema" ) @@ -20,7 +21,7 @@ func TestAPI(t *testing.T) { } var err error - c := NewClient(apiToken) + c := openai.NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") @@ -36,23 +37,23 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "GetFile error") } // else skip - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: AdaSearchQuery, + Model: openai.AdaSearchQuery, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") _, err = c.CreateChatCompletion( ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "Hello!", }, }, @@ -63,11 +64,11 @@ func TestAPI(t *testing.T) { _, err = c.CreateChatCompletion( ctx, - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Name: "John_Doe", Content: "Hello!", }, @@ -76,9 +77,9 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, CompletionRequest{ + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ Prompt: "Ex falso quodlibet", - Model: GPT3Ada, + Model: openai.GPT3Ada, MaxTokens: 5, Stream: true, }) @@ -103,15 +104,15 @@ func TestAPI(t *testing.T) { _, err = c.CreateChatCompletion( context.Background(), - ChatCompletionRequest{ - Model: GPT3Dot5Turbo, - Messages: []ChatCompletionMessage{ + openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ { - Role: ChatMessageRoleUser, + Role: openai.ChatMessageRoleUser, Content: "What is the weather like in Boston?", }, }, - Functions: []FunctionDefinition{{ + Functions: []openai.FunctionDefinition{{ Name: "get_current_weather", Parameters: jsonschema.Definition{ Type: jsonschema.Object, @@ -140,12 +141,12 @@ func TestAPIError(t *testing.T) { } var err error - c := NewClient(apiToken + "_invalid") + c := openai.NewClient(apiToken + "_invalid") ctx := context.Background() _, err = c.ListEngines(ctx) checks.HasError(t, err, "ListEngines should fail with an invalid key") - var apiErr *APIError + var apiErr *openai.APIError if !errors.As(err, &apiErr) { t.Fatalf("Error is not an APIError: %+v", err) } From bc89139c1ddcc4f6d5b15b7e8d0491c69dda402c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 09:05:44 +0100 Subject: [PATCH 063/242] Feat Implement threads API (#536) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: implement threads API * fix * add tests * fix * trigger£ * trigger * chore: add beta header --- client_test.go | 12 ++++++ thread.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++ thread_test.go | 95 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 thread.go create mode 100644 thread_test.go diff --git a/client_test.go b/client_test.go index bff2597c5..b2f28f90a 100644 --- a/client_test.go +++ b/client_test.go @@ -301,6 +301,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteAssistantFile", func() (any, error) { return nil, client.DeleteAssistantFile(ctx, "", "") }}, + {"CreateThread", func() (any, error) { + return client.CreateThread(ctx, ThreadRequest{}) + }}, + {"RetrieveThread", func() (any, error) { + return client.RetrieveThread(ctx, "") + }}, + {"ModifyThread", func() (any, error) { + return client.ModifyThread(ctx, "", ModifyThreadRequest{}) + }}, + {"DeleteThread", func() (any, error) { + return client.DeleteThread(ctx, "") + }}, } for _, testCase := range testCases { diff --git a/thread.go b/thread.go new file mode 100644 index 000000000..291f3dcab --- /dev/null +++ b/thread.go @@ -0,0 +1,107 @@ +package openai + +import ( + "context" + "net/http" +) + +const ( + threadsSuffix = "/threads" +) + +type Thread struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type ThreadRequest struct { + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ModifyThreadRequest struct { + Metadata map[string]any `json:"metadata"` +} + +type ThreadMessageRole string + +const ( + ThreadMessageRoleUser ThreadMessageRole = "user" +) + +type ThreadMessage struct { + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +// CreateThread creates a new thread. +func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveThread retrieves a thread. +func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyThread modifies a thread. +func (c *Client) ModifyThread( + ctx context.Context, + threadID string, + request ModifyThreadRequest, +) (response Thread, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// DeleteThread deletes a thread. +func (c *Client) DeleteThread( + ctx context.Context, + threadID string, +) (response ThreadDeleteResponse, err error) { + urlSuffix := threadsSuffix + "/" + threadID + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/thread_test.go b/thread_test.go new file mode 100644 index 000000000..227ab6330 --- /dev/null +++ b/thread_test.go @@ -0,0 +1,95 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +// TestThread Tests the thread endpoint of the API using the mocked server. +func TestThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} From e3e065deb0a190e2d3c3bbf9caf54471b32f675e Mon Sep 17 00:00:00 2001 From: Gabriel Burt Date: Thu, 9 Nov 2023 03:08:43 -0500 Subject: [PATCH 064/242] Add SystemFingerprint and chatMsg.ToolCallID field (#543) * fix ToolChoiche typo * add tool_call_id to ChatCompletionMessage * add /chat system_fingerprint response field * check empty ToolCallID JSON marshaling and add omitempty for tool_call_id * messages also required; don't omitempty * add Type to ToolCall, required by the API * fix test, omitempty for response_format ptr * fix casing of role values in comments --- chat.go | 27 +++++++++++++++++---------- chat_test.go | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/chat.go b/chat.go index 609e0c311..9ad31c466 100644 --- a/chat.go +++ b/chat.go @@ -62,11 +62,17 @@ type ChatCompletionMessage struct { Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + // For Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool. + ToolCallID string `json:"tool_call_id,omitempty"` } type ToolCall struct { ID string `json:"id"` + Type ToolType `json:"type"` Function FunctionCall `json:"function"` } @@ -84,7 +90,7 @@ const ( ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -112,7 +118,7 @@ type ChatCompletionRequest struct { FunctionCall any `json:"function_call,omitempty"` Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. - ToolChoiche any `json:"tool_choice,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } type ToolType string @@ -126,7 +132,7 @@ type Tool struct { Function FunctionDefinition `json:"function,omitempty"` } -type ToolChoiche struct { +type ToolChoice struct { Type ToolType `json:"type"` Function ToolFunction `json:"function,omitempty"` } @@ -182,12 +188,13 @@ type ChatCompletionChoice struct { // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` httpHeader } diff --git a/chat_test.go b/chat_test.go index 5bf1eaf6c..a8155edf2 100644 --- a/chat_test.go +++ b/chat_test.go @@ -51,6 +51,20 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestChatRequestOmitEmpty(t *testing.T) { + data, err := json.Marshal(openai.ChatCompletionRequest{ + // We set model b/c it's required, so omitempty doesn't make sense + Model: "gpt-4", + }) + checks.NoError(t, err) + + // messages is also required so isn't omitted + const expected = `{"model":"gpt-4","messages":null}` + if string(data) != expected { + t.Errorf("expected JSON with all empty fields to be %v but was %v", expected, string(data)) + } +} + func TestChatCompletionsWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" From 81270725539980d202829528054f3fda346970db Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Thu, 9 Nov 2023 00:20:39 -0800 Subject: [PATCH 065/242] fix test server setup: (#549) * fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler --- fine_tuning_job_test.go | 2 +- internal/test/server.go | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index c892ef775..d2fbcd4c7 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -42,7 +42,7 @@ func TestFineTuningJob(t *testing.T) { ) server.RegisterHandler( - "/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", + "/v1/fine_tuning/jobs/"+testFineTuninigJobID+"/cancel", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.FineTuningJob{}) fmt.Fprintln(w, string(resBytes)) diff --git a/internal/test/server.go b/internal/test/server.go index 3813ff869..127d4c16f 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "regexp" + "strings" ) const testAPI = "this-is-my-secure-token-do-not-steal!!" @@ -23,13 +24,16 @@ func NewTestServer() *ServerTest { } func (ts *ServerTest) RegisterHandler(path string, handler handler) { + // to make the registered paths friendlier to a regex match in the route handler + // in OpenAITestServer + path = strings.ReplaceAll(path, "*", ".*") ts.handlers[path] = handler } // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func (ts *ServerTest) OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Printf("received request at path %q\n", r.URL.Path) + log.Printf("received a %s request at path %q\n", r.Method, r.URL.Path) // check auth if r.Header.Get("Authorization") != "Bearer "+GetTestToken() && r.Header.Get("api-key") != GetTestToken() { @@ -38,8 +42,10 @@ func (ts *ServerTest) OpenAITestServer() *httptest.Server { } // Handle /path/* routes. + // Note: the * is converted to a .* in register handler for proper regex handling for route, handler := range ts.handlers { - pattern, _ := regexp.Compile(route) + // Adding ^ and $ to make path matching deterministic since go map iteration isn't ordered + pattern, _ := regexp.Compile("^" + route + "$") if pattern.MatchString(r.URL.Path) { handler(w, r) return From 78862a2798df46f6ca8bb73350b720f9c8d4a592 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 15:05:03 +0100 Subject: [PATCH 066/242] fix: add missing fields in tool_calls (#558) --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index 9ad31c466..ebdc0e24b 100644 --- a/chat.go +++ b/chat.go @@ -71,6 +71,8 @@ type ChatCompletionMessage struct { } type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` ID string `json:"id"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` From d6f3bdcdac9172ab5248d6be8c3e1761446a434c Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Thu, 9 Nov 2023 20:17:30 +0100 Subject: [PATCH 067/242] Feat implement Run APIs (#560) * chore: first commit * add apis * chore: add tests * feat add apis * chore: add api and tests * chore: add tests * fix * trigger build * fix * chore: formatting code * chore: add pagination type --- client_test.go | 27 ++++ run.go | 399 +++++++++++++++++++++++++++++++++++++++++++++++++ run_test.go | 237 +++++++++++++++++++++++++++++ 3 files changed, 663 insertions(+) create mode 100644 run.go create mode 100644 run_test.go diff --git a/client_test.go b/client_test.go index b2f28f90a..d5d3e2644 100644 --- a/client_test.go +++ b/client_test.go @@ -313,6 +313,33 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteThread", func() (any, error) { return client.DeleteThread(ctx, "") }}, + {"CreateRun", func() (any, error) { + return client.CreateRun(ctx, "", RunRequest{}) + }}, + {"RetrieveRun", func() (any, error) { + return client.RetrieveRun(ctx, "", "") + }}, + {"ModifyRun", func() (any, error) { + return client.ModifyRun(ctx, "", "", RunModifyRequest{}) + }}, + {"ListRuns", func() (any, error) { + return client.ListRuns(ctx, "", Pagination{}) + }}, + {"SubmitToolOutputs", func() (any, error) { + return client.SubmitToolOutputs(ctx, "", "", SubmitToolOutputsRequest{}) + }}, + {"CancelRun", func() (any, error) { + return client.CancelRun(ctx, "", "") + }}, + {"CreateThreadAndRun", func() (any, error) { + return client.CreateThreadAndRun(ctx, CreateThreadAndRunRequest{}) + }}, + {"RetrieveRunStep", func() (any, error) { + return client.RetrieveRunStep(ctx, "", "", "") + }}, + {"ListRunSteps", func() (any, error) { + return client.ListRunSteps(ctx, "", "", Pagination{}) + }}, } for _, testCase := range testCases { diff --git a/run.go b/run.go new file mode 100644 index 000000000..5d6ea58db --- /dev/null +++ b/run.go @@ -0,0 +1,399 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +type Run struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + ThreadID string `json:"thread_id"` + AssistantID string `json:"assistant_id"` + Status RunStatus `json:"status"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools"` + FileIDS []string `json:"file_ids"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStatus string + +const ( + RunStatusQueued RunStatus = "queued" + RunStatusInProgress RunStatus = "in_progress" + RunStatusRequiresAction RunStatus = "requires_action" + RunStatusCancelling RunStatus = "cancelling" + RunStatusFailed RunStatus = "failed" + RunStatusCompleted RunStatus = "completed" + RunStatusExpired RunStatus = "expired" +) + +type RunRequiredAction struct { + Type RequiredActionType `json:"type"` + SubmitToolOutputs *SubmitToolOutputs `json:"submit_tool_outputs,omitempty"` +} + +type RequiredActionType string + +const ( + RequiredActionTypeSubmitToolOutputs RequiredActionType = "submit_tool_outputs" +) + +type SubmitToolOutputs struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +type RunLastError struct { + Code RunError `json:"code"` + Message string `json:"message"` +} + +type RunError string + +const ( + RunErrorServerError RunError = "server_error" + RunErrorRateLimitExceeded RunError = "rate_limit_exceeded" +) + +type RunRequest struct { + AssistantID string `json:"assistant_id"` + Model *string `json:"model,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any +} + +type RunModifyRequest struct { + Metadata map[string]any `json:"metadata,omitempty"` +} + +// RunList is a list of runs. +type RunList struct { + Runs []Run `json:"data"` + + httpHeader +} + +type SubmitToolOutputsRequest struct { + ToolOutputs []ToolOutput `json:"tool_outputs"` +} + +type ToolOutput struct { + ToolCallID string `json:"tool_call_id"` + Output any `json:"output"` +} + +type CreateThreadAndRunRequest struct { + RunRequest + Thread ThreadRequest `json:"thread"` +} + +type RunStep struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + AssistantID string `json:"assistant_id"` + ThreadID string `json:"thread_id"` + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + Status RunStepStatus `json:"status"` + StepDetails StepDetails `json:"step_details"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiredAt *int64 `json:"expired_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type RunStepStatus string + +const ( + RunStepStatusInProgress RunStepStatus = "in_progress" + RunStepStatusCancelling RunStepStatus = "cancelled" + RunStepStatusFailed RunStepStatus = "failed" + RunStepStatusCompleted RunStepStatus = "completed" + RunStepStatusExpired RunStepStatus = "expired" +) + +type RunStepType string + +const ( + RunStepTypeMessageCreation RunStepType = "message_creation" + RunStepTypeToolCalls RunStepType = "tool_calls" +) + +type StepDetails struct { + Type RunStepType `json:"type"` + MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` + ToolCalls *StepDetailsToolCalls `json:"tool_calls,omitempty"` +} + +type StepDetailsMessageCreation struct { + MessageID string `json:"message_id"` +} + +type StepDetailsToolCalls struct { + ToolCalls []ToolCall `json:"tool_calls"` +} + +// RunStepList is a list of steps. +type RunStepList struct { + RunSteps []RunStep `json:"data"` + + httpHeader +} + +type Pagination struct { + Limit *int + Order *string + After *string + Before *string +} + +// CreateRun creates a new run. +func (c *Client) CreateRun( + ctx context.Context, + threadID string, + request RunRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRun retrieves a run. +func (c *Client) RetrieveRun( + ctx context.Context, + threadID string, + runID string, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ModifyRun modifies a run. +func (c *Client) ModifyRun( + ctx context.Context, + threadID string, + runID string, + request RunModifyRequest, +) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRuns lists runs. +func (c *Client) ListRuns( + ctx context.Context, + threadID string, + pagination Pagination, +) (response RunList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs%s", threadID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// SubmitToolOutputs submits tool outputs. +func (c *Client) SubmitToolOutputs( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CancelRun cancels a run. +func (c *Client) CancelRun( + ctx context.Context, + threadID string, + runID string) (response Run, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/cancel", threadID, runID) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// CreateThreadAndRun submits tool outputs. +func (c *Client) CreateThreadAndRun( + ctx context.Context, + request CreateThreadAndRunRequest) (response Run, err error) { + urlSuffix := "/threads/runs" + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// RetrieveRunStep retrieves a run step. +func (c *Client) RetrieveRunStep( + ctx context.Context, + threadID string, + runID string, + stepID string, +) (response RunStep, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps/%s", threadID, runID, stepID) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +// ListRunSteps lists run steps. +func (c *Client) ListRunSteps( + ctx context.Context, + threadID string, + runID string, + pagination Pagination, +) (response RunStepList, err error) { + urlValues := url.Values{} + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/steps%s", threadID, runID, encodedValues) + req, err := c.newRequest( + ctx, + http.MethodGet, + c.fullURL(urlSuffix), + withBetaAssistantV1(), + ) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/run_test.go b/run_test.go new file mode 100644 index 000000000..cdf99db05 --- /dev/null +++ b/run_test.go @@ -0,0 +1,237 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestAssistant Tests the assistant endpoint of the API using the mocked server. +func TestRun(t *testing.T) { + assistantID := "asst_abc123" + threadID := "thread_abc123" + runID := "run_abc123" + stepID := "step_abc123" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps/"+stepID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStep{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/steps", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunStepList{ + RunSteps: []openai.RunStep{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStepStatusCompleted, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID+"/submit_tool_outputs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusCancelling, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs/"+runID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.RunModifyRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.RunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.RunList{ + Runs: []openai.Run{ + { + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/runs", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.CreateThreadAndRunRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Run{ + ID: runID, + Object: "run", + CreatedAt: 1234567890, + Status: openai.RunStatusQueued, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateRun(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRun error") + + _, err = client.RetrieveRun(ctx, threadID, runID) + checks.NoError(t, err, "RetrieveRun error") + + _, err = client.ModifyRun(ctx, threadID, runID, openai.RunModifyRequest{ + Metadata: map[string]any{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyRun error") + + _, err = client.ListRuns( + ctx, + threadID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRuns error") + + _, err = client.SubmitToolOutputs(ctx, threadID, runID, + openai.SubmitToolOutputsRequest{}) + checks.NoError(t, err, "SubmitToolOutputs error") + + _, err = client.CancelRun(ctx, threadID, runID) + checks.NoError(t, err, "CancelRun error") + + _, err = client.CreateThreadAndRun(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndRun error") + + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) + checks.NoError(t, err, "RetrieveRunStep error") + + _, err = client.ListRunSteps( + ctx, + threadID, + runID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }, + ) + checks.NoError(t, err, "ListRunSteps error") +} From 35495ccd364265f37800a6fa72fed7f05705eb82 Mon Sep 17 00:00:00 2001 From: Kyle Bolton Date: Sun, 12 Nov 2023 06:09:40 -0500 Subject: [PATCH 068/242] Add `json:"metadata,omitempty"` to RunRequest struct (#561) Metadata is an optional field per the api spec https://platform.openai.com/docs/api-reference/runs/createRun --- run.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/run.go b/run.go index 5d6ea58db..7ff730fea 100644 --- a/run.go +++ b/run.go @@ -70,11 +70,11 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model *string `json:"model,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any + AssistantID string `json:"assistant_id"` + Model *string `json:"model,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type RunModifyRequest struct { From 9fefd50e12ad138efa3f38756be5dd2ed5fefadd Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 12 Nov 2023 20:10:00 +0900 Subject: [PATCH 069/242] Fix typo in chat_test.go (#564) requetsts -> requests --- chat_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat_test.go b/chat_test.go index a8155edf2..8377809da 100644 --- a/chat_test.go +++ b/chat_test.go @@ -144,7 +144,7 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { } resetRequestsTime := headers.ResetRequests.Time() if resetRequestsTime.Before(time.Now()) { - t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + t.Errorf("unexpected reset requests: %v", resetRequestsTime) } bs1, _ := json.Marshal(headers) From b7cac703acb1a8be0e803c81ad3236be66be969a Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Mon, 13 Nov 2023 08:33:26 -0600 Subject: [PATCH 070/242] Feat/messages api (#546) * fix test server setup: - go map access is not deterministic - this can lead to a route: /foo/bar/1 matching /foo/bar before matching /foo/bar/1 if the map iteration go through /foo/bar first since the regex match wasn't bound to start and end anchors - registering handlers now converts * in routes to .* for proper regex matching - test server route handling now tries to fully match the handler route * add missing /v1 prefix to fine-tuning job cancel test server handler * add create message call * add messages list call * add get message call * add modify message call, fix return types for other message calls * add message file retrieve call * add list message files call * code style fixes * add test for list messages with pagination options * add beta header to msg calls now that #545 is merged * Update messages.go Co-authored-by: Simone Vellei * Update messages.go Co-authored-by: Simone Vellei * add missing object details for message, fix tests * fix merge formatting * minor style fixes --------- Co-authored-by: Simone Vellei --- client_test.go | 18 ++++ messages.go | 178 +++++++++++++++++++++++++++++++++++ messages_test.go | 235 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 431 insertions(+) create mode 100644 messages.go create mode 100644 messages_test.go diff --git a/client_test.go b/client_test.go index d5d3e2644..24cb5ffa7 100644 --- a/client_test.go +++ b/client_test.go @@ -301,6 +301,24 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"DeleteAssistantFile", func() (any, error) { return nil, client.DeleteAssistantFile(ctx, "", "") }}, + {"CreateMessage", func() (any, error) { + return client.CreateMessage(ctx, "", MessageRequest{}) + }}, + {"ListMessage", func() (any, error) { + return client.ListMessage(ctx, "", nil, nil, nil, nil) + }}, + {"RetrieveMessage", func() (any, error) { + return client.RetrieveMessage(ctx, "", "") + }}, + {"ModifyMessage", func() (any, error) { + return client.ModifyMessage(ctx, "", "", nil) + }}, + {"RetrieveMessageFile", func() (any, error) { + return client.RetrieveMessageFile(ctx, "", "", "") + }}, + {"ListMessageFiles", func() (any, error) { + return client.ListMessageFiles(ctx, "", "") + }}, {"CreateThread", func() (any, error) { return client.CreateThread(ctx, ThreadRequest{}) }}, diff --git a/messages.go b/messages.go new file mode 100644 index 000000000..4e691a8ba --- /dev/null +++ b/messages.go @@ -0,0 +1,178 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + messagesSuffix = "messages" +) + +type Message struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + ThreadID string `json:"thread_id"` + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIds []string `json:"file_ids"` + AssistantID *string `json:"assistant_id,omitempty"` + RunID *string `json:"run_id,omitempty"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type MessagesList struct { + Messages []Message `json:"data"` + + httpHeader +} + +type MessageContent struct { + Type string `json:"type"` + Text *MessageText `json:"text,omitempty"` + ImageFile *ImageFile `json:"image_file,omitempty"` +} +type MessageText struct { + Value string `json:"value"` + Annotations []any `json:"annotations"` +} + +type ImageFile struct { + FileID string `json:"file_id"` +} + +type MessageRequest struct { + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type MessageFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int `json:"created_at"` + MessageID string `json:"message_id"` + + httpHeader +} + +type MessageFilesList struct { + MessageFiles []MessageFile `json:"data"` + + httpHeader +} + +// CreateMessage creates a new message. +func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ListMessage fetches all messages in the thread. +func (c *Client) ListMessage(ctx context.Context, threadID string, + limit *int, + order *string, + after *string, + before *string, +) (messages MessagesList, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if order != nil { + urlValues.Add("order", *order) + } + if after != nil { + urlValues.Add("after", *after) + } + if before != nil { + urlValues.Add("before", *before) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &messages) + return +} + +// RetrieveMessage retrieves a Message. +func (c *Client) RetrieveMessage( + ctx context.Context, + threadID, messageID string, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// ModifyMessage modifies a message. +func (c *Client) ModifyMessage( + ctx context.Context, + threadID, messageID string, + metadata map[string]any, +) (msg Message, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(metadata), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &msg) + return +} + +// RetrieveMessageFile fetches a message file. +func (c *Client) RetrieveMessageFile( + ctx context.Context, + threadID, messageID, fileID string, +) (file MessageFile, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + +// ListMessageFiles fetches all files attached to a message. +func (c *Client) ListMessageFiles( + ctx context.Context, + threadID, messageID string, +) (files MessageFilesList, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + if err != nil { + return + } + + err = c.sendRequest(req, &files) + return +} diff --git a/messages_test.go b/messages_test.go new file mode 100644 index 000000000..282b1cc9d --- /dev/null +++ b/messages_test.go @@ -0,0 +1,235 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +var emptyStr = "" + +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFile{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 1699061776, + MessageID: messageID, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID+"/files", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.MessageFilesList{MessageFiles: []openai.MessageFile{{ + ID: fileID, + Object: "thread.message.file", + CreatedAt: 0, + MessageID: messageID, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages/"+messageID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + metadata := map[string]any{} + err := json.NewDecoder(r.Body).Decode(&metadata) + checks.NoError(t, err, "unable to decode metadata in modify message call") + + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: metadata, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal( + openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + server.RegisterHandler( + "/v1/threads/"+threadID+"/messages", + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + resBytes, _ := json.Marshal(openai.Message{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodGet: + resBytes, _ := json.Marshal(openai.MessagesList{ + Messages: []openai.Message{{ + ID: messageID, + Object: "thread.message", + CreatedAt: 1234567890, + ThreadID: threadID, + Role: "user", + Content: []openai.MessageContent{{ + Type: "text", + Text: &openai.MessageText{ + Value: "How does AI work?", + Annotations: nil, + }, + }}, + FileIds: nil, + AssistantID: &emptyStr, + RunID: &emptyStr, + Metadata: nil, + }}}) + fmt.Fprintln(w, string(resBytes)) + default: + t.Fatalf("unsupported messages http method: %s", r.Method) + } + }, + ) + + ctx := context.Background() + + // static assertion of return type + var msg openai.Message + msg, err := client.CreateMessage(ctx, threadID, openai.MessageRequest{ + Role: "user", + Content: "How does AI work?", + FileIds: nil, + Metadata: nil, + }) + checks.NoError(t, err, "CreateMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + var msgs openai.MessagesList + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + // with pagination options set + limit := 1 + order := "desc" + after := "obj_foo" + before := "obj_bar" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListMessages error") + if len(msgs.Messages) != 1 { + t.Fatalf("unexpected length of fetched messages") + } + + msg, err = client.RetrieveMessage(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessage error") + if msg.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + + msg, err = client.ModifyMessage(ctx, threadID, messageID, + map[string]any{ + "foo": "bar", + }) + checks.NoError(t, err, "ModifyMessage error") + if msg.Metadata["foo"] != "bar" { + t.Fatalf("expected message metadata to get modified") + } + + // message files + var msgFile openai.MessageFile + msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) + checks.NoError(t, err, "RetrieveMessageFile error") + if msgFile.ID != fileID { + t.Fatalf("unexpected message file id: '%s'", msgFile.ID) + } + + var msgFiles openai.MessageFilesList + msgFiles, err = client.ListMessageFiles(ctx, threadID, messageID) + checks.NoError(t, err, "RetrieveMessageFile error") + if len(msgFiles.MessageFiles) != 1 { + t.Fatalf("unexpected count of message files: %d", len(msgFiles.MessageFiles)) + } + if msgFiles.MessageFiles[0].ID != fileID { + t.Fatalf("unexpected message file id: '%s' in list message files", msgFiles.MessageFiles[0].ID) + } +} From 515de0219d3b4d30351d44d8a0f508599de6c053 Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Mon, 13 Nov 2023 09:35:34 -0500 Subject: [PATCH 071/242] feat: initial TTS support (#528) * feat: initial TTS support * chore: lint, omitempty * chore: dont use pointer in struct * fix: add mocked server tests to speech_test.go Co-authored-by: Lachlan Laycock * chore: update imports * chore: fix lint * chore: add an error check * chore: ignore lint * chore: add error checks in package * chore: add test * chore: fix test --------- Co-authored-by: Lachlan Laycock --- client_test.go | 3 ++ speech.go | 87 +++++++++++++++++++++++++++++++++++++ speech_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 speech.go create mode 100644 speech_test.go diff --git a/client_test.go b/client_test.go index 24cb5ffa7..1c9084585 100644 --- a/client_test.go +++ b/client_test.go @@ -358,6 +358,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ListRunSteps", func() (any, error) { return client.ListRunSteps(ctx, "", "", Pagination{}) }}, + {"CreateSpeech", func() (any, error) { + return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) + }}, } for _, testCase := range testCases { diff --git a/speech.go b/speech.go new file mode 100644 index 000000000..a3d5f5dca --- /dev/null +++ b/speech.go @@ -0,0 +1,87 @@ +package openai + +import ( + "context" + "errors" + "io" + "net/http" +) + +type SpeechModel string + +const ( + TTSModel1 SpeechModel = "tts-1" + TTsModel1HD SpeechModel = "tts-1-hd" +) + +type SpeechVoice string + +const ( + VoiceAlloy SpeechVoice = "alloy" + VoiceEcho SpeechVoice = "echo" + VoiceFable SpeechVoice = "fable" + VoiceOnyx SpeechVoice = "onyx" + VoiceNova SpeechVoice = "nova" + VoiceShimmer SpeechVoice = "shimmer" +) + +type SpeechResponseFormat string + +const ( + SpeechResponseFormatMp3 SpeechResponseFormat = "mp3" + SpeechResponseFormatOpus SpeechResponseFormat = "opus" + SpeechResponseFormatAac SpeechResponseFormat = "aac" + SpeechResponseFormatFlac SpeechResponseFormat = "flac" +) + +var ( + ErrInvalidSpeechModel = errors.New("invalid speech model") + ErrInvalidVoice = errors.New("invalid voice") +) + +type CreateSpeechRequest struct { + Model SpeechModel `json:"model"` + Input string `json:"input"` + Voice SpeechVoice `json:"voice"` + ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 + Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 +} + +func contains[T comparable](s []T, e T) bool { + for _, v := range s { + if v == e { + return true + } + } + return false +} + +func isValidSpeechModel(model SpeechModel) bool { + return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) +} + +func isValidVoice(voice SpeechVoice) bool { + return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) +} + +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { + if !isValidSpeechModel(request.Model) { + err = ErrInvalidSpeechModel + return + } + if !isValidVoice(request.Voice) { + err = ErrInvalidVoice + return + } + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + withBody(request), + withContentType("application/json; charset=utf-8"), + ) + if err != nil { + return + } + + response, err = c.sendRequestRaw(req) + + return +} diff --git a/speech_test.go b/speech_test.go new file mode 100644 index 000000000..d9ba58b13 --- /dev/null +++ b/speech_test.go @@ -0,0 +1,115 @@ +package openai_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "os" + "path/filepath" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestSpeechIntegration(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { + dir, cleanup := test.CreateTestDirectory(t) + path := filepath.Join(dir, "fake.mp3") + test.CreateTestFile(t, path) + defer cleanup() + + // audio endpoints only accept POST requests + if r.Method != "POST" { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, "failed to parse media type", http.StatusBadRequest) + return + } + + if mediaType != "application/json" { + http.Error(w, "request is not json", http.StatusBadRequest) + return + } + + // Parse the JSON body of the request + var params map[string]interface{} + err = json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + http.Error(w, "failed to parse request body", http.StatusBadRequest) + return + } + + // Check if each required field is present in the parsed JSON object + reqParams := []string{"model", "input", "voice"} + for _, param := range reqParams { + _, ok := params[param] + if !ok { + http.Error(w, fmt.Sprintf("no %s in params", param), http.StatusBadRequest) + return + } + } + + // read audio file content + audioFile, err := os.ReadFile(path) + if err != nil { + http.Error(w, "failed to read audio file", http.StatusInternalServerError) + return + } + + // write audio file content to response + w.Header().Set("Content-Type", "audio/mpeg") + w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Connection", "keep-alive") + _, err = w.Write(audioFile) + if err != nil { + http.Error(w, "failed to write body", http.StatusInternalServerError) + return + } + }) + + t.Run("happy path", func(t *testing.T) { + res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.NoError(t, err, "CreateSpeech error") + defer res.Close() + + buf, err := io.ReadAll(res) + checks.NoError(t, err, "ReadAll error") + + // save buf to file as mp3 + err = os.WriteFile("test.mp3", buf, 0644) + checks.NoError(t, err, "Create error") + }) + t.Run("invalid model", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: "invalid_model", + Input: "Hello!", + Voice: openai.VoiceAlloy, + }) + checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") + }) + + t.Run("invalid voice", func(t *testing.T) { + _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ + Model: openai.TTSModel1, + Input: "Hello!", + Voice: "invalid_voice", + }) + checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") + }) +} From fe67abb97ed472bad359cc606c2d63289277cabf Mon Sep 17 00:00:00 2001 From: Donnie Flood Date: Wed, 15 Nov 2023 09:06:57 -0700 Subject: [PATCH 072/242] fix: add beta assistant header to CreateMessage call (#566) --- messages.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/messages.go b/messages.go index 4e691a8ba..3fd377fcb 100644 --- a/messages.go +++ b/messages.go @@ -71,7 +71,7 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) if err != nil { return } From 71848ccf6928157d1487c5bbd5029ceaf3af53ed Mon Sep 17 00:00:00 2001 From: Donnie Flood Date: Wed, 15 Nov 2023 09:08:48 -0700 Subject: [PATCH 073/242] feat: support direct bytes for file upload (#568) * feat: support direct bytes for file upload * add test for errors * add coverage --- client_test.go | 3 +++ files.go | 49 ++++++++++++++++++++++++++++++++++++++++++++++ files_api_test.go | 13 ++++++++++++ files_test.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+) diff --git a/client_test.go b/client_test.go index 1c9084585..664f9fb92 100644 --- a/client_test.go +++ b/client_test.go @@ -247,6 +247,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateImage", func() (any, error) { return client.CreateImage(ctx, ImageRequest{}) }}, + {"CreateFileBytes", func() (any, error) { + return client.CreateFileBytes(ctx, FileBytesRequest{}) + }}, {"DeleteFile", func() (any, error) { return nil, client.DeleteFile(ctx, "") }}, diff --git a/files.go b/files.go index 9e521fbbe..371d06c69 100644 --- a/files.go +++ b/files.go @@ -15,6 +15,24 @@ type FileRequest struct { Purpose string `json:"purpose"` } +// PurposeType represents the purpose of the file when uploading. +type PurposeType string + +const ( + PurposeFineTune PurposeType = "fine-tune" + PurposeAssistants PurposeType = "assistants" +) + +// FileBytesRequest represents a file upload request. +type FileBytesRequest struct { + // the name of the uploaded file in OpenAI + Name string + // the bytes of the file + Bytes []byte + // the purpose of the file + Purpose PurposeType +} + // File struct represents an OpenAPI file. type File struct { Bytes int `json:"bytes"` @@ -36,6 +54,37 @@ type FilesList struct { httpHeader } +// CreateFileBytes uploads bytes directly to OpenAI without requiring a local file. +func (c *Client) CreateFileBytes(ctx context.Context, request FileBytesRequest) (file File, err error) { + var b bytes.Buffer + reader := bytes.NewReader(request.Bytes) + builder := c.createFormBuilder(&b) + + err = builder.WriteField("purpose", string(request.Purpose)) + if err != nil { + return + } + + err = builder.CreateFormFileReader("file", reader, request.Name) + if err != nil { + return + } + + err = builder.Close() + if err != nil { + return + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/files"), + withBody(&b), withContentType(builder.FormDataContentType())) + if err != nil { + return + } + + err = c.sendRequest(req, &file) + return +} + // CreateFile uploads a jsonl file to GPT3 // FilePath must be a local file path. func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { diff --git a/files_api_test.go b/files_api_test.go index 330b88159..6f62a3fbc 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -16,6 +16,19 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUpload(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: openai.PurposeFineTune, + } + _, err := client.CreateFileBytes(context.Background(), req) + checks.NoError(t, err, "CreateFile error") +} + func TestFileUpload(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/files_test.go b/files_test.go index f588b30dc..3c1b99fb4 100644 --- a/files_test.go +++ b/files_test.go @@ -11,6 +11,53 @@ import ( "github.com/sashabaranov/go-openai/internal/test/checks" ) +func TestFileBytesUploadWithFailingFormBuilder(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + mockBuilder := &mockFormBuilder{} + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return mockBuilder + } + + ctx := context.Background() + req := FileBytesRequest{ + Name: "foo", + Bytes: []byte("foo"), + Purpose: PurposeAssistants, + } + + mockError := fmt.Errorf("mockWriteField error") + mockBuilder.mockWriteField = func(string, string) error { + return mockError + } + _, err := client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockCreateFormFile error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") + + mockError = fmt.Errorf("mockClose error") + mockBuilder.mockWriteField = func(string, string) error { + return nil + } + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { + return nil + } + mockBuilder.mockClose = func() error { + return mockError + } + _, err = client.CreateFileBytes(ctx, req) + checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") +} + func TestFileUploadWithFailingFormBuilder(t *testing.T) { config := DefaultConfig("") config.BaseURL = "" @@ -55,6 +102,9 @@ func TestFileUploadWithFailingFormBuilder(t *testing.T) { return mockError } _, err = client.CreateFile(ctx, req) + if err == nil { + t.Fatal("CreateFile should return error if form builder fails") + } checks.ErrorIs(t, err, mockError, "CreateFile should return error if form builder fails") } From 464b85b6d766a53c922a15dd1138570e31ec661b Mon Sep 17 00:00:00 2001 From: Liron Levin Date: Wed, 15 Nov 2023 18:22:39 +0200 Subject: [PATCH 074/242] Pagination fields are missing from assistants list beta API (#571) curl "https://api.openai.com/v1/assistants?order=desc&limit=20" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer $OPENAI_API_KEY" \ -H "OpenAI-Beta: assistants=v1" { "object": "list", "data": [], "first_id": null, "last_id": null, "has_more": false } --- assistant.go | 4 +++- assistant_test.go | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/assistant.go b/assistant.go index de49be680..59f78284f 100644 --- a/assistant.go +++ b/assistant.go @@ -52,7 +52,9 @@ type AssistantRequest struct { // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` - + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` httpHeader } diff --git a/assistant_test.go b/assistant_test.go index eb6f42458..30daec2b1 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -142,6 +142,8 @@ When asked a question, write and run Python code to answer the question.` fmt.Fprintln(w, string(resBytes)) } else if r.Method == http.MethodGet { resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, Assistants: []openai.Assistant{ { ID: assistantID, From 3220f19ee209de5e4bbc6db44261adcd4bbf1df1 Mon Sep 17 00:00:00 2001 From: Ccheers <1048315650@qq.com> Date: Thu, 16 Nov 2023 00:23:41 +0800 Subject: [PATCH 075/242] feat(runapi): add RunStepList response args https://platform.openai.com/docs/api-reference/runs/listRunSteps (#573) --- run.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/run.go b/run.go index 7ff730fea..f95bf0e35 100644 --- a/run.go +++ b/run.go @@ -157,6 +157,10 @@ type StepDetailsToolCalls struct { type RunStepList struct { RunSteps []RunStep `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` + httpHeader } From 18465723f7d96587045ce0a450d6874128b870cd Mon Sep 17 00:00:00 2001 From: Charlie Revett <2796074+revett@users.noreply.github.com> Date: Wed, 15 Nov 2023 16:25:18 +0000 Subject: [PATCH 076/242] Add missing struct properties. (#579) --- assistant.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/assistant.go b/assistant.go index 59f78284f..bd335833a 100644 --- a/assistant.go +++ b/assistant.go @@ -22,6 +22,8 @@ type Assistant struct { Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` Tools []AssistantTool `json:"tools,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` httpHeader } From 4fd904c2927c421cdbff89249979bc6a8a371d11 Mon Sep 17 00:00:00 2001 From: Charlie Revett <2796074+revett@users.noreply.github.com> Date: Sat, 18 Nov 2023 06:55:58 +0000 Subject: [PATCH 077/242] Add File purposes as constants (#577) * Add purposes. * Formatting. --- files.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/files.go b/files.go index 371d06c69..a37d45f18 100644 --- a/files.go +++ b/files.go @@ -19,8 +19,10 @@ type FileRequest struct { type PurposeType string const ( - PurposeFineTune PurposeType = "fine-tune" - PurposeAssistants PurposeType = "assistants" + PurposeFineTune PurposeType = "fine-tune" + PurposeFineTuneResults PurposeType = "fine-tune-results" + PurposeAssistants PurposeType = "assistants" + PurposeAssistantsOutput PurposeType = "assistants_output" ) // FileBytesRequest represents a file upload request. From 9efad284d02d90b2de3eeefc67a966743e47a2ac Mon Sep 17 00:00:00 2001 From: Albert Putra Purnama <14824254+albertpurnama@users.noreply.github.com> Date: Fri, 17 Nov 2023 22:59:01 -0800 Subject: [PATCH 078/242] Updates the tool call struct (#595) --- run.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/run.go b/run.go index f95bf0e35..dbb708a13 100644 --- a/run.go +++ b/run.go @@ -142,17 +142,13 @@ const ( type StepDetails struct { Type RunStepType `json:"type"` MessageCreation *StepDetailsMessageCreation `json:"message_creation,omitempty"` - ToolCalls *StepDetailsToolCalls `json:"tool_calls,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type StepDetailsMessageCreation struct { MessageID string `json:"message_id"` } -type StepDetailsToolCalls struct { - ToolCalls []ToolCall `json:"tool_calls"` -} - // RunStepList is a list of steps. type RunStepList struct { RunSteps []RunStep `json:"data"` From a130cfee26427b99ae0bf957be74e32ca8a7f567 Mon Sep 17 00:00:00 2001 From: Albert Putra Purnama <14824254+albertpurnama@users.noreply.github.com> Date: Fri, 17 Nov 2023 23:01:06 -0800 Subject: [PATCH 079/242] Add missing response fields for pagination (#584) --- messages.go | 5 +++++ messages_test.go | 7 ++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/messages.go b/messages.go index 3fd377fcb..ead247f5b 100644 --- a/messages.go +++ b/messages.go @@ -29,6 +29,11 @@ type Message struct { type MessagesList struct { Messages []Message `json:"data"` + Object string `json:"object"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` + httpHeader } diff --git a/messages_test.go b/messages_test.go index 282b1cc9d..9168d6ccf 100644 --- a/messages_test.go +++ b/messages_test.go @@ -142,6 +142,7 @@ func TestMessages(t *testing.T) { fmt.Fprintln(w, string(resBytes)) case http.MethodGet: resBytes, _ := json.Marshal(openai.MessagesList{ + Object: "list", Messages: []openai.Message{{ ID: messageID, Object: "thread.message", @@ -159,7 +160,11 @@ func TestMessages(t *testing.T) { AssistantID: &emptyStr, RunID: &emptyStr, Metadata: nil, - }}}) + }}, + FirstID: &messageID, + LastID: &messageID, + HasMore: false, + }) fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) From f87909596f8b0d293142ca00c4d4adc872c52ded Mon Sep 17 00:00:00 2001 From: pjuhasz Date: Fri, 24 Nov 2023 07:34:25 +0000 Subject: [PATCH 080/242] Add canary-tts to speech models (#603) Co-authored-by: Peter Juhasz --- speech.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/speech.go b/speech.go index a3d5f5dca..f2442b921 100644 --- a/speech.go +++ b/speech.go @@ -10,8 +10,9 @@ import ( type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTsModel1HD SpeechModel = "tts-1-hd" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" ) type SpeechVoice string @@ -57,7 +58,7 @@ func contains[T comparable](s []T, e T) bool { } func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTsModel1HD}, model) + return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) } func isValidVoice(voice SpeechVoice) bool { From 726099132704fd5ebc1680166f45bbd280bdb546 Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 24 Nov 2023 13:36:10 +0400 Subject: [PATCH 081/242] Update PULL_REQUEST_TEMPLATE.md (#606) --- .github/PULL_REQUEST_TEMPLATE.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 44bf697ed..222c065ce 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,11 +8,14 @@ Thanks for submitting a pull request! Please provide enough information so that **Describe the change** Please provide a clear and concise description of the changes you're proposing. Explain what problem it solves or what feature it adds. +**Provide OpenAI documentation link** +Provide a relevant API doc from https://platform.openai.com/docs/api-reference + **Describe your solution** Describe how your changes address the problem or how they add the feature. This should include a brief description of your approach and any new libraries or dependencies you're using. **Tests** -Briefly describe how you have tested these changes. +Briefly describe how you have tested these changes. If possible — please add integration tests. **Additional context** Add any other context or screenshots or logs about your pull request here. If the pull request relates to an open issue, please link to it. From 03caea89b75c4e6a5ac32f6e60e69e309d852e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rados=C5=82aw=20Kintzi?= Date: Fri, 24 Nov 2023 13:17:00 +0000 Subject: [PATCH 082/242] Add support for multi part chat messages (and gpt-4-vision-preview model) (#580) * Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error --- chat.go | 91 ++++++++++++++++++++++++++++++++++++++++++++- chat_test.go | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index ebdc0e24b..5b87b6bd7 100644 --- a/chat.go +++ b/chat.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "errors" "net/http" ) @@ -20,6 +21,7 @@ const chatCompletionsSuffix = "/chat/completions" var ( ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll + ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously") ) type Hate struct { @@ -51,9 +53,36 @@ type PromptAnnotation struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type ImageURLDetail string + +const ( + ImageURLDetailHigh ImageURLDetail = "high" + ImageURLDetailLow ImageURLDetail = "low" + ImageURLDetailAuto ImageURLDetail = "auto" +) + +type ChatMessageImageURL struct { + URL string `json:"url,omitempty"` + Detail ImageURLDetail `json:"detail,omitempty"` +} + +type ChatMessagePartType string + +const ( + ChatMessagePartTypeText ChatMessagePartType = "text" + ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" +) + +type ChatMessagePart struct { + Type ChatMessagePartType `json:"type,omitempty"` + Text string `json:"text,omitempty"` + ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` +} + type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in // the documentation for the official library for python: @@ -70,6 +99,64 @@ type ChatCompletionMessage struct { ToolCallID string `json:"tool_call_id,omitempty"` } +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + if m.Content != "" && m.MultiContent != nil { + return nil, ErrContentFieldsMisused + } + if len(m.MultiContent) > 0 { + msg := struct { + Role string `json:"role"` + Content string `json:"-"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) + } + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }(m) + return json.Marshal(msg) +} + +func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { + msg := struct { + Role string `json:"role"` + Content string `json:"content"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &msg); err == nil { + *m = ChatCompletionMessage(msg) + return nil + } + multiMsg := struct { + Role string `json:"role"` + Content string + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + }{} + if err := json.Unmarshal(bs, &multiMsg); err != nil { + return err + } + *m = ChatCompletionMessage(multiMsg) + return nil +} + type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` diff --git a/chat_test.go b/chat_test.go index 8377809da..520bf5ca4 100644 --- a/chat_test.go +++ b/chat_test.go @@ -3,6 +3,7 @@ package openai_test import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateAzureChatCompletion error") } +func TestMultipartChatCompletions(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint) + + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: "Hello!", + }, + { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: "URL", + Detail: openai.ImageURLDetailLow, + }, + }, + }, + }, + }, + }) + checks.NoError(t, err, "CreateAzureChatCompletion error") +} + +func TestMultipartChatMessageSerialization(t *testing.T) { + jsonText := `[{"role":"system","content":"system-message"},` + + `{"role":"user","content":[{"type":"text","text":"nice-text"},` + + `{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]` + + var msgs []openai.ChatCompletionMessage + err := json.Unmarshal([]byte(jsonText), &msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + if len(msgs) != 2 { + t.Errorf("unexpected number of messages") + } + if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil { + t.Errorf("invalid user message: %v", msgs[0]) + } + if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 { + t.Errorf("invalid user message") + } + parts := msgs[1].MultiContent + if parts[0].Type != "text" || parts[0].Text != "nice-text" { + t.Errorf("invalid text part: %v", parts[0]) + } + if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" { + t.Errorf("invalid image_url part") + } + + s, err := json.Marshal(msgs) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + res := strings.ReplaceAll(string(s), " ", "") + if res != jsonText { + t.Fatalf("invalid message: %s", string(s)) + } + + invalidMsg := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "some-text", + MultiContent: []openai.ChatMessagePart{ + { + Type: "text", + Text: "nice-text", + }, + }, + }, + } + _, err = json.Marshal(invalidMsg) + if !errors.Is(err, openai.ErrContentFieldsMisused) { + t.Fatalf("Expected error: %s", err) + } + + err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs) + if err == nil { + t.Fatalf("Expected error") + } + + emptyMultiContentMsg := openai.ChatCompletionMessage{ + Role: "user", + MultiContent: []openai.ChatMessagePart{}, + } + s, err = json.Marshal(emptyMultiContentMsg) + if err != nil { + t.Fatalf("Unexpected error") + } + res = strings.ReplaceAll(string(s), " ", "") + if res != `{"role":"user","content":""}` { + t.Fatalf("invalid message: %s", string(s)) + } +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error From a09cb0c528c110a6955a9ee9a5d021a57ed44b90 Mon Sep 17 00:00:00 2001 From: mikeb26 <83850730+mikeb26@users.noreply.github.com> Date: Sun, 26 Nov 2023 08:45:28 +0000 Subject: [PATCH 083/242] Add completion-with-tool example (#598) As a user of this go SDK it was not immediately intuitive to me how to correctly utilize the function calling capability of GPT4 (https://platform.openai.com/docs/guides/function-calling). While the aformentioned link provides a helpful example written in python, I initially tripped over how to correclty translate the specification of function arguments when usingthis go SDK. To make it easier for others in the future this commit adds a completion-with-tool example showing how to correctly utilize the function calling capability of GPT4 using this SDK end-to-end in a CreateChatCompletion() sequence. --- examples/completion-with-tool/main.go | 94 +++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) create mode 100644 examples/completion-with-tool/main.go diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go new file mode 100644 index 000000000..2c7fedc5e --- /dev/null +++ b/examples/completion-with-tool/main.go @@ -0,0 +1,94 @@ +package main + +import ( + "context" + "fmt" + "os" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + ctx := context.Background() + client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) + + // describe the function & its inputs + params := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + } + f := openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather in a given location", + Parameters: params, + } + t := openai.Tool{ + Type: openai.ToolTypeFunction, + Function: f, + } + + // simulate user asking a question that requires the function + dialogue := []openai.ChatCompletionMessage{ + {Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, + } + fmt.Printf("Asking OpenAI '%v' and providing it a '%v()' function...\n", + dialogue[0].Content, f.Name) + resp, err := client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("Completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + msg := resp.Choices[0].Message + if len(msg.ToolCalls) != 1 { + fmt.Printf("Completion error: len(toolcalls): %v\n", len(msg.ToolCalls)) + return + } + + // simulate calling the function & responding to OpenAI + dialogue = append(dialogue, msg) + fmt.Printf("OpenAI called us back wanting to invoke our function '%v' with params '%v'\n", + msg.ToolCalls[0].Function.Name, msg.ToolCalls[0].Function.Arguments) + dialogue = append(dialogue, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleTool, + Content: "Sunny and 80 degrees.", + Name: msg.ToolCalls[0].Function.Name, + ToolCallID: msg.ToolCalls[0].ID, + }) + fmt.Printf("Sending OpenAI our '%v()' function's response and requesting the reply to the original question...\n", + f.Name) + resp, err = client.CreateChatCompletion(ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: dialogue, + Tools: []openai.Tool{t}, + }, + ) + if err != nil || len(resp.Choices) != 1 { + fmt.Printf("2nd completion error: err:%v len(choices):%v\n", err, + len(resp.Choices)) + return + } + + // display OpenAI's response to the original question utilizing our function + msg = resp.Choices[0].Message + fmt.Printf("OpenAI answered the original request with: %v\n", + msg.Content) +} From c9615e0cbe3b68088ee04221acdfde63d6d20766 Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Wed, 3 Jan 2024 19:42:57 +0800 Subject: [PATCH 084/242] Added support for createImage Azure models (#608) --- image.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/image.go b/image.go index 4fe8b3a32..afd4e196b 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,7 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return } From f10955ce090c7b0d8f38458c753c01cd9b88aca5 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Tue, 9 Jan 2024 16:50:56 +0000 Subject: [PATCH 085/242] Log probabilities for chat completion output tokens (#625) * Add logprobs * Logprobs pointer * Move toplogporbs * Create toplogprobs struct * Remove pointers --- chat.go | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 5b87b6bd7..33b8755ce 100644 --- a/chat.go +++ b/chat.go @@ -200,7 +200,15 @@ type ChatCompletionRequest struct { // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + // LogProbs indicates whether to return log probabilities of the output tokens or not. + // If true, returns the log probabilities of each output token returned in the content of message. + // This option is currently not available on the gpt-4-vision-preview model. + LogProbs bool `json:"logprobs,omitempty"` + // TopLogProbs is an integer between 0 and 5 specifying the number of most likely tokens to return at each + // token position, each with an associated log probability. + // logprobs must be set to true if this parameter is used. + TopLogProbs int `json:"top_logprobs,omitempty"` + User string `json:"user,omitempty"` // Deprecated: use Tools instead. Functions []FunctionDefinition `json:"functions,omitempty"` // Deprecated: use ToolChoice instead. @@ -244,6 +252,28 @@ type FunctionDefinition struct { // Deprecated: use FunctionDefinition instead. type FunctionDefine = FunctionDefinition +type TopLogProbs struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` +} + +// LogProb represents the probability information for a token. +type LogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []byte `json:"bytes,omitempty"` // Omitting the field if it is null + // TopLogProbs is a list of the most likely tokens and their log probability, at this token position. + // In rare cases, there may be fewer than the number of requested top_logprobs returned. + TopLogProbs []TopLogProbs `json:"top_logprobs"` +} + +// LogProbs is the top-level structure containing the log probability information. +type LogProbs struct { + // Content is a list of message content tokens with log probability information. + Content []LogProb `json:"content"` +} + type FinishReason string const ( @@ -273,6 +303,7 @@ type ChatCompletionChoice struct { // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. From 682b7adb0bd645f290031fbca6028feb5c22ab9c Mon Sep 17 00:00:00 2001 From: Alexander Kledal Date: Thu, 11 Jan 2024 11:45:15 +0100 Subject: [PATCH 086/242] Update README.md (#631) Ensure variables in examples are valid --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4cb77db6b..9a479c0a0 100644 --- a/README.md +++ b/README.md @@ -453,7 +453,7 @@ func main() { config := openai.DefaultAzureConfig("your Azure OpenAI Key", "https://your Azure OpenAI Endpoint") // If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function // config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ + // azureModelMapping := map[string]string{ // "gpt-3.5-turbo": "your gpt-3.5-turbo deployment name", // } // return azureModelMapping[model] @@ -559,7 +559,7 @@ func main() { //If you use a deployment name different from the model name, you can customize the AzureModelMapperFunc function //config.AzureModelMapperFunc = func(model string) string { - // azureModelMapping = map[string]string{ + // azureModelMapping := map[string]string{ // "gpt-3.5-turbo":"your gpt-3.5-turbo deployment name", // } // return azureModelMapping[model] From e01a2d7231fafec2c1cbdd176806e3be767df965 Mon Sep 17 00:00:00 2001 From: Matthew Jaffee Date: Mon, 15 Jan 2024 03:33:02 -0600 Subject: [PATCH 087/242] convert EmbeddingModel to string type (#629) This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works. --- embeddings.go | 120 ++++++++------------------------------------- embeddings_test.go | 22 ++------- 2 files changed, 24 insertions(+), 118 deletions(-) diff --git a/embeddings.go b/embeddings.go index 7e2aa7eb0..f79df9df5 100644 --- a/embeddings.go +++ b/embeddings.go @@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} +type EmbeddingModel string const ( - Unknown EmbeddingModel = iota - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchText - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchText - AdaEmbeddingV2 + // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + AdaSimilarity EmbeddingModel = "text-similarity-ada-001" + BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" + CurieSimilarity EmbeddingModel = "text-similarity-curie-001" + DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001" + AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001" + AdaSearchQuery EmbeddingModel = "text-search-ada-query-001" + BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001" + BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001" + CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001" + CurieSearchQuery EmbeddingModel = "text-search-curie-query-001" + DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001" + DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001" + AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001" + AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001" + BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" + BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" + + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" ) -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} - // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, @@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index af04d96bf..846d1995d 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) { // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqStrings) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqTokens) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } } } -func TestEmbeddingModel(t *testing.T) { - var em openai.EmbeddingModel - err := em.UnmarshalText([]byte("text-similarity-ada-001")) - checks.NoError(t, err, "Could not marshal embedding model") - - if em != openai.AdaSimilarity { - t.Errorf("Model is not equal to AdaSimilarity") - } - - err = em.UnmarshalText([]byte("some-non-existent-model")) - checks.NoError(t, err, "Could not marshal embedding model") - if em != openai.Unknown { - t.Errorf("Model is not equal to Unknown") - } -} - func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() From 09f6920ad04666f65dd86ed542e5ebf8bffc93a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E5=AE=8F=E6=95=8F?= Date: Mon, 15 Jan 2024 20:01:49 +0800 Subject: [PATCH 088/242] fixed #594 (#609) APITypeAzure dall-e3 model url Co-authored-by: HanHongmin --- image.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/image.go b/image.go index afd4e196b..665de1a74 100644 --- a/image.go +++ b/image.go @@ -82,6 +82,7 @@ type ImageEditRequest struct { Image *os.File `json:"image,omitempty"` Mask *os.File `json:"mask,omitempty"` Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -131,7 +132,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits"), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return @@ -144,6 +145,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { Image *os.File `json:"image,omitempty"` + Model string `json:"model,omitempty"` N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` @@ -181,7 +183,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations"), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), withBody(body), withContentType(builder.FormDataContentType())) if err != nil { return From 4ce03a919ae9fdcb62e8098a03500ef77eafe348 Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Tue, 16 Jan 2024 04:32:48 -0500 Subject: [PATCH 089/242] Fix Azure embeddings model detection by passing string to `fullURL` (#637) --- embeddings.go | 2 +- embeddings_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/embeddings.go b/embeddings.go index f79df9df5..c144119f8 100644 --- a/embeddings.go +++ b/embeddings.go @@ -228,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index 846d1995d..ed6384f3f 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -158,6 +158,32 @@ func TestEmbeddingEndpoint(t *testing.T) { checks.HasError(t, err, "CreateEmbeddings error") } +func TestAzureEmbeddingEndpoint(t *testing.T) { + client, server, teardown := setupAzureTestServer() + defer teardown() + + sampleEmbeddings := []openai.Embedding{ + {Embedding: []float32{1.23, 4.56, 7.89}}, + {Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}}, + } + + server.RegisterHandler( + "/openai/deployments/text-embedding-ada-002/embeddings", + func(w http.ResponseWriter, r *http.Request) { + resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) + fmt.Fprintln(w, string(resBytes)) + }, + ) + // test create embeddings with strings (simple embedding request) + res, err := client.CreateEmbeddings(context.Background(), openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + }) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } +} + func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) { type fields struct { Object string From eff8dc1118ea82a1b50ee316608e24d83df74d6b Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Thu, 18 Jan 2024 01:42:07 +0800 Subject: [PATCH 090/242] fix(audio): fix audioTextResponse decode (#638) * fix(audio): fix audioTextResponse decode * test(audio): add audioTextResponse decode test * test(audio): simplify code --- client.go | 10 +++++++--- client_test.go | 48 ++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 056226c61..8bbbb875a 100644 --- a/client.go +++ b/client.go @@ -193,10 +193,14 @@ func decodeResponse(body io.Reader, v any) error { return nil } - if result, ok := v.(*string); ok { - return decodeString(body, result) + switch o := v.(type) { + case *string: + return decodeString(body, o) + case *audioTextResponse: + return decodeString(body, &o.Text) + default: + return json.NewDecoder(body).Decode(v) } - return json.NewDecoder(body).Decode(v) } func decodeString(body io.Reader, output *string) error { diff --git a/client_test.go b/client_test.go index 664f9fb92..bc5133edc 100644 --- a/client_test.go +++ b/client_test.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "net/http" + "reflect" "testing" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) var errTestRequestBuilderFailed = errors.New("test request builder failed") @@ -43,23 +45,29 @@ func TestDecodeResponse(t *testing.T) { testCases := []struct { name string value interface{} + expected interface{} body io.Reader hasError bool }{ { - name: "nil input", - value: nil, - body: bytes.NewReader([]byte("")), + name: "nil input", + value: nil, + body: bytes.NewReader([]byte("")), + expected: nil, }, { - name: "string input", - value: &stringInput, - body: bytes.NewReader([]byte("test")), + name: "string input", + value: &stringInput, + body: bytes.NewReader([]byte("test")), + expected: "test", }, { name: "map input", value: &map[string]interface{}{}, body: bytes.NewReader([]byte(`{"test": "test"}`)), + expected: map[string]interface{}{ + "test": "test", + }, }, { name: "reader return error", @@ -67,14 +75,38 @@ func TestDecodeResponse(t *testing.T) { body: &errorReader{err: errors.New("dummy")}, hasError: true, }, + { + name: "audio text input", + value: &audioTextResponse{}, + body: bytes.NewReader([]byte("test")), + expected: audioTextResponse{ + Text: "test", + }, + }, + } + + assertEqual := func(t *testing.T, expected, actual interface{}) { + t.Helper() + if expected == actual { + return + } + v := reflect.ValueOf(actual).Elem().Interface() + if !reflect.DeepEqual(v, expected) { + t.Fatalf("Unexpected value: %v, expected: %v", v, expected) + } } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err := decodeResponse(tc.body, tc.value) - if (err != nil) != tc.hasError { - t.Errorf("Unexpected error: %v", err) + if tc.hasError { + checks.HasError(t, err, "Unexpected nil error") + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) } + assertEqual(t, tc.expected, tc.value) }) } } From 4c41f24a99ad56f707df7c25b8833fb0a374c8c5 Mon Sep 17 00:00:00 2001 From: Daniil <7709243+bazuker@users.noreply.github.com> Date: Fri, 26 Jan 2024 00:41:48 -0800 Subject: [PATCH 091/242] Support January 25, 2024, models update. (#644) --- completion.go | 6 +++++- embeddings.go | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 2709c8b03..6326a72a8 100644 --- a/completion.go +++ b/completion.go @@ -22,7 +22,9 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" - GPT4TurboPreview = "gpt-4-1106-preview" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" @@ -78,6 +80,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4TurboPreview: true, GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, GPT40314: true, GPT40613: true, GPT432K: true, diff --git a/embeddings.go b/embeddings.go index c144119f8..517027f5a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -34,7 +34,9 @@ const ( BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" - AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" + SmallEmbedding3 EmbeddingModel = "text-embedding-3-small" + LargeEmbedding3 EmbeddingModel = "text-embedding-3-large" ) // Embedding is a special format of data representation that can be easily utilized by machine From 06ff541559eaf66482a89202da946644b6c96510 Mon Sep 17 00:00:00 2001 From: chenhhA <463474838@qq.com> Date: Mon, 29 Jan 2024 15:09:56 +0800 Subject: [PATCH 092/242] Add new struct filed dimensions for embedding API (#645) * add new struct filed dimensions for embedding API * docs: remove long single-line comments * change embedding request param Dimensions type to int --- embeddings.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/embeddings.go b/embeddings.go index 517027f5a..c5633a313 100644 --- a/embeddings.go +++ b/embeddings.go @@ -157,6 +157,9 @@ type EmbeddingRequest struct { Model EmbeddingModel `json:"model"` User string `json:"user"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -181,6 +184,9 @@ type EmbeddingRequestStrings struct { // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. // If not specified OpenAI will use "float". EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -189,6 +195,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { Model: r.Model, User: r.User, EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, } } @@ -209,6 +216,9 @@ type EmbeddingRequestTokens struct { // Currently, only "float" and "base64" are supported, however, "base64" is not officially documented. // If not specified OpenAI will use "float". EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` + // Dimensions The number of dimensions the resulting output embeddings should have. + // Only supported in text-embedding-3 and later models. + Dimensions int `json:"dimensions,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -217,6 +227,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { Model: r.Model, User: r.User, EncodingFormat: r.EncodingFormat, + Dimensions: r.Dimensions, } } From bc8cdd33d158ea165fcecde4a64fc5f1580f0192 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Fri, 2 Feb 2024 18:30:24 +0800 Subject: [PATCH 093/242] add GPT3Dot5Turbo0125 model (#648) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 6326a72a8..ab1dbd6c5 100644 --- a/completion.go +++ b/completion.go @@ -27,6 +27,7 @@ const ( GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" @@ -75,6 +76,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, From bb6ed545306ba56b99d297a77da0a93b0bcfb80e Mon Sep 17 00:00:00 2001 From: shadowpigy <71599610+shadowpigy@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:41:39 +0800 Subject: [PATCH 094/242] Fix: Add RunStatusCancelled (#650) Co-authored-by: shadowpigy --- run.go | 1 + 1 file changed, 1 insertion(+) diff --git a/run.go b/run.go index dbb708a13..d06756572 100644 --- a/run.go +++ b/run.go @@ -40,6 +40,7 @@ const ( RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" RunStatusExpired RunStatus = "expired" + RunStatusCancelled RunStatus = "cancelled" ) type RunRequiredAction struct { From 69e3fcbc2726d208d34e9d89089b47ebebdff01b Mon Sep 17 00:00:00 2001 From: chrbsg <52408325+chrbsg@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:04:40 +0000 Subject: [PATCH 095/242] Fix typo assitantInstructions (#655) --- assistant_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/assistant_test.go b/assistant_test.go index 30daec2b1..9e1e3f38d 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -17,7 +17,7 @@ func TestAssistant(t *testing.T) { assistantID := "asst_abc123" assistantName := "Ambrogio" assistantDescription := "Ambrogio is a friendly assistant." - assitantInstructions := `You are a personal math tutor. + assistantInstructions := `You are a personal math tutor. When asked a question, write and run Python code to answer the question.` assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" limit := 20 @@ -92,7 +92,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Model: openai.GPT4TurboPreview, Description: &assistantDescription, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: @@ -152,7 +152,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Model: openai.GPT4TurboPreview, Description: &assistantDescription, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }, }, }) @@ -167,7 +167,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Description: &assistantDescription, Model: openai.GPT4TurboPreview, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) checks.NoError(t, err, "CreateAssistant error") @@ -178,7 +178,7 @@ When asked a question, write and run Python code to answer the question.` Name: &assistantName, Description: &assistantDescription, Model: openai.GPT4TurboPreview, - Instructions: &assitantInstructions, + Instructions: &assistantInstructions, }) checks.NoError(t, err, "ModifyAssistant error") From 6c2e3162dfe3b32cbd1d026043957f8e589e987c Mon Sep 17 00:00:00 2001 From: "xuanming.zhang" Date: Thu, 8 Feb 2024 15:40:39 +0800 Subject: [PATCH 096/242] Added support for CreateSpeech Azure models (#657) --- speech.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/speech.go b/speech.go index f2442b921..b9344ac66 100644 --- a/speech.go +++ b/speech.go @@ -74,7 +74,7 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) err = ErrInvalidVoice return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", request.Model), + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json; charset=utf-8"), ) From a7954c854c89f45d3f5df62aab8df688b4c20b20 Mon Sep 17 00:00:00 2001 From: shadowpigy <71599610+shadowpigy@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:08:30 +0800 Subject: [PATCH 097/242] Feat: Add assistant usage (#649) * Feat: Add assistant usage --------- Co-authored-by: shadowpigy --- run.go | 1 + 1 file changed, 1 insertion(+) diff --git a/run.go b/run.go index d06756572..4befe0b44 100644 --- a/run.go +++ b/run.go @@ -26,6 +26,7 @@ type Run struct { Tools []Tool `json:"tools"` FileIDS []string `json:"file_ids"` Metadata map[string]any `json:"metadata"` + Usage Usage `json:"usage,omitempty"` httpHeader } From 11ad4b69d0f0dc61ed8777ac2d54a6787c8d2fea Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:02:48 +0400 Subject: [PATCH 098/242] make linter happy (#661) --- embeddings_test.go | 2 +- files_api_test.go | 10 +++++----- image_test.go | 10 +++++----- messages.go | 4 ++-- models_test.go | 2 +- run.go | 2 +- stream_test.go | 14 +++++++------- 7 files changed, 22 insertions(+), 22 deletions(-) diff --git a/embeddings_test.go b/embeddings_test.go index ed6384f3f..438978169 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -169,7 +169,7 @@ func TestAzureEmbeddingEndpoint(t *testing.T) { server.RegisterHandler( "/openai/deployments/text-embedding-ada-002/embeddings", - func(w http.ResponseWriter, r *http.Request) { + func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.EmbeddingResponse{Data: sampleEmbeddings}) fmt.Fprintln(w, string(resBytes)) }, diff --git a/files_api_test.go b/files_api_test.go index 6f62a3fbc..c92162a84 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -86,7 +86,7 @@ func handleCreateFile(w http.ResponseWriter, r *http.Request) { func TestDeleteFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {}) + server.RegisterHandler("/v1/files/deadbeef", func(http.ResponseWriter, *http.Request) {}) err := client.DeleteFile(context.Background(), "deadbeef") checks.NoError(t, err, "DeleteFile error") } @@ -94,7 +94,7 @@ func TestDeleteFile(t *testing.T) { func TestListFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.FilesList{}) fmt.Fprintln(w, string(resBytes)) }) @@ -105,7 +105,7 @@ func TestListFile(t *testing.T) { func TestGetFile(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, _ *http.Request) { resBytes, _ := json.Marshal(openai.File{}) fmt.Fprintln(w, string(resBytes)) }) @@ -151,7 +151,7 @@ func TestGetFileContentReturnError(t *testing.T) { }` client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) @@ -178,7 +178,7 @@ func TestGetFileContentReturnError(t *testing.T) { func TestGetFileContentReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/files/deadbeef/content", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() diff --git a/image_test.go b/image_test.go index 81fff6cba..9332dd5cd 100644 --- a/image_test.go +++ b/image_test.go @@ -60,7 +60,7 @@ func TestImageFormBuilderFailures(t *testing.T) { _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { if name == "mask" { return mockFailedErr } @@ -69,12 +69,12 @@ func TestImageFormBuilderFailures(t *testing.T) { _, err = client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(string, *os.File) error { return nil } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } @@ -125,12 +125,12 @@ func TestVariImageFormBuilderFailures(t *testing.T) { _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, file *os.File) error { + mockBuilder.mockCreateFormFile = func(string, *os.File) error { return nil } var failForField string - mockBuilder.mockWriteField = func(fieldname, value string) error { + mockBuilder.mockWriteField = func(fieldname, _ string) error { if fieldname == failForField { return mockFailedErr } diff --git a/messages.go b/messages.go index ead247f5b..861463235 100644 --- a/messages.go +++ b/messages.go @@ -18,7 +18,7 @@ type Message struct { ThreadID string `json:"thread_id"` Role string `json:"role"` Content []MessageContent `json:"content"` - FileIds []string `json:"file_ids"` + FileIds []string `json:"file_ids"` //nolint:revive //backwards-compatibility AssistantID *string `json:"assistant_id,omitempty"` RunID *string `json:"run_id,omitempty"` Metadata map[string]any `json:"metadata"` @@ -54,7 +54,7 @@ type ImageFile struct { type MessageRequest struct { Role string `json:"role"` Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility Metadata map[string]any `json:"metadata,omitempty"` } diff --git a/models_test.go b/models_test.go index 4a4c759dc..24a28ed23 100644 --- a/models_test.go +++ b/models_test.go @@ -64,7 +64,7 @@ func handleGetModelEndpoint(w http.ResponseWriter, _ *http.Request) { func TestGetModelReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/models/text-davinci-003", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/models/text-davinci-003", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() diff --git a/run.go b/run.go index 4befe0b44..ba09366cb 100644 --- a/run.go +++ b/run.go @@ -24,7 +24,7 @@ type Run struct { Model string `json:"model"` Instructions string `json:"instructions,omitempty"` Tools []Tool `json:"tools"` - FileIDS []string `json:"file_ids"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` diff --git a/stream_test.go b/stream_test.go index 35c52ae3b..2822a3535 100644 --- a/stream_test.go +++ b/stream_test.go @@ -34,7 +34,7 @@ func TestCompletionsStreamWrongModel(t *testing.T) { func TestCreateCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -106,7 +106,7 @@ func TestCreateCompletionStream(t *testing.T) { func TestCreateCompletionStreamError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -151,7 +151,7 @@ func TestCreateCompletionStreamError(t *testing.T) { func TestCreateCompletionStreamRateLimitError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(429) @@ -182,7 +182,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -228,7 +228,7 @@ func TestCreateCompletionStreamTooManyEmptyStreamMessagesError(t *testing.T) { func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -263,7 +263,7 @@ func TestCreateCompletionStreamUnexpectedTerminatedError(t *testing.T) { func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/event-stream") // Send test responses @@ -305,7 +305,7 @@ func TestCreateCompletionStreamBrokenJSONError(t *testing.T) { func TestCreateCompletionStreamReturnTimeoutError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() - server.RegisterHandler("/v1/completions", func(w http.ResponseWriter, r *http.Request) { + server.RegisterHandler("/v1/completions", func(http.ResponseWriter, *http.Request) { time.Sleep(10 * time.Nanosecond) }) ctx := context.Background() From 66bae3ee7329619b27ba8bcb185e0d333e9b3e26 Mon Sep 17 00:00:00 2001 From: grulex Date: Thu, 15 Feb 2024 16:11:58 +0000 Subject: [PATCH 099/242] Content-type fix (#659) * charset fixes * make linter happy (#661) --------- Co-authored-by: grulex Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- client.go | 4 ++-- speech.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 8bbbb875a..55c48bd47 100644 --- a/client.go +++ b/client.go @@ -107,13 +107,13 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... } func (c *Client) sendRequest(req *http.Request, v Response) error { - req.Header.Set("Accept", "application/json; charset=utf-8") + req.Header.Set("Accept", "application/json") // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data contentType := req.Header.Get("Content-Type") if contentType == "" { - req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("Content-Type", "application/json") } res, err := c.config.HTTPClient.Do(req) diff --git a/speech.go b/speech.go index b9344ac66..be8950218 100644 --- a/speech.go +++ b/speech.go @@ -76,7 +76,7 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), - withContentType("application/json; charset=utf-8"), + withContentType("application/json"), ) if err != nil { return From ff61bbb32253aad84c6cc96bf9be3884aa8cde88 Mon Sep 17 00:00:00 2001 From: chrbsg <52408325+chrbsg@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:12:22 +0000 Subject: [PATCH 100/242] Add RunRequest field AdditionalInstructions (#656) AdditionalInstructions is an optional string field used to append additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. Also, change the Model and Instructions *string fields to string. --- run.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/run.go b/run.go index ba09366cb..1f3cb7eb7 100644 --- a/run.go +++ b/run.go @@ -72,11 +72,12 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model *string `json:"model,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } type RunModifyRequest struct { From 69e3bbb1eb05a5c1b27a29fc9a83d02d0d040e27 Mon Sep 17 00:00:00 2001 From: Igor Berlenko Date: Fri, 16 Feb 2024 18:22:38 +0800 Subject: [PATCH 101/242] Update client.go - allow to skip Authorization header (#658) * Update client.go - allow to skip Authorization header * Update client.go --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 55c48bd47..7fdc36caa 100644 --- a/client.go +++ b/client.go @@ -175,7 +175,7 @@ func (c *Client) setCommonHeaders(req *http.Request) { // Azure API Key authentication if c.config.APIType == APITypeAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { + } else if c.config.authToken != "" { // OpenAI or Azure AD authentication req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) } From e8b347891b21187740d594409b1c11fb0846577e Mon Sep 17 00:00:00 2001 From: CaoPengFlying Date: Mon, 19 Feb 2024 20:26:04 +0800 Subject: [PATCH 102/242] fix:fix open ai original validation. modify Tool's Function to pointer (#664) Co-authored-by: caopengfei1 --- chat.go | 4 ++-- examples/completion-with-tool/main.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/chat.go b/chat.go index 33b8755ce..efb14fd4c 100644 --- a/chat.go +++ b/chat.go @@ -225,8 +225,8 @@ const ( ) type Tool struct { - Type ToolType `json:"type"` - Function FunctionDefinition `json:"function,omitempty"` + Type ToolType `json:"type"` + Function *FunctionDefinition `json:"function,omitempty"` } type ToolChoice struct { diff --git a/examples/completion-with-tool/main.go b/examples/completion-with-tool/main.go index 2c7fedc5e..26126e41b 100644 --- a/examples/completion-with-tool/main.go +++ b/examples/completion-with-tool/main.go @@ -35,7 +35,7 @@ func main() { } t := openai.Tool{ Type: openai.ToolTypeFunction, - Function: f, + Function: &f, } // simulate user asking a question that requires the function From 7381d18a75a673d569c7dc7657407381e5c84dd5 Mon Sep 17 00:00:00 2001 From: Rich Coggins <57115183+coggsflod@users.noreply.github.com> Date: Wed, 21 Feb 2024 07:45:15 -0500 Subject: [PATCH 103/242] Fix for broken Azure Assistants url (#665) * fix:fix url for Azure assistants api * test:add unit tests for Azure Assistants api * fix:minor liniting issue --- assistant_test.go | 190 ++++++++++++++++++++++++++++++++++++++++++++++ client.go | 2 +- 2 files changed, 191 insertions(+), 1 deletion(-) diff --git a/assistant_test.go b/assistant_test.go index 9e1e3f38d..48bc6f91d 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -202,3 +202,193 @@ When asked a question, write and run Python code to answer the question.` err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) checks.NoError(t, err, "DeleteAssistantFile error") } + +func TestAzureAssistant(t *testing.T) { + assistantID := "asst_abc123" + assistantName := "Ambrogio" + assistantDescription := "Ambrogio is a friendly assistant." + assistantInstructions := `You are a personal math tutor. +When asked a question, write and run Python code to answer the question.` + assistantFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + limit := 20 + order := "desc" + after := "asst_abc122" + before := "asst_abc124" + + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files/"+assistantFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "assistant.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantFilesList{ + AssistantFiles: []openai.AssistantFile{ + { + ID: assistantFileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.AssistantFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.AssistantFile{ + ID: request.FileID, + Object: "assistant.file", + CreatedAt: 1234567890, + AssistantID: assistantID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants/"+assistantID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/assistants", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.AssistantRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Assistant{ + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: request.Name, + Model: request.Model, + Description: request.Description, + Instructions: request.Instructions, + Tools: request.Tools, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.AssistantsList{ + LastID: &assistantID, + FirstID: &assistantID, + Assistants: []openai.Assistant{ + { + ID: assistantID, + Object: "assistant", + CreatedAt: 1234567890, + Name: &assistantName, + Model: openai.GPT4TurboPreview, + Description: &assistantDescription, + Instructions: &assistantInstructions, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") + + _, err = client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + + _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + _, err = client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") + + _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + + _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + + _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") + + _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + + err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") +} diff --git a/client.go b/client.go index 7fdc36caa..e7a4d5beb 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") { + if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" From c5401e9e6417ac2b5374993ccff1f40010e03f52 Mon Sep 17 00:00:00 2001 From: Rich Coggins <57115183+coggsflod@users.noreply.github.com> Date: Mon, 26 Feb 2024 03:46:35 -0500 Subject: [PATCH 104/242] Fix for broken Azure Threads url (#668) --- client.go | 11 ++++++- thread_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index e7a4d5beb..7b1a313a8 100644 --- a/client.go +++ b/client.go @@ -221,7 +221,7 @@ func (c *Client) fullURL(suffix string, args ...any) string { baseURL = strings.TrimRight(baseURL, "/") // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if strings.Contains(suffix, "/models") || strings.Contains(suffix, "/assistants") { + if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) } azureDeploymentName := "UNKNOWN" @@ -258,3 +258,12 @@ func (c *Client) handleErrorResp(resp *http.Response) error { errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } + +func containsSubstr(s []string, e string) bool { + for _, v := range s { + if strings.Contains(e, v) { + return true + } + } + return false +} diff --git a/thread_test.go b/thread_test.go index 227ab6330..1ac0f3c0e 100644 --- a/thread_test.go +++ b/thread_test.go @@ -93,3 +93,86 @@ func TestThread(t *testing.T) { _, err = client.DeleteThread(ctx, threadID) checks.NoError(t, err, "DeleteThread error") } + +// TestAzureThread Tests the thread endpoint of the API using the Azure mocked server. +func TestAzureThread(t *testing.T) { + threadID := "thread_abc123" + client, server, teardown := setupAzureTestServer() + defer teardown() + + server.RegisterHandler( + "/openai/threads/"+threadID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.ThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/openai/threads", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.ModifyThreadRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.Thread{ + ID: threadID, + Object: "thread", + CreatedAt: 1234567890, + Metadata: request.Metadata, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + _, err := client.CreateThread(ctx, openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }) + checks.NoError(t, err, "CreateThread error") + + _, err = client.RetrieveThread(ctx, threadID) + checks.NoError(t, err, "RetrieveThread error") + + _, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{ + Metadata: map[string]interface{}{ + "key": "value", + }, + }) + checks.NoError(t, err, "ModifyThread error") + + _, err = client.DeleteThread(ctx, threadID) + checks.NoError(t, err, "DeleteThread error") +} From f2204439857a1085207e74c8f05abf6c8248d336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Oester?= <56402078+raphoester@users.noreply.github.com> Date: Mon, 26 Feb 2024 10:48:09 +0200 Subject: [PATCH 105/242] Added fields for moderation (#662) --- moderation.go | 36 ++++++++++++++++++++++-------------- moderation_test.go | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 20 deletions(-) diff --git a/moderation.go b/moderation.go index f8d20ee51..45d05248e 100644 --- a/moderation.go +++ b/moderation.go @@ -44,24 +44,32 @@ type Result struct { // ResultCategories represents Categories of Result. type ResultCategories struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - SelfHarm bool `json:"self-harm"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate float32 `json:"hate"` - HateThreatening float32 `json:"hate/threatening"` - SelfHarm float32 `json:"self-harm"` - Sexual float32 `json:"sexual"` - SexualMinors float32 `json:"sexual/minors"` - Violence float32 `json:"violence"` - ViolenceGraphic float32 `json:"violence/graphic"` + Hate bool `json:"hate"` + HateThreatening bool `json:"hate/threatening"` + Harassment bool `json:"harassment"` + HarassmentThreatening bool `json:"harassment/threatening"` + SelfHarm bool `json:"self-harm"` + SelfHarmIntent bool `json:"self-harm/intent"` + SelfHarmInstructions bool `json:"self-harm/instructions"` + Sexual bool `json:"sexual"` + SexualMinors bool `json:"sexual/minors"` + Violence bool `json:"violence"` + ViolenceGraphic bool `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 059f0d1c7..7fdeb9baf 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -80,18 +80,49 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { resCat := openai.ResultCategories{} resCatScore := openai.ResultCategoryScores{} switch { - case strings.Contains(moderationReq.Input, "kill"): - resCat = openai.ResultCategories{Violence: true} - resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "hate"): resCat = openai.ResultCategories{Hate: true} - resCatScore = openai.ResultCategoryScores{Hate: 1} + resCatScore = openai.ResultCategoryScores{Hate: true} + + case strings.Contains(moderationReq.Input, "hate more"): + resCat = openai.ResultCategories{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: true} + + case strings.Contains(moderationReq.Input, "harass"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: true} + + case strings.Contains(moderationReq.Input, "harass hard"): + resCat = openai.ResultCategories{Harassment: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: true} + case strings.Contains(moderationReq.Input, "suicide"): resCat = openai.ResultCategories{SelfHarm: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: 1} + resCatScore = openai.ResultCategoryScores{SelfHarm: true} + + case strings.Contains(moderationReq.Input, "wanna suicide"): + resCat = openai.ResultCategories{SelfHarmIntent: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: true} + + case strings.Contains(moderationReq.Input, "drink bleach"): + resCat = openai.ResultCategories{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: true} + case strings.Contains(moderationReq.Input, "porn"): resCat = openai.ResultCategories{Sexual: true} - resCatScore = openai.ResultCategoryScores{Sexual: 1} + resCatScore = openai.ResultCategoryScores{Sexual: true} + + case strings.Contains(moderationReq.Input, "child porn"): + resCat = openai.ResultCategories{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: true} + + case strings.Contains(moderationReq.Input, "kill"): + resCat = openai.ResultCategories{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: true} + + case strings.Contains(moderationReq.Input, "corpse"): + resCat = openai.ResultCategories{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: true} } result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} From 41037783bc7668998900248ed697b90ec36c3f09 Mon Sep 17 00:00:00 2001 From: Guillaume Dussault <146769929+guillaume-dussault@users.noreply.github.com> Date: Mon, 26 Feb 2024 03:48:53 -0500 Subject: [PATCH 106/242] fix: when no Assistant Tools are specified, an empty list should be sent (#669) --- assistant.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/assistant.go b/assistant.go index bd335833a..7a7a7652e 100644 --- a/assistant.go +++ b/assistant.go @@ -46,7 +46,7 @@ type AssistantRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } From bb6149f64fcb22381b2ef0b5c7d8287a520dc110 Mon Sep 17 00:00:00 2001 From: Martin Heck Date: Wed, 28 Feb 2024 10:25:47 +0100 Subject: [PATCH 107/242] fix: repair json decoding of moderation response (#670) --- moderation.go | 22 +++++++++++----------- moderation_test.go | 22 +++++++++++----------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/moderation.go b/moderation.go index 45d05248e..ae285ef83 100644 --- a/moderation.go +++ b/moderation.go @@ -59,17 +59,17 @@ type ResultCategories struct { // ResultCategoryScores represents CategoryScores of Result. type ResultCategoryScores struct { - Hate bool `json:"hate"` - HateThreatening bool `json:"hate/threatening"` - Harassment bool `json:"harassment"` - HarassmentThreatening bool `json:"harassment/threatening"` - SelfHarm bool `json:"self-harm"` - SelfHarmIntent bool `json:"self-harm/intent"` - SelfHarmInstructions bool `json:"self-harm/instructions"` - Sexual bool `json:"sexual"` - SexualMinors bool `json:"sexual/minors"` - Violence bool `json:"violence"` - ViolenceGraphic bool `json:"violence/graphic"` + Hate float32 `json:"hate"` + HateThreatening float32 `json:"hate/threatening"` + Harassment float32 `json:"harassment"` + HarassmentThreatening float32 `json:"harassment/threatening"` + SelfHarm float32 `json:"self-harm"` + SelfHarmIntent float32 `json:"self-harm/intent"` + SelfHarmInstructions float32 `json:"self-harm/instructions"` + Sexual float32 `json:"sexual"` + SexualMinors float32 `json:"sexual/minors"` + Violence float32 `json:"violence"` + ViolenceGraphic float32 `json:"violence/graphic"` } // ModerationResponse represents a response structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 7fdeb9baf..61171c384 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -82,47 +82,47 @@ func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(moderationReq.Input, "hate"): resCat = openai.ResultCategories{Hate: true} - resCatScore = openai.ResultCategoryScores{Hate: true} + resCatScore = openai.ResultCategoryScores{Hate: 1} case strings.Contains(moderationReq.Input, "hate more"): resCat = openai.ResultCategories{HateThreatening: true} - resCatScore = openai.ResultCategoryScores{HateThreatening: true} + resCatScore = openai.ResultCategoryScores{HateThreatening: 1} case strings.Contains(moderationReq.Input, "harass"): resCat = openai.ResultCategories{Harassment: true} - resCatScore = openai.ResultCategoryScores{Harassment: true} + resCatScore = openai.ResultCategoryScores{Harassment: 1} case strings.Contains(moderationReq.Input, "harass hard"): resCat = openai.ResultCategories{Harassment: true} - resCatScore = openai.ResultCategoryScores{HarassmentThreatening: true} + resCatScore = openai.ResultCategoryScores{HarassmentThreatening: 1} case strings.Contains(moderationReq.Input, "suicide"): resCat = openai.ResultCategories{SelfHarm: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "wanna suicide"): resCat = openai.ResultCategories{SelfHarmIntent: true} - resCatScore = openai.ResultCategoryScores{SelfHarm: true} + resCatScore = openai.ResultCategoryScores{SelfHarm: 1} case strings.Contains(moderationReq.Input, "drink bleach"): resCat = openai.ResultCategories{SelfHarmInstructions: true} - resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: true} + resCatScore = openai.ResultCategoryScores{SelfHarmInstructions: 1} case strings.Contains(moderationReq.Input, "porn"): resCat = openai.ResultCategories{Sexual: true} - resCatScore = openai.ResultCategoryScores{Sexual: true} + resCatScore = openai.ResultCategoryScores{Sexual: 1} case strings.Contains(moderationReq.Input, "child porn"): resCat = openai.ResultCategories{SexualMinors: true} - resCatScore = openai.ResultCategoryScores{SexualMinors: true} + resCatScore = openai.ResultCategoryScores{SexualMinors: 1} case strings.Contains(moderationReq.Input, "kill"): resCat = openai.ResultCategories{Violence: true} - resCatScore = openai.ResultCategoryScores{Violence: true} + resCatScore = openai.ResultCategoryScores{Violence: 1} case strings.Contains(moderationReq.Input, "corpse"): resCat = openai.ResultCategories{ViolenceGraphic: true} - resCatScore = openai.ResultCategoryScores{ViolenceGraphic: true} + resCatScore = openai.ResultCategoryScores{ViolenceGraphic: 1} } result := openai.Result{Categories: resCat, CategoryScores: resCatScore, Flagged: true} From 38b16a3c413a3ea076cf4082ea5cd1754b72c70f Mon Sep 17 00:00:00 2001 From: Bilal Hameed <68427058+LinuxSploit@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:56:50 +0500 Subject: [PATCH 108/242] Added 'wav' and 'pcm' Audio Formats (#671) * Added 'wav' and 'pcm' Audio Formats Added "wav" and "pcm" audio formats as per OpenAI API documentation for createSpeech endpoint. Ref: https://platform.openai.com/docs/api-reference/audio/createSpeech Supported formats are mp3, opus, aac, flac, wav, and pcm. * Removed Extra Newline for Sanity Check * fix: run goimports to get accepted by the linter --- speech.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/speech.go b/speech.go index be8950218..92b30b55b 100644 --- a/speech.go +++ b/speech.go @@ -33,6 +33,8 @@ const ( SpeechResponseFormatOpus SpeechResponseFormat = "opus" SpeechResponseFormatAac SpeechResponseFormat = "aac" SpeechResponseFormatFlac SpeechResponseFormat = "flac" + SpeechResponseFormatWav SpeechResponseFormat = "wav" + SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) var ( From 699f397c36d05e42210f65456436a447885cc07a Mon Sep 17 00:00:00 2001 From: sunshineplan Date: Mon, 11 Mar 2024 15:27:48 +0800 Subject: [PATCH 109/242] Update streamReader Close() method to return error (#681) --- stream_reader.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index d17412591..4210a1948 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -108,6 +108,6 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { return } -func (stream *streamReader[T]) Close() { - stream.response.Body.Close() +func (stream *streamReader[T]) Close() error { + return stream.response.Body.Close() } From 0925563e86c2fdc5011310aa616ba493989cfe0a Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Fri, 15 Mar 2024 18:59:16 +0800 Subject: [PATCH 110/242] Fix broken implementation AssistantModify implementation (#685) * add custom marshaller, documentation and isolate tests * fix linter --- assistant.go | 30 ++++++++++++- assistant_test.go | 109 ++++++++++++++++++++++++++++++++++------------ 2 files changed, 109 insertions(+), 30 deletions(-) diff --git a/assistant.go b/assistant.go index 7a7a7652e..4ca2dda62 100644 --- a/assistant.go +++ b/assistant.go @@ -2,6 +2,7 @@ package openai import ( "context" + "encoding/json" "fmt" "net/http" "net/url" @@ -21,7 +22,7 @@ type Assistant struct { Description *string `json:"description,omitempty"` Model string `json:"model"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools,omitempty"` + Tools []AssistantTool `json:"tools"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` @@ -41,16 +42,41 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +// AssistantRequest provides the assistant request parameters. +// When modifying the tools the API functions as the following: +// If Tools is undefined, no changes are made to the Assistant's tools. +// If Tools is empty slice it will effectively delete all of the Assistant's tools. +// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { Model string `json:"model"` Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` + Tools []AssistantTool `json:"-"` FileIDs []string `json:"file_ids,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } +// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases +// If Tools is nil, the field is omitted from the JSON. +// If Tools is an empty slice, it's included in the JSON as an empty array ([]). +// If Tools is populated, it's included in the JSON with the elements. +func (a AssistantRequest) MarshalJSON() ([]byte, error) { + type Alias AssistantRequest + assistantAlias := &struct { + Tools *[]AssistantTool `json:"tools,omitempty"` + *Alias + }{ + Alias: (*Alias)(&a), + } + + if a.Tools != nil { + assistantAlias.Tools = &a.Tools + } + + return json.Marshal(assistantAlias) +} + // AssistantsList is a list of assistants. type AssistantsList struct { Assistants []Assistant `json:"data"` diff --git a/assistant_test.go b/assistant_test.go index 48bc6f91d..40de0e50f 100644 --- a/assistant_test.go +++ b/assistant_test.go @@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.` }) fmt.Fprintln(w, string(resBytes)) case http.MethodPost: - var request openai.AssistantRequest + var request openai.Assistant err := json.NewDecoder(r.Body).Decode(&request) checks.NoError(t, err, "Decode error") @@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.` ctx := context.Background() - _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("create_assistant", func(t *testing.T) { + _, err := client.CreateAssistant(ctx, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "CreateAssistant error") }) - checks.NoError(t, err, "CreateAssistant error") - _, err = client.RetrieveAssistant(ctx, assistantID) - checks.NoError(t, err, "RetrieveAssistant error") + t.Run("retrieve_assistant", func(t *testing.T) { + _, err := client.RetrieveAssistant(ctx, assistantID) + checks.NoError(t, err, "RetrieveAssistant error") + }) - _, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ - Name: &assistantName, - Description: &assistantDescription, - Model: openai.GPT4TurboPreview, - Instructions: &assistantInstructions, + t.Run("delete_assistant", func(t *testing.T) { + _, err := client.DeleteAssistant(ctx, assistantID) + checks.NoError(t, err, "DeleteAssistant error") }) - checks.NoError(t, err, "ModifyAssistant error") - _, err = client.DeleteAssistant(ctx, assistantID) - checks.NoError(t, err, "DeleteAssistant error") + t.Run("list_assistant", func(t *testing.T) { + _, err := client.ListAssistants(ctx, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistants error") + }) - _, err = client.ListAssistants(ctx, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistants error") + t.Run("create_assistant_file", func(t *testing.T) { + _, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ + FileID: assistantFileID, + }) + checks.NoError(t, err, "CreateAssistantFile error") + }) - _, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{ - FileID: assistantFileID, + t.Run("list_assistant_files", func(t *testing.T) { + _, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) + checks.NoError(t, err, "ListAssistantFiles error") }) - checks.NoError(t, err, "CreateAssistantFile error") - _, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before) - checks.NoError(t, err, "ListAssistantFiles error") + t.Run("retrieve_assistant_file", func(t *testing.T) { + _, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "RetrieveAssistantFile error") + }) - _, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "RetrieveAssistantFile error") + t.Run("delete_assistant_file", func(t *testing.T) { + err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID) + checks.NoError(t, err, "DeleteAssistantFile error") + }) - err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID) - checks.NoError(t, err, "DeleteAssistantFile error") + t.Run("modify_assistant_no_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools != nil { + t.Errorf("expected nil got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_with_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}}, + }) + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil || len(assistant.Tools) != 1 { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) + + t.Run("modify_assistant_empty_tools", func(t *testing.T) { + assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{ + Name: &assistantName, + Description: &assistantDescription, + Model: openai.GPT4TurboPreview, + Instructions: &assistantInstructions, + Tools: make([]openai.AssistantTool, 0), + }) + + checks.NoError(t, err, "ModifyAssistant error") + + if assistant.Tools == nil { + t.Errorf("expected a slice got %v", assistant.Tools) + } + }) } func TestAzureAssistant(t *testing.T) { From 2646bce71c0cc907e2a3d050130b712c1e5688db Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 6 Apr 2024 03:15:54 +0800 Subject: [PATCH 111/242] feat: get header from sendRequestRaw (#694) * feat: get header from sendRequestRaw * Fix ci lint --- client.go | 15 ++++++++++++--- files.go | 6 ++---- speech.go | 7 ++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/client.go b/client.go index 7b1a313a8..9a1c8958d 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,12 @@ func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { return newRateLimitHeaders(h.Header()) } +type RawResponse struct { + io.ReadCloser + + httpHeader +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -134,8 +140,8 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { return decodeResponse(res.Body, v) } -func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { - resp, err := c.config.HTTPClient.Do(req) +func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err error) { + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body should be closed by outer function if err != nil { return } @@ -144,7 +150,10 @@ func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err erro err = c.handleErrorResp(resp) return } - return resp.Body, nil + + response.SetHeader(resp.Header) + response.ReadCloser = resp.Body + return } func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { diff --git a/files.go b/files.go index a37d45f18..b40a44f15 100644 --- a/files.go +++ b/files.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "io" "net/http" "os" ) @@ -159,13 +158,12 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err return } -func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content RawResponse, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - content, err = c.sendRequestRaw(req) - return + return c.sendRequestRaw(req) } diff --git a/speech.go b/speech.go index 92b30b55b..7e22e755c 100644 --- a/speech.go +++ b/speech.go @@ -3,7 +3,6 @@ package openai import ( "context" "errors" - "io" "net/http" ) @@ -67,7 +66,7 @@ func isValidVoice(voice SpeechVoice) bool { return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) } -func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { if !isValidSpeechModel(request.Model) { err = ErrInvalidSpeechModel return @@ -84,7 +83,5 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) return } - response, err = c.sendRequestRaw(req) - - return + return c.sendRequestRaw(req) } From 774fc9dd12ed60c10a9f9f03319ddb9cd5f8780c Mon Sep 17 00:00:00 2001 From: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 5 Apr 2024 23:24:30 +0400 Subject: [PATCH 112/242] make linter happy (#701) --- fine_tunes.go | 1 - 1 file changed, 1 deletion(-) diff --git a/fine_tunes.go b/fine_tunes.go index 46f89f165..ca840781c 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,6 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - //nolint:goconst // Decreases readability req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) if err != nil { return From 187f4169f8898d78716f7944d87e5d95aa9a7c41 Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Tue, 9 Apr 2024 16:22:31 +0800 Subject: [PATCH 113/242] [BREAKING_CHANGES] Fix update message payload (#699) * add custom marshaller, documentation and isolate tests * fix linter * wrap payload as expected from the API and update test * modify input to accept map[string]string only --- messages.go | 4 ++-- messages_test.go | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/messages.go b/messages.go index 861463235..6fd0adbc9 100644 --- a/messages.go +++ b/messages.go @@ -139,11 +139,11 @@ func (c *Client) RetrieveMessage( func (c *Client) ModifyMessage( ctx context.Context, threadID, messageID string, - metadata map[string]any, + metadata map[string]string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(metadata), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) if err != nil { return } diff --git a/messages_test.go b/messages_test.go index 9168d6ccf..a18be20bd 100644 --- a/messages_test.go +++ b/messages_test.go @@ -68,6 +68,10 @@ func TestMessages(t *testing.T) { metadata := map[string]any{} err := json.NewDecoder(r.Body).Decode(&metadata) checks.NoError(t, err, "unable to decode metadata in modify message call") + payload, ok := metadata["metadata"].(map[string]any) + if !ok { + t.Fatalf("metadata payload improperly wrapped %+v", metadata) + } resBytes, _ := json.Marshal( openai.Message{ @@ -86,8 +90,9 @@ func TestMessages(t *testing.T) { FileIds: nil, AssistantID: &emptyStr, RunID: &emptyStr, - Metadata: metadata, + Metadata: payload, }) + fmt.Fprintln(w, string(resBytes)) case http.MethodGet: resBytes, _ := json.Marshal( @@ -212,7 +217,7 @@ func TestMessages(t *testing.T) { } msg, err = client.ModifyMessage(ctx, threadID, messageID, - map[string]any{ + map[string]string{ "foo": "bar", }) checks.NoError(t, err, "ModifyMessage error") From e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 11 Apr 2024 16:39:10 +0800 Subject: [PATCH 114/242] feat: add GPT4Turbo and GPT4Turbo20240409 (#703) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index ab1dbd6c5..00f43ff1c 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" GPT4Turbo1106 = "gpt-4-1106-preview" GPT4TurboPreview = "gpt-4-turbo-preview" @@ -84,6 +86,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4VisionPreview: true, GPT4Turbo1106: true, GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, GPT40314: true, GPT40613: true, GPT432K: true, From ea551f422e5f38a0afc7d938eea5cff1f69494c5 Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:32:38 +0200 Subject: [PATCH 115/242] Fixing typos (#706) --- README.md | 2 +- assistant.go | 4 ++-- client_test.go | 2 +- error.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 9a479c0a0..7946f4d9b 100644 --- a/README.md +++ b/README.md @@ -636,7 +636,7 @@ FunctionDefinition{ }, "unit": { Type: jsonschema.String, - Enum: []string{"celcius", "fahrenheit"}, + Enum: []string{"celsius", "fahrenheit"}, }, }, Required: []string{"location"}, diff --git a/assistant.go b/assistant.go index 4ca2dda62..9415325f8 100644 --- a/assistant.go +++ b/assistant.go @@ -181,7 +181,7 @@ func (c *Client) ListAssistants( order *string, after *string, before *string, -) (reponse AssistantsList, err error) { +) (response AssistantsList, err error) { urlValues := url.Values{} if limit != nil { urlValues.Add("limit", fmt.Sprintf("%d", *limit)) @@ -208,7 +208,7 @@ func (c *Client) ListAssistants( return } - err = c.sendRequest(req, &reponse) + err = c.sendRequest(req, &response) return } diff --git a/client_test.go b/client_test.go index bc5133edc..a08d10f21 100644 --- a/client_test.go +++ b/client_test.go @@ -406,7 +406,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } } -func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { +func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { config := DefaultConfig(test.GetTestToken()) client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} diff --git a/error.go b/error.go index b2d01e22e..37959a272 100644 --- a/error.go +++ b/error.go @@ -23,7 +23,7 @@ type InnerError struct { ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } -// RequestError provides informations about generic request errors. +// RequestError provides information about generic request errors. type RequestError struct { HTTPStatusCode int Err error From 2446f08f94b2750287c40bb9593377f349f5578e Mon Sep 17 00:00:00 2001 From: Andreas Deininger Date: Sat, 13 Apr 2024 13:34:23 +0200 Subject: [PATCH 116/242] Bump GitHub workflow actions to latest versions (#707) --- .github/workflows/close-inactive-issues.yml | 2 +- .github/workflows/pr.yml | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index bfe9b5c96..32723c4e9 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -10,7 +10,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v9 with: days-before-issue-stale: 30 days-before-issue-close: 14 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 8df721f0f..a41fff92f 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -9,19 +9,19 @@ jobs: name: Sanity check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: '1.19' + go-version: '1.21' - name: Run vet run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 From a42f51967f5c2f8462f8d8dfd25f7d6a8d7a46fc Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 17 Apr 2024 03:26:14 +0800 Subject: [PATCH 117/242] [New_Features] Adds recently added Assistant cost saving parameters (#710) * add cost saving parameters * add periods at the end of comments * shorten commnet * further lower comment length * fix type --- run.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/run.go b/run.go index 1f3cb7eb7..7c14779c5 100644 --- a/run.go +++ b/run.go @@ -28,6 +28,16 @@ type Run struct { Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + httpHeader } @@ -78,8 +88,42 @@ type RunRequest struct { AdditionalInstructions string `json:"additional_instructions,omitempty"` Tools []Tool `json:"tools,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` + + // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. + // lower values are more focused and deterministic. + Temperature *float32 `json:"temperature,omitempty"` + + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` } +// ThreadTruncationStrategy defines the truncation strategy to use for the thread. +// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. +type ThreadTruncationStrategy struct { + // default 'auto'. + Type TruncationStrategy `json:"type,omitempty"` + // this field should be set if the truncation strategy is set to LastMessages. + LastMessages *int `json:"last_messages,omitempty"` +} + +// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant. +type TruncationStrategy string + +const ( + // TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model. + TruncationStrategyAuto = TruncationStrategy("auto") + // TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread. + TruncationStrategyLastMessages = TruncationStrategy("last_messages") +) + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From c6a63ed19aeb0e91facc5409c5a08612db550fb2 Mon Sep 17 00:00:00 2001 From: Mike Chaykowsky Date: Tue, 16 Apr 2024 12:28:06 -0700 Subject: [PATCH 118/242] Add PromptFilterResult (#702) --- chat_stream.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 57cfa789f..6ff7078e2 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -19,13 +19,19 @@ type ChatCompletionStreamChoice struct { ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } +type PromptFilterResult struct { + Index int `json:"index"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } // ChatCompletionStream From 8d15a377ec4fa3aaf2e706cd1e2ad986dd6b8242 Mon Sep 17 00:00:00 2001 From: Danai Antoniou <32068609+danai-antoniou@users.noreply.github.com> Date: Wed, 24 Apr 2024 12:59:50 +0100 Subject: [PATCH 119/242] Remove hardcoded assistants version (#719) --- assistant.go | 19 +++++++++---------- client.go | 4 ++-- config.go | 14 +++++++++----- messages.go | 17 +++++++++++------ run.go | 27 +++++++++------------------ thread.go | 8 ++++---- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/assistant.go b/assistant.go index 9415325f8..661681e83 100644 --- a/assistant.go +++ b/assistant.go @@ -11,7 +11,6 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" - openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { @@ -116,7 +115,7 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -132,7 +131,7 @@ func (c *Client) RetrieveAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -149,7 +148,7 @@ func (c *Client) ModifyAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -165,7 +164,7 @@ func (c *Client) DeleteAssistant( ) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -203,7 +202,7 @@ func (c *Client) ListAssistants( urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -221,7 +220,7 @@ func (c *Client) CreateAssistantFile( urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -238,7 +237,7 @@ func (c *Client) RetrieveAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -255,7 +254,7 @@ func (c *Client) DeleteAssistantFile( ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -294,7 +293,7 @@ func (c *Client) ListAssistantFiles( urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/client.go b/client.go index 9a1c8958d..77d693226 100644 --- a/client.go +++ b/client.go @@ -89,9 +89,9 @@ func withContentType(contentType string) requestOption { } } -func withBetaAssistantV1() requestOption { +func withBetaAssistantVersion(version string) requestOption { return func(args *requestOptions) { - args.header.Set("OpenAI-Beta", "assistants=v1") + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } } diff --git a/config.go b/config.go index c58b71ec6..599fa89c0 100644 --- a/config.go +++ b/config.go @@ -23,6 +23,8 @@ const ( const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -30,7 +32,8 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient *http.Client @@ -39,10 +42,11 @@ type ClientConfig struct { func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, diff --git a/messages.go b/messages.go index 6fd0adbc9..6af118445 100644 --- a/messages.go +++ b/messages.go @@ -76,7 +76,8 @@ type MessageFilesList struct { // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -111,7 +112,8 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, } urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -126,7 +128,8 @@ func (c *Client) RetrieveMessage( threadID, messageID string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -143,7 +146,7 @@ func (c *Client) ModifyMessage( ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -158,7 +161,8 @@ func (c *Client) RetrieveMessageFile( threadID, messageID, fileID string, ) (file MessageFile, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -173,7 +177,8 @@ func (c *Client) ListMessageFiles( threadID, messageID string, ) (files MessageFilesList, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/run.go b/run.go index 7c14779c5..094b0a4db 100644 --- a/run.go +++ b/run.go @@ -226,8 +226,7 @@ func (c *Client) CreateRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -247,8 +246,7 @@ func (c *Client) RetrieveRun( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -270,8 +268,7 @@ func (c *Client) ModifyRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -310,8 +307,7 @@ func (c *Client) ListRuns( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -332,8 +328,7 @@ func (c *Client) SubmitToolOutputs( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -352,8 +347,7 @@ func (c *Client) CancelRun( ctx, http.MethodPost, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -372,8 +366,7 @@ func (c *Client) CreateThreadAndRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -394,8 +387,7 @@ func (c *Client) RetrieveRunStep( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -435,8 +427,7 @@ func (c *Client) ListRunSteps( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/thread.go b/thread.go index 291f3dcab..900e3f2ea 100644 --- a/thread.go +++ b/thread.go @@ -51,7 +51,7 @@ type ThreadDeleteResponse struct { // CreateThread creates a new thread. func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -64,7 +64,7 @@ func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (respo func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -81,7 +81,7 @@ func (c *Client) ModifyThread( ) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -97,7 +97,7 @@ func (c *Client) DeleteThread( ) (response ThreadDeleteResponse, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } From 2d58f8f4b87be26dc0b7ba2b1f0c9496ecf1dfa5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=80=E6=97=A5=E3=80=82?= Date: Wed, 24 Apr 2024 20:02:03 +0800 Subject: [PATCH 120/242] chore: add SystemFingerprint for chat completion stream response (#716) * chore: add SystemFingerprint for stream response * chore: add test * lint: format for test --- chat_stream.go | 1 + chat_stream_test.go | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index 6ff7078e2..159f9f472 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -30,6 +30,7 @@ type ChatCompletionStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` } diff --git a/chat_stream_test.go b/chat_stream_test.go index bd571cb48..bd1c737dd 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -46,12 +46,12 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -77,10 +77,11 @@ func TestCreateChatCompletionStream(t *testing.T) { expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: openai.GPT3Dot5Turbo, + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -91,10 +92,11 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: openai.GPT3Dot5Turbo, + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ From c84ab5f6ae8da3a78826ed2c8dc4c5cf93e30589 Mon Sep 17 00:00:00 2001 From: wurui <1009479218@qq.com> Date: Wed, 24 Apr 2024 20:08:58 +0800 Subject: [PATCH 121/242] feat: support cloudflare AI Gateway flavored azure openai (#715) * feat: support cloudflare AI Gateway flavored azure openai Signed-off-by: STRRL * test: add test for cloudflare azure fullURL --------- Signed-off-by: STRRL Co-authored-by: STRRL --- api_internal_test.go | 36 ++++++++++++++++++++++++++++++++++++ client.go | 10 ++++++++-- config.go | 7 ++++--- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f8993..a590ec9ab 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -148,3 +148,39 @@ func TestAzureFullURL(t *testing.T) { }) } } + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "CloudflareAzureBaseURLWithoutSlashOK", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL("/chat/completions") + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} diff --git a/client.go b/client.go index 77d693226..c57ba17c7 100644 --- a/client.go +++ b/client.go @@ -182,7 +182,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication // Azure API Key authentication - if c.config.APIType == APITypeAzure { + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { req.Header.Set(AzureAPIKeyHeader, c.config.authToken) } else if c.config.authToken != "" { // OpenAI or Azure AD authentication @@ -246,7 +246,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { ) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" + // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ + if c.config.APIType == APITypeCloudflareAzure { + baseURL := c.config.BaseURL + baseURL = strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) + } + return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) } diff --git a/config.go b/config.go index 599fa89c0..bb437c97f 100644 --- a/config.go +++ b/config.go @@ -16,9 +16,10 @@ const ( type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" ) const AzureAPIKeyHeader = "api-key" From c9953a7b051bd661254fb071029553e61c78f8bd Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Sat, 27 Apr 2024 12:55:49 +0330 Subject: [PATCH 122/242] Fixup minor copy-pasta comment typo (#728) imagess -> images --- image_api_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/image_api_test.go b/image_api_test.go index 2eb46f2b4..48416b1e2 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -36,7 +36,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -146,7 +146,7 @@ func TestImageEditWithoutMask(t *testing.T) { func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -202,7 +202,7 @@ func TestImageVariation(t *testing.T) { func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } From 3334a9c78a9d594934e33af184e4e6313c4a942b Mon Sep 17 00:00:00 2001 From: Alireza Ghasemi Date: Tue, 7 May 2024 16:10:07 +0330 Subject: [PATCH 123/242] Add support for word-level audio transcription timestamp granularity (#733) * Add support for audio transcription timestamp_granularities word * Fixup multiple timestamp granularities --- audio.go | 31 ++++++++++++++++++++++++++----- audio_api_test.go | 4 ++++ audio_test.go | 6 +++++- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/audio.go b/audio.go index 4cbe4fe64..dbc26d154 100644 --- a/audio.go +++ b/audio.go @@ -27,8 +27,14 @@ const ( AudioResponseFormatVTT AudioResponseFormat = "vtt" ) +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) + // AudioRequest represents a request structure for audio API. -// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { Model string @@ -38,10 +44,11 @@ type AudioRequest struct { // Reader is an optional io.Reader when you do not want to use an existing file. Reader io.Reader - Prompt string // For translation, it should be in English - Temperature float32 - Language string // For translation, just do not use it. It seems "en" works, not confirmed... - Format AudioResponseFormat + Prompt string + Temperature float32 + Language string // Only for transcription. + Format AudioResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } // AudioResponse represents a response structure for audio API. @@ -62,6 +69,11 @@ type AudioResponse struct { NoSpeechProb float64 `json:"no_speech_prob"` Transient bool `json:"transient"` } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` Text string `json:"text"` httpHeader @@ -179,6 +191,15 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { } } + if len(request.TimestampGranularities) > 0 { + for _, tg := range request.TimestampGranularities { + err = b.WriteField("timestamp_granularities[]", string(tg)) + if err != nil { + return fmt.Errorf("writing timestamp_granularities[]: %w", err) + } + } + } + // Close the multipart writer return b.Close() } diff --git a/audio_api_test.go b/audio_api_test.go index a0efc7921..c24598443 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -105,6 +105,10 @@ func TestAudioWithOptionalArgs(t *testing.T) { Temperature: 0.5, Language: "zh", Format: openai.AudioResponseFormatSRT, + TimestampGranularities: []openai.TranscriptionTimestampGranularity{ + openai.TranscriptionTimestampGranularitySegment, + openai.TranscriptionTimestampGranularityWord, + }, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index 5346244c8..235931f36 100644 --- a/audio_test.go +++ b/audio_test.go @@ -24,6 +24,10 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Temperature: 0.5, Language: "en", Format: AudioResponseFormatSRT, + TimestampGranularities: []TranscriptionTimestampGranularity{ + TranscriptionTimestampGranularitySegment, + TranscriptionTimestampGranularityWord, + }, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -47,7 +51,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { return nil } - failOn := []string{"model", "prompt", "temperature", "language", "response_format"} + failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} for _, failingField := range failOn { failForField = failingField mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) From 6af32202d1ce469674050600efa07c90ec286d03 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 7 May 2024 20:42:24 +0800 Subject: [PATCH 124/242] feat: support stream_options (#736) * feat: support stream_options * fix lint * fix lint --- chat.go | 10 ++++ chat_stream.go | 4 ++ chat_stream_test.go | 123 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) diff --git a/chat.go b/chat.go index efb14fd4c..a1eb11720 100644 --- a/chat.go +++ b/chat.go @@ -216,6 +216,16 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` } type ToolType string diff --git a/chat_stream.go b/chat_stream.go index 159f9f472..ffd512ff6 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -33,6 +33,10 @@ type ChatCompletionStreamResponse struct { SystemFingerprint string `json:"system_fingerprint"` PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream diff --git a/chat_stream_test.go b/chat_stream_test.go index bd1c737dd..63e45ee23 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -388,6 +388,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { @@ -401,6 +515,15 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true } From 3b25e09da90715681fe4049955d7c7ce645e218c Mon Sep 17 00:00:00 2001 From: Kevin Mesiab Date: Mon, 13 May 2024 11:48:14 -0700 Subject: [PATCH 125/242] enhancement: Add new GPT4-o and alias to completion enums (#744) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 00f43ff1c..3b4f8952a 100644 --- a/completion.go +++ b/completion.go @@ -22,6 +22,8 @@ const ( GPT432K = "gpt-4-32k" GPT40613 = "gpt-4-0613" GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" From 9f19d1c93bf986f2a8925be62f35aa5c413a706a Mon Sep 17 00:00:00 2001 From: nullswan Date: Mon, 13 May 2024 21:07:07 +0200 Subject: [PATCH 126/242] Add gpt4o (#742) * Add gpt4o * disabled model for endpoint seen in https://github.com/sashabaranov/go-openai/commit/e0d0801ac73cdc87d1b56ced0a0eb71e574546c3 * Update completion.go --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 3b4f8952a..ced8e0606 100644 --- a/completion.go +++ b/completion.go @@ -84,6 +84,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT3Dot5Turbo16K: true, GPT3Dot5Turbo16K0613: true, GPT4: true, + GPT4o: true, + GPT4o20240513: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 4f4a85687be31607536997e924b27693f5e5211a Mon Sep 17 00:00:00 2001 From: Kshirodra Meher Date: Tue, 14 May 2024 00:38:14 +0530 Subject: [PATCH 127/242] Added DALL.E 3 to readme.md (#741) * Added DALL.E 3 to readme.md Added DALL.E 3 to readme.md as its supported now as per issue https://github.com/sashabaranov/go-openai/issues/494 * Update README.md --------- Co-authored-by: Alexander Baranov <677093+sashabaranov@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7946f4d9b..799dc602b 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT * GPT-3, GPT-4 -* DALL·E 2 +* DALL·E 2, DALL·E 3 * Whisper ## Installation From 211cb49fc22766f4174fef15301c4d39aef609d3 Mon Sep 17 00:00:00 2001 From: ando-masaki Date: Fri, 24 May 2024 16:18:47 +0900 Subject: [PATCH 128/242] Update client.go to get response header whether there is an error or not. (#751) Update client.go to get response header whether there is an error or not. Because 429 Too Many Requests error response has "Retry-After" header. --- client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c57ba17c7..7bc28e984 100644 --- a/client.go +++ b/client.go @@ -129,14 +129,14 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { defer res.Body.Close() - if isFailureStatusCode(res) { - return c.handleErrorResp(res) - } - if v != nil { v.SetHeader(res.Header) } + if isFailureStatusCode(res) { + return c.handleErrorResp(res) + } + return decodeResponse(res.Body, v) } From 30cf7b879cff5eb56f06fda19c51c9e92fce8b13 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:50:22 -0700 Subject: [PATCH 129/242] feat: add params to RunRequest (#754) --- run.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/run.go b/run.go index 094b0a4db..6bd3933b1 100644 --- a/run.go +++ b/run.go @@ -92,6 +92,7 @@ type RunRequest struct { // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. @@ -103,6 +104,11 @@ type RunRequest struct { // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. @@ -124,6 +130,13 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From 8618492b98bb91edbb43f8080b3a68275e183663 Mon Sep 17 00:00:00 2001 From: shosato0306 <38198918+shosato0306@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:03:57 +0900 Subject: [PATCH 130/242] feat: add incomplete run status (#763) --- run.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/run.go b/run.go index 6bd3933b1..5598f1dfb 100644 --- a/run.go +++ b/run.go @@ -30,10 +30,10 @@ type Run struct { Temperature *float32 `json:"temperature,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` @@ -50,6 +50,7 @@ const ( RunStatusCancelling RunStatus = "cancelling" RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" RunStatusExpired RunStatus = "expired" RunStatusCancelled RunStatus = "cancelled" ) @@ -95,11 +96,11 @@ type RunRequest struct { TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fd41f7a5f49e6723d97642c186e5e090abaebfe2 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Thu, 13 Jun 2024 06:23:07 -0700 Subject: [PATCH 131/242] Fix integration test (#762) * added TestCompletionStream test moved completion stream testing to seperate function added NoErrorF fixes nil pointer reference on stream object * update integration test models --- api_integration_test.go | 64 ++++++++++++++++++++-------------- completion.go | 31 ++++++++-------- embeddings.go | 2 +- internal/test/checks/checks.go | 7 ++++ 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 736040c50..f34685188 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -26,7 +26,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +77,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +109,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { diff --git a/completion.go b/completion.go index ced8e0606..024f09b14 100644 --- a/completion.go +++ b/completion.go @@ -39,30 +39,33 @@ const ( GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/embeddings.go b/embeddings.go index c5633a313..b513ba6a7 100644 --- a/embeddings.go +++ b/embeddings.go @@ -16,7 +16,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..6bd0964c6 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { From 7e96c712cbdad50b9cf67324b1ca5ef6541b6235 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:15:27 +0400 Subject: [PATCH 132/242] run integration tests (#769) --- .github/workflows/integration-tests.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/integration-tests.yml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 000000000..19f158e40 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,19 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + run: go test -v -tags=integration ./api_integration_test.go From c69c3bb1d259375d5de801f890aca40c0b2a8867 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:21:25 +0400 Subject: [PATCH 133/242] integration tests: pass openai secret (#770) * pass openai secret * only run in master branch --- .github/workflows/integration-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 19f158e40..7260b00b4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -16,4 +16,6 @@ jobs: with: go-version: '1.21' - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} run: go test -v -tags=integration ./api_integration_test.go From 99cc170b5414bd21fc1c55bccba1d6c1bad04516 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 13 Jun 2024 23:24:37 +0800 Subject: [PATCH 134/242] feat: support batches api (#746) * feat: support batches api * update batch_test.go * fix golangci-lint check * fix golangci-lint check * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix: create batch api * update batch_test.go * feat: add `CreateBatchWithUploadFile` * feat: add `UploadBatchFile` * optimize variable and type naming * expose `BatchLineItem` interface * update batches const --- batch.go | 275 ++++++++++++++++++++++++++++++++++++ batch_test.go | 368 +++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 11 ++ files.go | 1 + 4 files changed, 655 insertions(+) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..4aba966bc --- /dev/null +++ b/batch.go @@ -0,0 +1,275 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +var ErrUploadBatchFileFailed = errors.New("upload batch file failed") + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + err = errors.Join(ErrUploadBatchFileFailed, err) + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 000000000..4b2261e0e --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/client_test.go b/client_test.go index a08d10f21..e49da9b3d 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { diff --git a/files.go b/files.go index b40a44f15..26ad6bd70 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. From 68acf22a43903c1b460006e7c4b883ce73e35857 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 13 Jun 2024 17:26:37 +0200 Subject: [PATCH 135/242] Support Tool Resources properties for Threads (#760) * Support Tool Resources properties for Threads * Add Chunking Strategy for Threads vector stores --- thread.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/thread.go b/thread.go index 900e3f2ea..6f7521454 100644 --- a/thread.go +++ b/thread.go @@ -10,21 +10,74 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string From 0a421308993425afed7796da8f8e0e1abafd4582 Mon Sep 17 00:00:00 2001 From: Peng Guan-Cheng Date: Wed, 19 Jun 2024 16:37:21 +0800 Subject: [PATCH 136/242] feat: provide vector store (#772) * implement vectore store feature * fix after integration testing * fix golint error * improve test to increare code coverage * fix golint anc code coverage problem * add tool_resource in assistant response * chore: code style * feat: use pagination param * feat: use pagination param * test: use pagination param * test: rm unused code --------- Co-authored-by: Denny Depok <61371551+kodernubie@users.noreply.github.com> Co-authored-by: eric.p --- assistant.go | 50 ++++--- config.go | 2 +- vector_store.go | 345 ++++++++++++++++++++++++++++++++++++++++++ vector_store_test.go | 349 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 18 deletions(-) create mode 100644 vector_store.go create mode 100644 vector_store_test.go diff --git a/assistant.go b/assistant.go index 661681e83..cc13a3020 100644 --- a/assistant.go +++ b/assistant.go @@ -14,16 +14,17 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` httpHeader } @@ -34,6 +35,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -41,19 +43,33 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases diff --git a/config.go b/config.go index bb437c97f..1347567d7 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,7 @@ const ( const AzureAPIKeyHeader = "api-key" -const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store // ClientConfig is a configuration of a client. type ClientConfig struct { diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 000000000..5c364362a --- /dev/null +++ b/vector_store.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 000000000..58b9a857e --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +} From e31185974c45949cc58c24a6cbf5ca969fb0f622 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:06:52 +0100 Subject: [PATCH 137/242] remove errors.Join (#778) --- batch.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/batch.go b/batch.go index 4aba966bc..a43d401ab 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -109,8 +108,6 @@ type BatchResponse struct { Batch } -var ErrUploadBatchFileFailed = errors.New("upload batch file failed") - // CreateBatch — API call to Create batch. func (c *Client) CreateBatch( ctx context.Context, @@ -202,7 +199,6 @@ func (c *Client) CreateBatchWithUploadFile( Lines: request.Lines, }) if err != nil { - err = errors.Join(ErrUploadBatchFileFailed, err) return } return c.CreateBatch(ctx, CreateBatchRequest{ From 03851d20327b7df5358ff9fb0ac96f476be1875a Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Sun, 30 Jun 2024 17:20:10 +0200 Subject: [PATCH 138/242] allow custom voice and speech models (#691) --- speech.go | 31 ------------------------------- speech_test.go | 17 ----------------- 2 files changed, 48 deletions(-) diff --git a/speech.go b/speech.go index 7e22e755c..19b21bdf1 100644 --- a/speech.go +++ b/speech.go @@ -2,7 +2,6 @@ package openai import ( "context" - "errors" "net/http" ) @@ -36,11 +35,6 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` @@ -49,32 +43,7 @@ type CreateSpeechRequest struct { Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), withBody(request), withContentType("application/json"), diff --git a/speech_test.go b/speech_test.go index d9ba58b13..f1e405c39 100644 --- a/speech_test.go +++ b/speech_test.go @@ -95,21 +95,4 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) } From 727944c47886924800128d1c33df706b4159eb23 Mon Sep 17 00:00:00 2001 From: Luca Giannini <68999840+LGXerxes@users.noreply.github.com> Date: Fri, 12 Jul 2024 12:31:11 +0200 Subject: [PATCH 139/242] feat: ParallelToolCalls to ChatCompletionRequest with helper functions (#787) * added ParallelToolCalls to ChatCompletionRequest with helper functions * added tests for coverage * changed ParallelToolCalls to any --- chat.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chat.go b/chat.go index a1eb11720..eb494f41f 100644 --- a/chat.go +++ b/chat.go @@ -218,6 +218,8 @@ type ChatCompletionRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // Options for streaming response. Only set this when you set stream: true. StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } type StreamOptions struct { From 3e47e6fef4ac861dd5e07f73a8fb240374e8cad3 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 19 Jul 2024 22:06:27 +0800 Subject: [PATCH 140/242] fix: #790 (#798) --- files.go | 1 + 1 file changed, 1 insertion(+) diff --git a/files.go b/files.go index 26ad6bd70..edc9f2a20 100644 --- a/files.go +++ b/files.go @@ -102,6 +102,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File if err != nil { return } + defer fileData.Close() err = builder.CreateFormFile("file", fileData) if err != nil { From 27c1c56f0b50a84740425f7534c46825e227b437 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Fri, 19 Jul 2024 07:06:51 -0700 Subject: [PATCH 141/242] feat: Add GPT-4o Mini model support (#796) --- completion.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/completion.go b/completion.go index 024f09b14..4ff1123c4 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,8 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" @@ -89,6 +91,8 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4oMini: true, + GPT4oMini20240718: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, From 92f483055f666847f7954e148b7f46771c5581b8 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 19 Jul 2024 22:10:17 +0800 Subject: [PATCH 142/242] fix: #794 (#797) --- client.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 7bc28e984..d5d555c3d 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" utils "github.com/sashabaranov/go-openai/internal" @@ -228,10 +229,13 @@ func (c *Client) fullURL(suffix string, args ...any) string { if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { baseURL := c.config.BaseURL baseURL = strings.TrimRight(baseURL, "/") + parseURL, _ := url.Parse(baseURL) + query := parseURL.Query() + query.Add("api-version", c.config.APIVersion) // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) + return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) } azureDeploymentName := "UNKNOWN" if len(args) > 0 { @@ -240,9 +244,9 @@ func (c *Client) fullURL(suffix string, args ...any) string { azureDeploymentName = c.config.GetAzureDeploymentByModel(model) } } - return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", + return fmt.Sprintf("%s/%s/%s/%s%s?%s", baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, c.config.APIVersion, + azureDeploymentName, suffix, query.Encode(), ) } From ae903d7465c4b48654fac6103472767ee4d95e41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edin=20=C4=86orali=C4=87?= <73831203+ecoralic@users.noreply.github.com> Date: Fri, 19 Jul 2024 17:12:20 +0300 Subject: [PATCH 143/242] fix: Updated ThreadMessage struct with latest fields based on OpenAI docs (#792) * fix: Updated ThreadMessage struct with latest fields based on OpenAI docs * fix: Reverted FileIDs for backward compatibility of v1 --- thread.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/thread.go b/thread.go index 6f7521454..bc08e2bcb 100644 --- a/thread.go +++ b/thread.go @@ -83,14 +83,25 @@ type ModifyThreadRequest struct { type ThreadMessageRole string const ( - ThreadMessageRoleUser ThreadMessageRole = "user" + ThreadMessageRoleAssistant ThreadMessageRole = "assistant" + ThreadMessageRoleUser ThreadMessageRole = "user" ) type ThreadMessage struct { - Role ThreadMessageRole `json:"role"` - Content string `json:"content"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadAttachment struct { + FileID string `json:"file_id"` + Tools []ThreadAttachmentTool `json:"tools"` +} + +type ThreadAttachmentTool struct { + Type string `json:"type"` } type ThreadDeleteResponse struct { From a7e9f0e3880d1487fe8e06a43820f42046b5b622 Mon Sep 17 00:00:00 2001 From: Janusch Jacoby Date: Fri, 19 Jul 2024 16:13:02 +0200 Subject: [PATCH 144/242] add hyperparams (#793) --- fine_tuning_job.go | 4 +++- fine_tuning_job_test.go | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 9dcb49de1..5a9f54a92 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -26,7 +26,9 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs any `json:"n_epochs,omitempty"` + Epochs any `json:"n_epochs,omitempty"` + LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` + BatchSize any `json:"batch_size,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index d2fbcd4c7..5f63ef24c 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -33,7 +33,9 @@ func TestFineTuningJob(t *testing.T) { ValidationFile: "", TrainingFile: "file-abc123", Hyperparameters: openai.Hyperparameters{ - Epochs: "auto", + Epochs: "auto", + LearningRateMultiplier: "auto", + BatchSize: "auto", }, TrainedTokens: 5768, }) From 966ee682b11ca580c2c2c3ac067c27b51bd6d749 Mon Sep 17 00:00:00 2001 From: VanessaMae23 <60029664+Vanessamae23@users.noreply.github.com> Date: Fri, 19 Jul 2024 22:18:16 +0800 Subject: [PATCH 145/242] Add New Optional Parameters to `AssistantRequest` Struct (#795) * Add more parameters to support Assistant v2 * Add goimports --- assistant.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/assistant.go b/assistant.go index cc13a3020..4c89c1b2f 100644 --- a/assistant.go +++ b/assistant.go @@ -62,14 +62,17 @@ type AssistantToolResource struct { // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases From 581da2f12d52617368bdfe2625f5b0ef1dd32758 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Mon, 29 Jul 2024 01:43:45 +0800 Subject: [PATCH 146/242] fix: #804 (#807) --- batch.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/batch.go b/batch.go index a43d401ab..3c1a9d0d7 100644 --- a/batch.go +++ b/batch.go @@ -65,7 +65,7 @@ type Batch struct { Endpoint BatchEndpoint `json:"endpoint"` Errors *struct { Object string `json:"object,omitempty"` - Data struct { + Data []struct { Code string `json:"code,omitempty"` Message string `json:"message,omitempty"` Param *string `json:"param,omitempty"` From dbe726c59f6df65965a4ee25e37706c33e391dc4 Mon Sep 17 00:00:00 2001 From: Daniel Chalef <131175+danielchalef@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:21:38 +1000 Subject: [PATCH 147/242] Add support for `gpt-4o-2024-08-06` (#812) * feat: Add GPT-4o Mini model support * feat: Add GPT-4o-2024-08-06 model support --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 4ff1123c4..d435eb382 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,7 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" @@ -91,6 +92,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4o20240806: true, GPT4oMini: true, GPT4oMini20240718: true, GPT4TurboPreview: true, From 623074c14a110b97d9a7aac7896bbdccf335257f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Wed, 7 Aug 2024 21:47:48 +0800 Subject: [PATCH 148/242] feat: Support Structured Outputs (#813) * feat: Support Structured Outputs * feat: Support Structured Outputs * update imports * add integration test * update JSON schema comments --- api_integration_test.go | 61 +++++++++++++++++++++++++++++++++++++++++ chat.go | 13 ++++++++- jsonschema/json.go | 8 +++++- 3 files changed, 80 insertions(+), 2 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index f34685188..a487f588a 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" @@ -178,3 +179,63 @@ func TestAPIError(t *testing.T) { t.Fatal("Empty error message occurred") } } + +func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": jsonschema.Definition{Type: jsonschema.String}, + "CamelCase": jsonschema.Definition{Type: jsonschema.String}, + "KebabCase": jsonschema.Definition{Type: jsonschema.String}, + "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + Strict: true, + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index eb494f41f..8bfe558b5 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -175,11 +177,20 @@ type ChatCompletionResponseFormatType string const ( ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema jsonschema.Definition `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/jsonschema/json.go b/jsonschema/json.go index cb941eb75..7fd1e11bf 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -29,11 +29,17 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties"` + Properties map[string]Definition `json:"properties,omitempty"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` + // AdditionalProperties is used to control the handling of properties in an object + // that are not explicitly defined in the properties section of the schema. example: + // additionalProperties: true + // additionalProperties: false + // additionalProperties: jsonschema.Definition{Type: jsonschema.String} + AdditionalProperties any `json:"additionalProperties,omitempty"` } func (d Definition) MarshalJSON() ([]byte, error) { From 6439e1fcc93fc5175accf5d51358e45fa5ea9099 Mon Sep 17 00:00:00 2001 From: Tyler Gannon Date: Wed, 7 Aug 2024 12:40:45 -0700 Subject: [PATCH 149/242] Make reponse format JSONSchema optional (#820) --- chat.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index 8bfe558b5..31fa887d6 100644 --- a/chat.go +++ b/chat.go @@ -182,8 +182,8 @@ const ( ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` - JSONSchema ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` } type ChatCompletionResponseFormatJSONSchema struct { From 18803333812ea21c409e84d426141606b9a6e692 Mon Sep 17 00:00:00 2001 From: Harald Nordgren Date: Fri, 9 Aug 2024 18:30:32 +0200 Subject: [PATCH 150/242] Run integration tests for PRs (#823) * Unbreak integration tests * Update integration-tests.yml --- api_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api_integration_test.go b/api_integration_test.go index a487f588a..3084268e6 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -211,7 +211,7 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { }, ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, - JSONSchema: openai.ChatCompletionResponseFormatJSONSchema{ + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ Name: "cases", Schema: jsonschema.Definition{ Type: jsonschema.Object, From 2c6889e0818b93c4fd724d9528b610896f5e9421 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sun, 11 Aug 2024 05:05:06 +0800 Subject: [PATCH 151/242] fix: #788 (#800) --- completion.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/completion.go b/completion.go index d435eb382..bc2a63795 100644 --- a/completion.go +++ b/completion.go @@ -138,25 +138,26 @@ func checkPromptType(prompt any) bool { // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From dd7f5824f9a4c3860cccfaf8350d5d09e864038f Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sat, 17 Aug 2024 01:11:38 +0800 Subject: [PATCH 152/242] fix: fullURL endpoint generation (#817) --- api_internal_test.go | 24 ++++++++--- audio.go | 9 ++++- chat.go | 7 +++- chat_stream.go | 7 +++- client.go | 84 ++++++++++++++++++++++++-------------- client_test.go | 96 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 7 +++- edits.go | 7 +++- embeddings.go | 7 +++- example_test.go | 2 +- image.go | 25 +++++++++--- moderation.go | 7 +++- speech.go | 5 ++- stream.go | 8 +++- 14 files changed, 244 insertions(+), 51 deletions(-) diff --git a/api_internal_test.go b/api_internal_test.go index a590ec9ab..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org/", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Suffix string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { - "CloudflareAzureBaseURLWithoutSlashOK", + "", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + - "chat/completions?api-version=2023-05-15", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", }, } @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/audio.go b/audio.go index dbc26d154..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } diff --git a/chat.go b/chat.go index 31fa887d6..826fd3bd5 100644 --- a/chat.go +++ b/chat.go @@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index ffd512ff6..3f90bc019 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/client.go b/client.go index d5d555c3d..9f547e7cb 100644 --- a/client.go +++ b/client.go @@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error { return nil } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - parseURL, _ := url.Parse(baseURL) - query := parseURL.Query() - query.Add("api-version", c.config.APIVersion) - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode()) - } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } - return fmt.Sprintf("%s/%s/%s/%s%s?%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, query.Encode(), - ) + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } + + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) } + return fmt.Sprintf("%s%s", baseURL, suffix) +} - // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ - if c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { diff --git a/client_test.go b/client_test.go index e49da9b3d..a0d3bb390 100644 --- a/client_test.go +++ b/client_test.go @@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + wantPanic string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + defer func() { + if r := recover(); r != nil { + if r.(string) != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + } + } + }() + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) + } + }) + } +} diff --git a/completion.go b/completion.go index bc2a63795..e8e9242c9 100644 --- a/completion.go +++ b/completion.go @@ -213,7 +213,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/edits.go b/edits.go index 97d026029..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index b513ba6a7..74eb8aa57 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(baseReq), + ) if err != nil { return } diff --git a/example_test.go b/example_test.go index de67c57cd..1bdb8496e 100644 --- a/example_test.go +++ b/example_test.go @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } diff --git a/image.go b/image.go index 665de1a74..577d7db95 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/moderation.go b/moderation.go index ae285ef83..c8652efc8 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/speech.go b/speech.go index 19b21bdf1..20b52e334 100644 --- a/speech.go +++ b/speech.go @@ -44,7 +44,10 @@ type CreateSpeechRequest struct { } func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/stream.go b/stream.go index b277f3c29..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } From d86425a5cfd09bb76fe2f9239a03a9dbcdca8a9c Mon Sep 17 00:00:00 2001 From: Grey Baker Date: Fri, 16 Aug 2024 13:41:39 -0400 Subject: [PATCH 153/242] Allow structured outputs via function calling (#828) --- api_integration_test.go | 76 +++++++++++++++++++++++++++++++++++++++++ chat.go | 1 + chat_test.go | 26 ++++++++++++++ 3 files changed, 103 insertions(+) diff --git a/api_integration_test.go b/api_integration_test.go index 3084268e6..57f7c40fb 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { } } } + +func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + ToolChoice: openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "display_cases", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/chat.go b/chat.go index 826fd3bd5..97c89a497 100644 --- a/chat.go +++ b/chat.go @@ -264,6 +264,7 @@ type ToolFunction struct { type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` // Parameters is an object describing the function. // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. diff --git a/chat_test.go b/chat_test.go index 520bf5ca4..37dc09d4d 100644 --- a/chat_test.go +++ b/chat_test.go @@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) + t.Run("StructuredOutputs", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Strict: true, + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) } func TestAzureChatCompletions(t *testing.T) { From 6d021190f05410a44d9401984815c55f4736b755 Mon Sep 17 00:00:00 2001 From: Yamagami ken-ichi Date: Thu, 22 Aug 2024 23:27:44 +0900 Subject: [PATCH 154/242] feat: Support Delete Message API (#799) * feat: Add DeleteMessage function to API client * fix: linter nolint : Deprecated method split function: cognitive complexity 21 * rename func name for unit-test --- client_test.go | 3 +++ fine_tunes.go | 2 +- messages.go | 24 ++++++++++++++++++++++++ messages_test.go | 36 +++++++++++++++++++++++++++++++----- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/client_test.go b/client_test.go index a0d3bb390..7119d8a7e 100644 --- a/client_test.go +++ b/client_test.go @@ -348,6 +348,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ModifyMessage", func() (any, error) { return client.ModifyMessage(ctx, "", "", nil) }}, + {"DeleteMessage", func() (any, error) { + return client.DeleteMessage(ctx, "", "") + }}, {"RetrieveMessageFile", func() (any, error) { return client.RetrieveMessageFile(ctx, "", "", "") }}, diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..74b47bf3f 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated if err != nil { return } diff --git a/messages.go b/messages.go index 6af118445..1fddd6314 100644 --- a/messages.go +++ b/messages.go @@ -73,6 +73,14 @@ type MessageFilesList struct { httpHeader } +type MessageDeletionStatus struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) @@ -186,3 +194,19 @@ func (c *Client) ListMessageFiles( err = c.sendRequest(req, &files) return } + +// DeleteMessage deletes a message.. +func (c *Client) DeleteMessage( + ctx context.Context, + threadID, messageID string, +) (status MessageDeletionStatus, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &status) + return +} diff --git a/messages_test.go b/messages_test.go index a18be20bd..71ceb4d3a 100644 --- a/messages_test.go +++ b/messages_test.go @@ -8,20 +8,17 @@ import ( "testing" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) var emptyStr = "" -// TestMessages Tests the messages endpoint of the API using the mocked server. -func TestMessages(t *testing.T) { +func setupServerForTestMessage(t *testing.T, server *test.ServerTest) { threadID := "thread_abc123" messageID := "msg_abc123" fileID := "file_abc123" - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler( "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, func(w http.ResponseWriter, r *http.Request) { @@ -115,6 +112,13 @@ func TestMessages(t *testing.T) { Metadata: nil, }) fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + resBytes, _ := json.Marshal(openai.MessageDeletionStatus{ + ID: messageID, + Object: "thread.message.deleted", + Deleted: true, + }) + fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) } @@ -176,7 +180,18 @@ func TestMessages(t *testing.T) { } }, ) +} +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + setupServerForTestMessage(t, server) ctx := context.Background() // static assertion of return type @@ -225,6 +240,17 @@ func TestMessages(t *testing.T) { t.Fatalf("expected message metadata to get modified") } + msgDel, err := client.DeleteMessage(ctx, threadID, messageID) + checks.NoError(t, err, "DeleteMessage error") + if msgDel.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + if !msgDel.Deleted { + t.Fatalf("expected deleted is true") + } + _, err = client.DeleteMessage(ctx, threadID, "not_exist_id") + checks.HasError(t, err, "DeleteMessage error") + // message files var msgFile openai.MessageFile msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) From 5162adbbf90cef77b8462c1f33c81f7d258a1447 Mon Sep 17 00:00:00 2001 From: Alexey Michurin Date: Fri, 23 Aug 2024 13:47:11 +0300 Subject: [PATCH 155/242] Support http client middlewareing (#830) --- config.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 1347567d7..8a9183558 100644 --- a/config.go +++ b/config.go @@ -26,6 +26,10 @@ const AzureAPIKeyHeader = "api-key" const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store +type HTTPDoer interface { + Do(req *http.Request) (*http.Response, error) +} + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -36,7 +40,7 @@ type ClientConfig struct { APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func - HTTPClient *http.Client + HTTPClient HTTPDoer EmptyMessagesLimit uint } From a3bd2569ac51f1c54d704ec80dcbb91ab9f46acf Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Sun, 25 Aug 2024 01:06:08 +0800 Subject: [PATCH 156/242] Improve handling of JSON Schema in OpenAI API Response Context (#819) * feat: add jsonschema.Validate and jsonschema.Unmarshal * fix Sanity check * remove slices.Contains * fix Sanity check * add SchemaWrapper * update api_integration_test.go * update method 'reflectSchema' to support 'omitempty' in JSON tag * add GenerateSchemaForType * update json_test.go * update `Warp` to `Wrap` * fix Sanity check * fix Sanity check * update api_internal_test.go * update README.md * update README.md * remove jsonschema.SchemaWrapper * remove jsonschema.SchemaWrapper * fix Sanity check * optimize code formatting --- README.md | 64 +++++++++++++++++ api_integration_test.go | 36 +++++----- chat.go | 10 ++- example_test.go | 2 +- jsonschema/json.go | 105 +++++++++++++++++++++++++++- jsonschema/validate.go | 89 +++++++++++++++++++++++ jsonschema/validate_test.go | 136 ++++++++++++++++++++++++++++++++++++ 7 files changed, 412 insertions(+), 30 deletions(-) create mode 100644 jsonschema/validate.go create mode 100644 jsonschema/validate_test.go diff --git a/README.md b/README.md index 799dc602b..0d6aafa40 100644 --- a/README.md +++ b/README.md @@ -743,6 +743,70 @@ func main() { } ```
+ +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } `json:"steps"` + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. ## Frequently Asked Questions diff --git a/api_integration_test.go b/api_integration_test.go index 57f7c40fb..8c9f3384f 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,7 +4,6 @@ package openai_test import ( "context" - "encoding/json" "errors" "io" "os" @@ -190,6 +189,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { c := openai.NewClient(apiToken) ctx := context.Background() + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } resp, err := c.CreateChatCompletion( ctx, openai.ChatCompletionRequest{ @@ -212,31 +222,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { ResponseFormat: &openai.ChatCompletionResponseFormat{ Type: openai.ChatCompletionResponseFormatTypeJSONSchema, JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ - Name: "cases", - Schema: jsonschema.Definition{ - Type: jsonschema.Object, - Properties: map[string]jsonschema.Definition{ - "PascalCase": jsonschema.Definition{Type: jsonschema.String}, - "CamelCase": jsonschema.Definition{Type: jsonschema.String}, - "KebabCase": jsonschema.Definition{Type: jsonschema.String}, - "SnakeCase": jsonschema.Definition{Type: jsonschema.String}, - }, - Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, - AdditionalProperties: false, - }, + Name: "cases", + Schema: schema, Strict: true, }, }, }, ) checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") - var result = make(map[string]string) - err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result) - checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") - for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { - if _, ok := result[key]; !ok { - t.Errorf("key:%s does not exist.", key) - } + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") } } diff --git a/chat.go b/chat.go index 97c89a497..56e99a78b 100644 --- a/chat.go +++ b/chat.go @@ -5,8 +5,6 @@ import ( "encoding/json" "errors" "net/http" - - "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct { } type ChatCompletionResponseFormatJSONSchema struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Schema jsonschema.Definition `json:"schema"` - Strict bool `json:"strict"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. diff --git a/example_test.go b/example_test.go index 1bdb8496e..e5dbf44bf 100644 --- a/example_test.go +++ b/example_test.go @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() diff --git a/jsonschema/json.go b/jsonschema/json.go index 7fd1e11bf..bcb253fae 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,13 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) type DataType string @@ -42,7 +48,7 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -50,6 +56,99 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + return reflectSchema(reflect.TypeOf(v)) +} + +func reflectSchema(t reflect.Type) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + if jsonTag == "" { + jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..f14ffd4c4 --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,89 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +func Validate(schema Definition, data any) bool { + switch schema.Type { + case Object: + return validateObject(schema, data) + case Array: + return validateArray(schema, data) + case String: + _, ok := data.(string) + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + return false + } +} + +func validateObject(schema Definition, data any) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..c2c47a2ce --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,136 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func Test_Validate(t *testing.T) { + type args struct { + data any + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema jsonschema.Definition + content []byte + v any + } + var result1 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + var result2 struct { + String string `json:"string"` + Number float64 `json:"number"` + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &result1, + }, false}, + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: result2, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil { + t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} From 030b7cb7ed60fc4a8b2fd608f538c470b65b1131 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 24 Aug 2024 18:11:27 +0100 Subject: [PATCH 157/242] fix integration tests (#834) --- api_integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/api_integration_test.go b/api_integration_test.go index 8c9f3384f..7828d9451 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" From c37cf9ab5b887fe0195d3cc6240780e9b1928a04 Mon Sep 17 00:00:00 2001 From: Tommy Mathisen Date: Sun, 1 Sep 2024 18:30:29 +0300 Subject: [PATCH 158/242] Dynamic model (#838) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index e8e9242c9..12ce4b558 100644 --- a/completion.go +++ b/completion.go @@ -25,6 +25,7 @@ const ( GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4oLatest = "chatgpt-4o-latest" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" @@ -93,6 +94,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4o: true, GPT4o20240513: true, GPT4o20240806: true, + GPT4oLatest: true, GPT4oMini: true, GPT4oMini20240718: true, GPT4TurboPreview: true, From 643da8d650b1f7db4706076a53b9d0acddccbd17 Mon Sep 17 00:00:00 2001 From: Arun Das <89579096+Arundas666@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:19:57 +0530 Subject: [PATCH 159/242] depricated model GPT3Ada changed to GPT3Babbage002 (#843) * depricated model GPT3Ada changed to GPT3Babbage002 * Delete test.mp3 --- README.md | 4 ++-- example_test.go | 4 ++-- examples/completion/main.go | 2 +- stream_test.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0d6aafa40..b3ebc1471 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", } @@ -174,7 +174,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/example_test.go b/example_test.go index e5dbf44bf..5910ffb84 100644 --- a/example_test.go +++ b/example_test.go @@ -82,7 +82,7 @@ func ExampleClient_CreateCompletion() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, @@ -99,7 +99,7 @@ func ExampleClient_CreateCompletionStream() { stream, err := client.CreateCompletionStream( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/examples/completion/main.go b/examples/completion/main.go index 22af1fd82..8c5cbd5ca 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -13,7 +13,7 @@ func main() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, diff --git a/stream_test.go b/stream_test.go index 2822a3535..9dd95bb5f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -169,7 +169,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { var apiErr *openai.APIError _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, Prompt: "Hello!", Stream: true, }) From 194a03e763f0d71333a6088bf613a35f65c50447 Mon Sep 17 00:00:00 2001 From: Quest Henkart Date: Wed, 11 Sep 2024 22:24:49 +0200 Subject: [PATCH 160/242] Add refusal (#844) * add custom marshaller, documentation and isolate tests * fix linter * add missing field --- chat.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chat.go b/chat.go index 56e99a78b..dc60f35b9 100644 --- a/chat.go +++ b/chat.go @@ -82,6 +82,7 @@ type ChatMessagePart struct { type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in @@ -107,6 +108,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { msg := struct { Role string `json:"role"` Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"content,omitempty"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` @@ -115,9 +117,11 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { }(m) return json.Marshal(msg) } + msg := struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"-"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` @@ -131,12 +135,14 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { Role string `json:"role"` Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` }{} + if err := json.Unmarshal(bs, &msg); err == nil { *m = ChatCompletionMessage(msg) return nil @@ -144,6 +150,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { multiMsg := struct { Role string `json:"role"` Content string + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"content"` Name string `json:"name,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` From a5fb55321b43aa6b31bb3ff57d43cb5a8f2e17ef Mon Sep 17 00:00:00 2001 From: Aaron Batilo Date: Tue, 17 Sep 2024 14:19:47 -0600 Subject: [PATCH 161/242] Support OpenAI reasoning models (#850) These model strings are now available for use. More info: https://openai.com/index/introducing-openai-o1-preview/ https://platform.openai.com/docs/guides/reasoning --- completion.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/completion.go b/completion.go index 12ce4b558..e1e065a8b 100644 --- a/completion.go +++ b/completion.go @@ -17,6 +17,10 @@ var ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -83,6 +87,10 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, From 1ec8c24ea7ae0e31d5e8332f8a0349d2ecd5b913 Mon Sep 17 00:00:00 2001 From: Wei-An Yen Date: Sat, 21 Sep 2024 02:22:01 +0800 Subject: [PATCH 162/242] fix: jsonschema integer validation (#852) --- jsonschema/validate.go | 4 ++++ jsonschema/validate_test.go | 48 +++++++++++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/jsonschema/validate.go b/jsonschema/validate.go index f14ffd4c4..49f9b8859 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -36,6 +36,10 @@ func Validate(schema Definition, data any) bool { _, ok := data.(bool) return ok case Integer: + // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer + if num, ok := data.(float64); ok { + return num == float64(int64(num)) + } _, ok := data.(int) return ok case Null: diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index c2c47a2ce..6fa30ab0c 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -86,14 +86,6 @@ func TestUnmarshal(t *testing.T) { content []byte v any } - var result1 struct { - String string `json:"string"` - Number float64 `json:"number"` - } - var result2 struct { - String string `json:"string"` - Number float64 `json:"number"` - } tests := []struct { name string args args @@ -108,7 +100,10 @@ func TestUnmarshal(t *testing.T) { }, }, content: []byte(`{"string":"abc","number":123.4}`), - v: &result1, + v: &struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, }, false}, {"", args{ schema: jsonschema.Definition{ @@ -120,7 +115,40 @@ func TestUnmarshal(t *testing.T) { Required: []string{"string", "number"}, }, content: []byte(`{"string":"abc"}`), - v: result2, + v: struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, true}, + {"validate integer", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, false}, + {"validate integer failed", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123.4}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, }, true}, } for _, tt := range tests { From 9add1c348607c14e8fde9966713c97f9a2351919 Mon Sep 17 00:00:00 2001 From: Ivan Timofeev Date: Fri, 20 Sep 2024 23:40:24 +0300 Subject: [PATCH 163/242] add max_completions_tokens for o1 series models (#857) * add max_completions_tokens for o1 series models * add validation for o1 series models validataion + beta limitations --- chat.go | 35 +++++--- chat_stream.go | 4 + chat_stream_test.go | 21 +++++ chat_test.go | 211 ++++++++++++++++++++++++++++++++++++++++++++ completion.go | 82 +++++++++++++++++ 5 files changed, 341 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index dc60f35b9..d47c95e4f 100644 --- a/chat.go +++ b/chat.go @@ -200,18 +200,25 @@ type ChatCompletionResponseFormatJSONSchema struct { // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens The maximum number of tokens that can be generated in the chat completion. + // This value can be used to control costs for text generated via API. + // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens + MaxTokens int `json:"max_tokens,omitempty"` + // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, + // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias @@ -364,6 +371,10 @@ func (c *Client) CreateChatCompletion( return } + if err = validateRequestForO1Models(request); err != nil { + return + } + req, err := c.newRequest( ctx, http.MethodPost, diff --git a/chat_stream.go b/chat_stream.go index 3f90bc019..f43d01834 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,6 +60,10 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true + if err = validateRequestForO1Models(request); err != nil { + return + } + req, err := c.newRequest( ctx, http.MethodPost, diff --git a/chat_stream_test.go b/chat_stream_test.go index 63e45ee23..2e7c99b45 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -36,6 +36,27 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } } +func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1/chat/completions" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + req := openai.ChatCompletionRequest{ + Model: openai.O1Preview, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + } + _, err := client.CreateChatCompletionStream(ctx, req) + if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) { + t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err) + } +} + func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 37dc09d4d..a54dd35e0 100644 --- a/chat_test.go +++ b/chat_test.go @@ -52,6 +52,199 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "o1-preview_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Preview, + }, + expectedError: openai.ErrO1MaxTokensDeprecated, + }, + { + name: "o1-mini_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Mini, + }, + expectedError: openai.ErrO1MaxTokensDeprecated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, + }, + expectedError: openai.ErrO1BetaLimitationsLogprobs, + }, + { + name: "message_type_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + }, + }, + }, + expectedError: openai.ErrO1BetaLimitationsMessageTypes, + }, + { + name: "tool_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + }, + }, + }, + expectedError: openai.ErrO1BetaLimitationsTools, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionsTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrO1BetaLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense @@ -97,6 +290,24 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestO1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O1Preview, + MaxCompletionsTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index e1e065a8b..8e3172ace 100644 --- a/completion.go +++ b/completion.go @@ -7,11 +7,20 @@ import ( ) var ( + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll ) +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific @@ -85,6 +94,15 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) +// O1SeriesModels List of new Series of OpenAI models. +// Some old api attributes not supported. +var O1SeriesModels = map[string]struct{}{ + O1Mini: {}, + O1Mini20240912: {}, + O1Preview: {}, + O1Preview20240912: {}, +} + var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { O1Mini: true, @@ -146,6 +164,70 @@ func checkPromptType(prompt any) bool { return isString || isStringSlice } +var unsupportedToolsForO1Models = map[ToolType]struct{}{ + ToolTypeFunction: {}, +} + +var availableMessageRoleForO1Models = map[string]struct{}{ + ChatMessageRoleUser: {}, + ChatMessageRoleAssistant: {}, +} + +// validateRequestForO1Models checks for deprecated fields of OpenAI models. +func validateRequestForO1Models(request ChatCompletionRequest) error { + if _, found := O1SeriesModels[request.Model]; !found { + return nil + } + + if request.MaxTokens > 0 { + return ErrO1MaxTokensDeprecated + } + + // Beta Limitations + // refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations + // Streaming: not supported + if request.Stream { + return ErrO1BetaLimitationsStreaming + } + // Logprobs: not supported. + if request.LogProbs { + return ErrO1BetaLimitationsLogprobs + } + + // Message types: user and assistant messages only, system messages are not supported. + for _, m := range request.Messages { + if _, found := availableMessageRoleForO1Models[m.Role]; !found { + return ErrO1BetaLimitationsMessageTypes + } + } + + // Tools: tools, function calling, and response format parameters are not supported + for _, t := range request.Tools { + if _, found := unsupportedToolsForO1Models[t.Type]; found { + return ErrO1BetaLimitationsTools + } + } + + // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. + if request.Temperature > 0 && request.Temperature != 1 { + return ErrO1BetaLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrO1BetaLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrO1BetaLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrO1BetaLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrO1BetaLimitationsOther + } + + return nil +} + // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` From 9a4f3a7dbf8f29408848c94cf933d1530ae64526 Mon Sep 17 00:00:00 2001 From: Jialin Tian Date: Sat, 21 Sep 2024 04:49:28 +0800 Subject: [PATCH 164/242] feat: add ParallelToolCalls to RunRequest (#847) --- run.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/run.go b/run.go index 5598f1dfb..0cdec2bdc 100644 --- a/run.go +++ b/run.go @@ -37,6 +37,8 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } From e095df5325a39ed94940dbe3882d2aa14eb64ad0 Mon Sep 17 00:00:00 2001 From: floodwm Date: Fri, 20 Sep 2024 23:54:25 +0300 Subject: [PATCH 165/242] run_id string Optional (#855) Filter messages by the run ID that generated them. Co-authored-by: wappi --- .zshrc | 0 client_test.go | 2 +- messages.go | 5 +++++ messages_test.go | 5 +++-- 4 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 .zshrc diff --git a/.zshrc b/.zshrc new file mode 100644 index 000000000..e69de29bb diff --git a/client_test.go b/client_test.go index 7119d8a7e..3f27b9dd7 100644 --- a/client_test.go +++ b/client_test.go @@ -340,7 +340,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { return client.CreateMessage(ctx, "", MessageRequest{}) }}, {"ListMessage", func() (any, error) { - return client.ListMessage(ctx, "", nil, nil, nil, nil) + return client.ListMessage(ctx, "", nil, nil, nil, nil, nil) }}, {"RetrieveMessage", func() (any, error) { return client.RetrieveMessage(ctx, "", "") diff --git a/messages.go b/messages.go index 1fddd6314..eefc29a36 100644 --- a/messages.go +++ b/messages.go @@ -100,6 +100,7 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, order *string, after *string, before *string, + runID *string, ) (messages MessagesList, err error) { urlValues := url.Values{} if limit != nil { @@ -114,6 +115,10 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, if before != nil { urlValues.Add("before", *before) } + if runID != nil { + urlValues.Add("run_id", *runID) + } + encodedValues := "" if len(urlValues) > 0 { encodedValues = "?" + urlValues.Encode() diff --git a/messages_test.go b/messages_test.go index 71ceb4d3a..b25755f98 100644 --- a/messages_test.go +++ b/messages_test.go @@ -208,7 +208,7 @@ func TestMessages(t *testing.T) { } var msgs openai.MessagesList - msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") @@ -219,7 +219,8 @@ func TestMessages(t *testing.T) { order := "desc" after := "obj_foo" before := "obj_bar" - msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + runID := "run_abc123" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") From 38bdc812df391bcec3d7defda2a456ea00bb54e5 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 26 Sep 2024 18:25:56 +0800 Subject: [PATCH 166/242] Optimize Client Error Return (#856) * update client error return * update client_test.go * update client_test.go * update file_api_test.go * update client_test.go * update client_test.go --- client.go | 9 ++++++ client_test.go | 76 +++++++++++++++++++++++++++++++++-------------- error.go | 6 ++-- files_api_test.go | 1 + 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/client.go b/client.go index 9f547e7cb..583244fe1 100644 --- a/client.go +++ b/client.go @@ -285,10 +285,18 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } + return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) + } var errRes ErrorResponse err := json.NewDecoder(resp.Body).Decode(&errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, } @@ -298,6 +306,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } diff --git a/client_test.go b/client_test.go index 3f27b9dd7..18da787a0 100644 --- a/client_test.go +++ b/client_test.go @@ -134,14 +134,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -152,11 +155,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -165,11 +169,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -179,22 +184,53 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", }, { - name: "503 no message (Unknown response)", - httpCode: http.StatusServiceUnavailable, + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{} }`)), - expected: "error, status code: 503, message: ", + expected: "error, status code: 503, status: , message: ", + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`)), + expected: `error, status code: 413, status: , body: +413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ +`, + }, + { + name: "errorReader", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: &errorReader{err: errors.New("errorReader")}, + expected: "error, reading response body: errorReader", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testCase := &http.Response{} + testCase := &http.Response{ + Header: map[string][]string{ + "Content-Type": {tc.contentType}, + }, + } testCase.StatusCode = tc.httpCode testCase.Body = io.NopCloser(tc.body) err := client.handleErrorResp(testCase) @@ -203,12 +239,6 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } - - e := &APIError{} - if !errors.As(err, &e) { - t.Errorf("(%s) Expected error to be of type APIError", tc.name) - t.Fail() - } }) } } diff --git a/error.go b/error.go index 37959a272..1f6a8971d 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ type APIError struct { Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` + HTTPStatus string `json:"-"` HTTPStatusCode int `json:"-"` InnerError *InnerError `json:"innererror,omitempty"` } @@ -25,6 +26,7 @@ type InnerError struct { // RequestError provides information about generic request errors. type RequestError struct { + HTTPStatus string HTTPStatusCode int Err error } @@ -35,7 +37,7 @@ type ErrorResponse struct { func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) } return e.Message @@ -101,7 +103,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) } func (e *RequestError) Unwrap() error { diff --git a/files_api_test.go b/files_api_test.go index c92162a84..aa4fda458 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) From 7f80303cc393edf2f6806ca37668346f8fa6247e Mon Sep 17 00:00:00 2001 From: Alex Philipp Date: Thu, 26 Sep 2024 05:26:22 -0500 Subject: [PATCH 167/242] Fix max_completion_tokens (#860) The json tag is incorrect, and results in an error from the API when using the o1 model. I didn't modify the struct field name to maintain compatibility if anyone else had started using it, but it wouldn't work for them either. --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index d47c95e4f..dd99c530e 100644 --- a/chat.go +++ b/chat.go @@ -209,7 +209,7 @@ type ChatCompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionsTokens int `json:"max_completions_tokens,omitempty"` + MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` N int `json:"n,omitempty"` From e9d8485e90092b8adcce82fdd0dcd7cf10327e8d Mon Sep 17 00:00:00 2001 From: Jialin Tian Date: Thu, 26 Sep 2024 18:26:54 +0800 Subject: [PATCH 168/242] fix: ParallelToolCalls should be added to RunRequest (#861) --- run.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/run.go b/run.go index 0cdec2bdc..d3e755f05 100644 --- a/run.go +++ b/run.go @@ -37,8 +37,6 @@ type Run struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` - // Disable the default behavior of parallel tool calls by setting it: false. - ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` httpHeader } @@ -112,6 +110,8 @@ type RunRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // This can be either a string or a ResponseFormat object. ResponseFormat any `json:"response_format,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fdd59d93413154cd07b2e46a428b15eda40b26e2 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Thu, 26 Sep 2024 18:30:56 +0800 Subject: [PATCH 169/242] feat: usage struct add CompletionTokensDetails (#863) --- common.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/common.go b/common.go index cbfda4e3c..cde14154a 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,13 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` +} + +// CompletionTokensDetails Breakdown of tokens used in a completion. +type CompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` } From bac7d5936108965a9666a65d0d4d55bd0fe78808 Mon Sep 17 00:00:00 2001 From: Winston Liu Date: Thu, 3 Oct 2024 12:17:16 -0700 Subject: [PATCH 170/242] fix MaxCompletionTokens typo (#862) * fix spelling error * fix lint * Update chat.go * Update chat.go --- chat.go | 22 +++++++++++----------- chat_test.go | 38 +++++++++++++++++++------------------- completion.go | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/chat.go b/chat.go index dd99c530e..9adf2808d 100644 --- a/chat.go +++ b/chat.go @@ -207,18 +207,18 @@ type ChatCompletionRequest struct { // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens MaxTokens int `json:"max_tokens,omitempty"` - // MaxCompletionsTokens An upper bound for the number of tokens that can be generated for a completion, + // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning - MaxCompletionsTokens int `json:"max_completion_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias diff --git a/chat_test.go b/chat_test.go index a54dd35e0..134026cdb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -100,17 +100,17 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "log_probs_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - LogProbs: true, - Model: openai.O1Preview, + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, }, expectedError: openai.ErrO1BetaLimitationsLogprobs, }, { name: "message_type_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, @@ -122,8 +122,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "tool_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -143,8 +143,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_temperature_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -160,8 +160,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_top_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -178,8 +178,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_n_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -197,8 +197,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_presence_penalty_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -214,8 +214,8 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { { name: "set_frequency_penalty_unsupported", in: openai.ChatCompletionRequest{ - MaxCompletionsTokens: 1000, - Model: openai.O1Mini, + MaxCompletionTokens: 1000, + Model: openai.O1Mini, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -296,8 +296,8 @@ func TestO1ModelChatCompletions(t *testing.T) { defer teardown() server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ - Model: openai.O1Preview, - MaxCompletionsTokens: 1000, + Model: openai.O1Preview, + MaxCompletionTokens: 1000, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, diff --git a/completion.go b/completion.go index 8e3172ace..80c4d39ae 100644 --- a/completion.go +++ b/completion.go @@ -7,7 +7,7 @@ import ( ) var ( - ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionsTokens") //nolint:lll + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll From 7c145ebb4be68610bc3bb5377b754944307d44fd Mon Sep 17 00:00:00 2001 From: Julio Martins <89476495+juliomartinsdev@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:19:48 -0300 Subject: [PATCH 171/242] add jailbreak filter result, add ContentFilterResults on output (#864) * add jailbreak filter result * add content filter results on completion output * add profanity content filter --- chat.go | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/chat.go b/chat.go index 9adf2808d..a7dee8e03 100644 --- a/chat.go +++ b/chat.go @@ -41,11 +41,23 @@ type Violence struct { Severity string `json:"severity,omitempty"` } +type JailBreak struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + +type Profanity struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + type ContentFilterResults struct { - Hate Hate `json:"hate,omitempty"` - SelfHarm SelfHarm `json:"self_harm,omitempty"` - Sexual Sexual `json:"sexual,omitempty"` - Violence Violence `json:"violence,omitempty"` + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` + JailBreak JailBreak `json:"jailbreak,omitempty"` + Profanity Profanity `json:"profanity,omitempty"` } type PromptAnnotation struct { @@ -338,19 +350,21 @@ type ChatCompletionChoice struct { // function_call: The model decided to call a function // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete - FinishReason FinishReason `json:"finish_reason"` - LogProbs *LogProbs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` - SystemFingerprint string `json:"system_fingerprint"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` httpHeader } From 991326480f84981b6e89032b9f9710a3a83a6f0f Mon Sep 17 00:00:00 2001 From: Isaac Seymour Date: Wed, 9 Oct 2024 10:50:27 +0100 Subject: [PATCH 172/242] Completion API: add new params (#870) * Completion API: add 'store' param This param allows you to opt a completion request in to being stored, for use in distillations and evals. * Add cached and audio tokens to usage structs These have been added to the completions API recently: https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage --- common.go | 8 ++++++++ completion.go | 27 +++++++++++++++------------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/common.go b/common.go index cde14154a..8cc7289c0 100644 --- a/common.go +++ b/common.go @@ -7,10 +7,18 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` } // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { + AudioTokens int `json:"audio_tokens"` ReasoningTokens int `json:"reasoning_tokens"` } + +// PromptTokensDetails Breakdown of tokens used in the prompt. +type PromptTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + CachedTokens int `json:"cached_tokens"` +} diff --git a/completion.go b/completion.go index 80c4d39ae..afcf84671 100644 --- a/completion.go +++ b/completion.go @@ -238,18 +238,21 @@ type CompletionRequest struct { // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias - LogitBias map[string]int `json:"logit_bias,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Seed *int `json:"seed,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - User string `json:"user,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From cfe15ffd00bb908c32cf0d9e277786a14afdd2c7 Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Mon, 14 Oct 2024 18:50:39 +0530 Subject: [PATCH 173/242] return response body as byte slice for RequestError type (#873) --- client.go | 11 ++++++----- error.go | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 583244fe1..1e228a097 100644 --- a/client.go +++ b/client.go @@ -285,20 +285,21 @@ func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newB } func (c *Client) handleErrorResp(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("error, reading response body: %w", err) - } return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) } var errRes ErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errRes) + err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, + Body: body, } if errRes.Error != nil { reqErr.Err = errRes.Error diff --git a/error.go b/error.go index 1f6a8971d..fc9e7cdb9 100644 --- a/error.go +++ b/error.go @@ -29,6 +29,7 @@ type RequestError struct { HTTPStatus string HTTPStatusCode int Err error + Body []byte } type ErrorResponse struct { From 21f713457449b1ab386529b9495cbf1f27c0db5a Mon Sep 17 00:00:00 2001 From: Matt Jacobs Date: Mon, 14 Oct 2024 09:21:39 -0400 Subject: [PATCH 174/242] Adding new moderation model constants (#875) --- moderation.go | 12 ++++++++---- moderation_test.go | 2 ++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/moderation.go b/moderation.go index c8652efc8..a0e09c0ee 100644 --- a/moderation.go +++ b/moderation.go @@ -14,8 +14,10 @@ import ( // If you use text-moderation-stable, we will provide advanced notice before updating the model. // Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest. const ( - ModerationTextStable = "text-moderation-stable" - ModerationTextLatest = "text-moderation-latest" + ModerationOmniLatest = "omni-moderation-latest" + ModerationOmni20240926 = "omni-moderation-2024-09-26" + ModerationTextStable = "text-moderation-stable" + ModerationTextLatest = "text-moderation-latest" // Deprecated: use ModerationTextStable and ModerationTextLatest instead. ModerationText001 = "text-moderation-001" ) @@ -25,8 +27,10 @@ var ( ) var validModerationModel = map[string]struct{}{ - ModerationTextStable: {}, - ModerationTextLatest: {}, + ModerationOmniLatest: {}, + ModerationOmni20240926: {}, + ModerationTextStable: {}, + ModerationTextLatest: {}, } // ModerationRequest represents a request structure for moderation API. diff --git a/moderation_test.go b/moderation_test.go index 61171c384..a97f25bc6 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -37,6 +37,8 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), getModerationModelTestOption(openai.ModerationTextStable, nil), getModerationModelTestOption(openai.ModerationTextLatest, nil), + getModerationModelTestOption(openai.ModerationOmni20240926, nil), + getModerationModelTestOption(openai.ModerationOmniLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() From b162541513db0cf3d4d48da03be22b05861269cb Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 15 Oct 2024 20:09:34 +0100 Subject: [PATCH 175/242] Cleanup (#879) * remove obsolete files * update readme --- .zshrc | 0 Makefile | 35 ----------------------------------- README.md | 2 +- 3 files changed, 1 insertion(+), 36 deletions(-) delete mode 100644 .zshrc delete mode 100644 Makefile diff --git a/.zshrc b/.zshrc deleted file mode 100644 index e69de29bb..000000000 diff --git a/Makefile b/Makefile deleted file mode 100644 index 2e608aa0c..000000000 --- a/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -##@ General - -# The help target prints out all targets with their descriptions organized -# beneath their categories. The categories are represented by '##@' and the -# target descriptions by '##'. The awk commands is responsible for reading the -# entire set of makefiles included in this invocation, looking for lines of the -# file as xyz: ## something, and then pretty-format the target and help. Then, -# if there's a line with ##@ something, that gets pretty-printed as a category. -# More info on the usage of ANSI control characters for terminal formatting: -# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters -# More info on the awk command: -# http://linuxcommand.org/lc3_adv_awk.php - -.PHONY: help -help: ## Display this help. - @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) - - -##@ Development - -.PHONY: test -TEST_ARGS ?= -v -TEST_TARGETS ?= ./... -test: ## Test the Go modules within this package. - @ echo ▶️ go test $(TEST_ARGS) $(TEST_TARGETS) - go test $(TEST_ARGS) $(TEST_TARGETS) - @ echo ✅ success! - - -.PHONY: lint -LINT_TARGETS ?= ./... -lint: ## Lint Go code with the installed golangci-lint - @ echo "▶️ golangci-lint run" - golangci-lint run $(LINT_TARGETS) - @ echo "✅ golangci-lint run" diff --git a/README.md b/README.md index b3ebc1471..57d1d35bf 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: -* ChatGPT +* ChatGPT 4o, o1 * GPT-3, GPT-4 * DALL·E 2, DALL·E 3 * Whisper From 9fe2c6ce1f5b756cd172ae9a7786beea69b2956f Mon Sep 17 00:00:00 2001 From: Sander Mack-Crane <71154168+smackcrane@users.noreply.github.com> Date: Tue, 15 Oct 2024 14:16:57 -0600 Subject: [PATCH 176/242] Completion API: add Store and Metadata parameters (#878) --- chat.go | 5 +++++ completion.go | 26 ++++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/chat.go b/chat.go index a7dee8e03..2b13f8dd7 100644 --- a/chat.go +++ b/chat.go @@ -255,6 +255,11 @@ type ChatCompletionRequest struct { StreamOptions *StreamOptions `json:"stream_options,omitempty"` // Disable the default behavior of parallel tool calls by setting it: false. ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` } type StreamOptions struct { diff --git a/completion.go b/completion.go index afcf84671..84ef2ad26 100644 --- a/completion.go +++ b/completion.go @@ -241,18 +241,20 @@ type CompletionRequest struct { LogitBias map[string]int `json:"logit_bias,omitempty"` // Store can be set to true to store the output of this completion request for use in distillations and evals. // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store - Store bool `json:"store,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - Seed *int `json:"seed,omitempty"` - Stop []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - Suffix string `json:"suffix,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - User string `json:"user,omitempty"` + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. From fb15ff9dcd861e601fc2c54078aac2bbd3c06ce8 Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Tue, 22 Oct 2024 02:19:34 +0530 Subject: [PATCH 177/242] Handling for non-json response (#881) * removed handling for non-json response * added response body in RequestError.Error() and updated tests * done linting --- client.go | 3 --- client_test.go | 35 ++++++++++++++++++++--------------- error.go | 5 ++++- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/client.go b/client.go index 1e228a097..ed8595e0b 100644 --- a/client.go +++ b/client.go @@ -289,9 +289,6 @@ func (c *Client) handleErrorResp(resp *http.Response) error { if err != nil { return fmt.Errorf("error, reading response body: %w", err) } - if !strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") { - return fmt.Errorf("error, status code: %d, status: %s, body: %s", resp.StatusCode, resp.Status, body) - } var errRes ErrorResponse err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { diff --git a/client_test.go b/client_test.go index 18da787a0..354a6b3f5 100644 --- a/client_test.go +++ b/client_test.go @@ -194,26 +194,31 @@ func TestHandleErrorResp(t *testing.T) { { "error":{} }`)), - expected: "error, status code: 503, status: , message: ", + expected: `error, status code: 503, status: , message: , body: + { + "error":{} + }`, }, { name: "413 Request Entity Too Large", httpCode: http.StatusRequestEntityTooLarge, contentType: "text/html", - body: bytes.NewReader([]byte(` -413 Request Entity Too Large - -

413 Request Entity Too Large

-
nginx
- -`)), - expected: `error, status code: 413, status: , body: -413 Request Entity Too Large - -

413 Request Entity Too Large

-
nginx
- -`, + body: bytes.NewReader([]byte(` + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `)), + expected: `error, status code: 413, status: , message: invalid character '<' looking for beginning of value, body: + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `, }, { name: "errorReader", diff --git a/error.go b/error.go index fc9e7cdb9..8a74bd52c 100644 --- a/error.go +++ b/error.go @@ -104,7 +104,10 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Err) + return fmt.Sprintf( + "error, status code: %d, status: %s, message: %s, body: %s", + e.HTTPStatusCode, e.HTTPStatus, e.Err, e.Body, + ) } func (e *RequestError) Unwrap() error { From 3672c0dec601f89037d8d54e7df653d7df1f0c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edin=20=C4=86orali=C4=87?= <73831203+ecoralic@users.noreply.github.com> Date: Mon, 21 Oct 2024 22:57:02 +0200 Subject: [PATCH 178/242] fix: Updated Assistent struct with latest fields based on OpenAI docs (#883) --- assistant.go | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/assistant.go b/assistant.go index 4c89c1b2f..8aab5bcf0 100644 --- a/assistant.go +++ b/assistant.go @@ -14,17 +14,20 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` // Deprecated in v2 + Metadata map[string]any `json:"metadata,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` httpHeader } From 6e087322b77693e6e9227d9950a0c8d8a10a8d1a Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Fri, 25 Oct 2024 19:11:45 +0530 Subject: [PATCH 179/242] Updated checkPromptType function to handle prompt list in completions (#885) * updated checkPromptType function to handle prompt list in completions * removed generated test file * added corresponding unit testcases * Updated to use less nesting with early returns --- completion.go | 18 ++++++++++- completion_test.go | 78 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/completion.go b/completion.go index 84ef2ad26..77ea8c3ab 100644 --- a/completion.go +++ b/completion.go @@ -161,7 +161,23 @@ func checkEndpointSupportsModel(endpoint, model string) bool { func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) - return isString || isStringSlice + if isString || isStringSlice { + return true + } + + // check if it is prompt is []string hidden under []any + slice, isSlice := prompt.([]any) + if !isSlice { + return false + } + + for _, item := range slice { + _, itemIsString := item.(string) + if !itemIsString { + return false + } + } + return true // all items in the slice are string, so it is []string } var unsupportedToolsForO1Models = map[ToolType]struct{}{ diff --git a/completion_test.go b/completion_test.go index 89950bf94..935bbe864 100644 --- a/completion_test.go +++ b/completion_test.go @@ -59,6 +59,38 @@ func TestCompletions(t *testing.T) { checks.NoError(t, err, "CreateCompletion error") } +// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts with wrong type. +func TestMultiplePromptsCompletionsWrong(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", 9}, + } + _, err := client.CreateCompletion(context.Background(), req) + if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) + } +} + +// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts. +func TestMultiplePromptsCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, + } + _, err := client.CreateCompletion(context.Background(), req) + checks.NoError(t, err, "CreateCompletion error") +} + // handleCompletionEndpoint Handles the completion endpoint by the test server. func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -87,24 +119,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if n == 0 { n = 1 } + // Handle different types of prompts: single string or list of strings + prompts := []string{} + switch v := completionReq.Prompt.(type) { + case string: + prompts = append(prompts, v) + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + prompts = append(prompts, str) + } + } + default: + http.Error(w, "Invalid prompt type", http.StatusBadRequest) + return + } + for i := 0; i < n; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr + for _, prompt := range prompts { + // Generate a random string of length completionReq.MaxTokens + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = prompt + completionStr + } + + res.Choices = append(res.Choices, openai.CompletionChoice{ + Text: completionStr, + Index: len(res.Choices), + }) } - res.Choices = append(res.Choices, openai.CompletionChoice{ - Text: completionStr, - Index: i, - }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * n - completionTokens := completionReq.MaxTokens * n + + inputTokens := 0 + for _, prompt := range prompts { + inputTokens += numTokens(prompt) + } + inputTokens *= n + completionTokens := completionReq.MaxTokens * len(prompts) * n res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } + + // Serialize the response and send it back resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } From d10f1b81995ddce1aacacfa671d79f2784a68ef4 Mon Sep 17 00:00:00 2001 From: genglixia <62233468+Yu0u@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:22:52 +0800 Subject: [PATCH 180/242] add chatcompletion stream delta refusal and logprobs (#882) * add chatcompletion stream refusal and logprobs * fix slice to struct * add integration test * fix lint * fix lint * fix: the object should be pointer --------- Co-authored-by: genglixia --- chat_stream.go | 28 ++++- chat_stream_test.go | 265 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 289 insertions(+), 4 deletions(-) diff --git a/chat_stream.go b/chat_stream.go index f43d01834..58b2651c0 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -10,13 +10,33 @@ type ChatCompletionStreamChoiceDelta struct { Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` +} + +type ChatCompletionStreamChoiceLogprobs struct { + Content []ChatCompletionTokenLogprob `json:"content,omitempty"` + Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` +} + +type ChatCompletionTokenLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes,omitempty"` + Logprob float64 `json:"logprob,omitempty"` + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` +} + +type ChatCompletionTokenLogprobTopLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes"` + Logprob float64 `json:"logprob"` } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` - ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type PromptFilterResult struct { diff --git a/chat_stream_test.go b/chat_stream_test.go index 2e7c99b45..14684146c 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -358,6 +358,271 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithRefusal(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":" World"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT4oMini20240718, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: " World", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":{"content":[{"token":"Hello","logprob":-0.000020458236,"bytes":[72,101,108,108,111],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":" World"},"logprobs":{"content":[{"token":" World","logprob":-0.00055303273,"bytes":[32,87,111,114,108,100],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{}, + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: "Hello", + Logprob: -0.000020458236, + Bytes: []int64{72, 101, 108, 108, 111}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: " World", + Logprob: -0.00055303273, + Bytes: []int64{32, 87, 111, 114, 108, 100}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { wantCode := "429" wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + From f5e6e0e4fed1284bafa4805f6487e5b5f8a4ccd1 Mon Sep 17 00:00:00 2001 From: Matt Davis Date: Fri, 8 Nov 2024 08:53:02 -0500 Subject: [PATCH 181/242] Added Vector Store File List properties that allow for pagination (#891) --- vector_store.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vector_store.go b/vector_store.go index 5c364362a..682bb1cf9 100644 --- a/vector_store.go +++ b/vector_store.go @@ -83,6 +83,9 @@ type VectorStoreFileRequest struct { type VectorStoreFilesList struct { VectorStoreFiles []VectorStoreFile `json:"data"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` httpHeader } From 6d066bb12dfbaa3cefa83f204c431fb0d0ef02fa Mon Sep 17 00:00:00 2001 From: Denny Depok <61371551+kodernubie@users.noreply.github.com> Date: Fri, 8 Nov 2024 20:54:27 +0700 Subject: [PATCH 182/242] Support Attachments in MessageRequest (#890) * add attachments in MessageRequest * Move tools const to message * remove const, just use assistanttool const --- messages.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/messages.go b/messages.go index eefc29a36..902363938 100644 --- a/messages.go +++ b/messages.go @@ -52,10 +52,11 @@ type ImageFile struct { } type MessageRequest struct { - Role string `json:"role"` - Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility - Metadata map[string]any `json:"metadata,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` } type MessageFile struct { From b3ece4d32e9416105bc2427b735448e82abd448b Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Wed, 20 Nov 2024 02:07:10 +0530 Subject: [PATCH 183/242] Updated client_test to solve lint error (#900) * updated client_test to solve lint error * modified golangci yml to solve linter issues * minor change --- .golangci.yml | 6 +++--- client_test.go | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 58fab4a20..724cb7375 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -57,7 +57,7 @@ linters-settings: # Default: true skipRecvDeref: false - gomnd: + mnd: # List of function patterns to exclude from analysis. # Values always ignored: `time.Date` # Default: [] @@ -167,7 +167,7 @@ linters: - durationcheck # check for two durations multiplied together - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds + # Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers @@ -180,7 +180,6 @@ linters: - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomnd # An analyzer to detect magic numbers. - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with f at the end @@ -188,6 +187,7 @@ linters: - lll # Reports long lines - makezero # Finds slice declarations with non-zero initial length # - nakedret # Finds naked returns in functions greater than a specified function length + - mnd # An analyzer to detect magic numbers. - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. diff --git a/client_test.go b/client_test.go index 354a6b3f5..2ed82f13c 100644 --- a/client_test.go +++ b/client_test.go @@ -513,8 +513,14 @@ func TestClient_suffixWithAPIVersion(t *testing.T) { } defer func() { if r := recover(); r != nil { - if r.(string) != tt.wantPanic { - t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic) + // Check if the panic message matches the expected panic message + if rStr, ok := r.(string); ok { + if rStr != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", rStr, tt.wantPanic) + } + } else { + // If the panic is not a string, log it + t.Errorf("suffixWithAPIVersion() panicked with non-string value: %v", r) } } }() From 168761616567a1cf2645c98f6f19329877f0beaa Mon Sep 17 00:00:00 2001 From: LinYushen Date: Thu, 21 Nov 2024 04:26:10 +0800 Subject: [PATCH 184/242] o1 model support stream (#904) --- chat_stream_test.go | 21 --------------------- completion.go | 7 ------- 2 files changed, 28 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 14684146c..28a9acf67 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -36,27 +36,6 @@ func TestChatCompletionsStreamWrongModel(t *testing.T) { } } -func TestChatCompletionsStreamWithO1BetaLimitations(t *testing.T) { - config := openai.DefaultConfig("whatever") - config.BaseURL = "http://localhost/v1/chat/completions" - client := openai.NewClientWithConfig(config) - ctx := context.Background() - - req := openai.ChatCompletionRequest{ - Model: openai.O1Preview, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - } - _, err := client.CreateChatCompletionStream(ctx, req) - if !errors.Is(err, openai.ErrO1BetaLimitationsStreaming) { - t.Fatalf("CreateChatCompletion should return ErrO1BetaLimitationsStreaming, but returned: %v", err) - } -} - func TestCreateChatCompletionStream(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/completion.go b/completion.go index 77ea8c3ab..9e3073694 100644 --- a/completion.go +++ b/completion.go @@ -15,7 +15,6 @@ var ( var ( ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll - ErrO1BetaLimitationsStreaming = errors.New("this model has beta-limitations, streaming not supported") //nolint:lll ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll @@ -199,12 +198,6 @@ func validateRequestForO1Models(request ChatCompletionRequest) error { return ErrO1MaxTokensDeprecated } - // Beta Limitations - // refs:https://platform.openai.com/docs/guides/reasoning/beta-limitations - // Streaming: not supported - if request.Stream { - return ErrO1BetaLimitationsStreaming - } // Logprobs: not supported. if request.LogProbs { return ErrO1BetaLimitationsLogprobs From 74ed75f291f8f55d1104a541090d46c021169115 Mon Sep 17 00:00:00 2001 From: nagar-ajay Date: Thu, 21 Nov 2024 02:09:44 +0530 Subject: [PATCH 185/242] Make user field optional in embedding request (#899) * make user optional in embedding request * fix unit test --- batch_test.go | 2 +- embeddings.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/batch_test.go b/batch_test.go index 4b2261e0e..f4714f4eb 100644 --- a/batch_test.go +++ b/batch_test.go @@ -211,7 +211,7 @@ func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { Input: []string{"Hello", "World"}, }, }, - }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/embeddings.go b/embeddings.go index 74eb8aa57..4a0e682da 100644 --- a/embeddings.go +++ b/embeddings.go @@ -155,7 +155,7 @@ const ( type EmbeddingRequest struct { Input any `json:"input"` Model EmbeddingModel `json:"model"` - User string `json:"user"` + User string `json:"user,omitempty"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. From 21fa42c18dbafef43977ab73c403eef6d694b14a Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Sat, 30 Nov 2024 17:39:47 +0800 Subject: [PATCH 186/242] feat: add gpt-4o-2024-11-20 model (#905) --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 9e3073694..f11566081 100644 --- a/completion.go +++ b/completion.go @@ -37,6 +37,7 @@ const ( GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" GPT4oLatest = "chatgpt-4o-latest" GPT4oMini = "gpt-4o-mini" GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" @@ -119,6 +120,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4o: true, GPT4o20240513: true, GPT4o20240806: true, + GPT4o20241120: true, GPT4oLatest: true, GPT4oMini: true, GPT4oMini20240718: true, From c203ca001fecd40210cfcf9923ab69235c92e321 Mon Sep 17 00:00:00 2001 From: Qiying Wang <781345688@qq.com> Date: Sat, 30 Nov 2024 18:29:05 +0800 Subject: [PATCH 187/242] feat: add RecvRaw (#896) --- stream_reader.go | 39 ++++++++++++++++++++++----------------- stream_reader_test.go | 13 +++++++++++++ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..ecfa26807 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -32,17 +32,28 @@ type streamReader[T streamable] struct { } func (stream *streamReader[T]) Recv() (response T, err error) { - if stream.isFinished { - err = io.EOF + rawLine, err := stream.RecvRaw() + if err != nil { return } - response, err = stream.processLines() - return + err = stream.unmarshaler.Unmarshal(rawLine, &response) + if err != nil { + return + } + return response, nil +} + +func (stream *streamReader[T]) RecvRaw() ([]byte, error) { + if stream.isFinished { + return nil, io.EOF + } + + return stream.processLines() } //nolint:gocognit -func (stream *streamReader[T]) processLines() (T, error) { +func (stream *streamReader[T]) processLines() ([]byte, error) { var ( emptyMessagesCount uint hasErrorPrefix bool @@ -53,9 +64,9 @@ func (stream *streamReader[T]) processLines() (T, error) { if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { - return *new(T), fmt.Errorf("error, %w", respErr.Error) + return nil, fmt.Errorf("error, %w", respErr.Error) } - return *new(T), readErr + return nil, readErr } noSpaceLine := bytes.TrimSpace(rawLine) @@ -68,11 +79,11 @@ func (stream *streamReader[T]) processLines() (T, error) { } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { - return *new(T), writeErr + return nil, writeErr } emptyMessagesCount++ if emptyMessagesCount > stream.emptyMessagesLimit { - return *new(T), ErrTooManyEmptyStreamMessages + return nil, ErrTooManyEmptyStreamMessages } continue @@ -81,16 +92,10 @@ func (stream *streamReader[T]) processLines() (T, error) { noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true - return *new(T), io.EOF - } - - var response T - unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) - if unmarshalErr != nil { - return *new(T), unmarshalErr + return nil, io.EOF } - return response, nil + return noPrefixLine, nil } } diff --git a/stream_reader_test.go b/stream_reader_test.go index cd6e46eff..449a14b43 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { _, err := stream.Recv() checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } + +func TestStreamReaderRecvRaw(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))), + } + rawLine, err := stream.RecvRaw() + if err != nil { + t.Fatalf("Did not return raw line: %v", err) + } + if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { + t.Fatalf("Did not return raw line: %v", string(rawLine)) + } +} From af5355f5b1a7701f891109e8a17b7b245ac5363b Mon Sep 17 00:00:00 2001 From: Tim Misiak Date: Sun, 8 Dec 2024 05:12:05 -0800 Subject: [PATCH 188/242] Fix ID field to be optional (#911) The ID field is not always present for streaming responses. Without omitempty, the entire ToolCall struct will be missing. --- chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chat.go b/chat.go index 2b13f8dd7..fcaf79cf7 100644 --- a/chat.go +++ b/chat.go @@ -179,7 +179,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` - ID string `json:"id"` + ID string `json:"id,omitempty"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` } From 56a9acf86fc3ce0e9030feafa346d64bade94027 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sun, 8 Dec 2024 13:16:48 +0000 Subject: [PATCH 189/242] Ignore test.mp3 (#913) --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 99b40bf17..b0ac1605c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ # Auth token for tests .openai-token -.idea \ No newline at end of file +.idea + +# Generated by tests +test.mp3 \ No newline at end of file From 2a0ff5ac63e460cbe44cccd0d4199d51bf8682a4 Mon Sep 17 00:00:00 2001 From: Sabuhi Gurbani <51547928+sabuhigr@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:01:16 +0400 Subject: [PATCH 190/242] Added additional_messages (#914) --- run.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/run.go b/run.go index d3e755f05..9c51aaf8d 100644 --- a/run.go +++ b/run.go @@ -83,12 +83,13 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model string `json:"model,omitempty"` - Instructions string `json:"instructions,omitempty"` - AdditionalInstructions string `json:"additional_instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. From 7a2915a37dae714f40a4b5575fbf98430fe1d6aa Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Fri, 31 Jan 2025 20:55:41 +0200 Subject: [PATCH 191/242] Simplify tests with T.TempDir (#929) --- .golangci.yml | 1 + audio_api_test.go | 10 ++------- audio_test.go | 8 ++----- image_api_test.go | 42 +++++++++++------------------------ internal/form_builder_test.go | 17 ++++---------- internal/test/helpers.go | 10 --------- openai_test.go | 2 +- speech_test.go | 4 +--- 8 files changed, 24 insertions(+), 70 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 724cb7375..9d22d9bd3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -206,6 +206,7 @@ linters: - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters + - usetesting # Reports uses of functions with replacement inside the testing package - wastedassign # wastedassign finds wasted assignment statements. - whitespace # Tool for detection of leading and trailing whitespace ## you may want to enable diff --git a/audio_api_test.go b/audio_api_test.go index c24598443..6c6a35643 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -40,12 +40,9 @@ func TestAudio(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ @@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ diff --git a/audio_test.go b/audio_test.go index 235931f36..9f32d5468 100644 --- a/audio_test.go +++ b/audio_test.go @@ -13,9 +13,7 @@ import ( ) func TestAudioWithFailingFormBuilder(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ @@ -63,9 +61,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { func TestCreateFileField(t *testing.T) { t.Run("createFileField failing file", func(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ diff --git a/image_api_test.go b/image_api_test.go index 48416b1e2..f6057b77d 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "path/filepath" "testing" "time" @@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } + defer origin.Close() - mask, err := os.Create("mask.png") + mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png")) if err != nil { - t.Error("open mask file error") - return + t.Fatalf("open mask file error: %v", err) } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() + defer mask.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index d3faf9982..8df989e3b 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,7 +1,6 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -20,15 +19,11 @@ func (*failingWriter) Write([]byte) (int, error) { } func TestFormBuilderWithFailingWriter(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } defer file.Close() - defer os.Remove(file.Name()) builder := NewFormBuilder(&failingWriter{}) err = builder.CreateFormFile("file", file) @@ -36,15 +31,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { } func TestFormBuilderWithClosedFile(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } file.Close() - defer os.Remove(file.Name()) body := &bytes.Buffer{} builder := NewFormBuilder(body) diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 0e63ae82f..dc5fa6646 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) { file.Close() } -// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called. -func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { - t.Helper() - - path, err := os.MkdirTemp(os.TempDir(), "") - checks.NoError(t, err) - - return path, func() { os.RemoveAll(path) } -} - // TokenRoundTripper is a struct that implements the RoundTripper // interface, specifically to handle the authentication token by adding a token // to the request header. We need this because the API requires that each diff --git a/openai_test.go b/openai_test.go index 729d8880c..48a00b9fc 100644 --- a/openai_test.go +++ b/openai_test.go @@ -31,7 +31,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer // -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { return int(float32(len(s)) / 4) } diff --git a/speech_test.go b/speech_test.go index f1e405c39..67a3feabc 100644 --- a/speech_test.go +++ b/speech_test.go @@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) { defer teardown() server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { - dir, cleanup := test.CreateTestDirectory(t) - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) - defer cleanup() // audio endpoints only accept POST requests if r.Method != "POST" { From 9823a8bbbdc00871c1d569ed2b90111af94a4fb2 Mon Sep 17 00:00:00 2001 From: Trevor Creech Date: Fri, 31 Jan 2025 10:57:57 -0800 Subject: [PATCH 192/242] Chat Completion API: add ReasoningEffort and new o1 models (#928) * add reasoning_effort param * add o1 model * fix lint --- chat.go | 2 ++ completion.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/chat.go b/chat.go index fcaf79cf7..7a44fd831 100644 --- a/chat.go +++ b/chat.go @@ -258,6 +258,8 @@ type ChatCompletionRequest struct { // Store can be set to true to store the output of this completion request for use in distillations and evals. // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store Store bool `json:"store,omitempty"` + // Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high". + ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/completion.go b/completion.go index f11566081..62724688a 100644 --- a/completion.go +++ b/completion.go @@ -29,6 +29,8 @@ const ( O1Mini20240912 = "o1-mini-2024-09-12" O1Preview = "o1-preview" O1Preview20240912 = "o1-preview-2024-09-12" + O1 = "o1" + O120241217 = "o1-2024-12-17" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" From 45aa99607be0b4c225af57c36fb5cff7328957de Mon Sep 17 00:00:00 2001 From: saileshd1402 Date: Sat, 1 Feb 2025 00:35:29 +0530 Subject: [PATCH 193/242] Make "Content" field in "ChatCompletionMessage" omitempty (#926) --- chat.go | 6 +++--- chat_test.go | 2 +- openai_test.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/chat.go b/chat.go index 7a44fd831..8ea7238fe 100644 --- a/chat.go +++ b/chat.go @@ -93,7 +93,7 @@ type ChatMessagePart struct { type ChatCompletionMessage struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart @@ -132,7 +132,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { msg := struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart `json:"-"` Name string `json:"name,omitempty"` @@ -146,7 +146,7 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart Name string `json:"name,omitempty"` diff --git a/chat_test.go b/chat_test.go index 134026cdb..cea549cbd 100644 --- a/chat_test.go +++ b/chat_test.go @@ -631,7 +631,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) { t.Fatalf("Unexpected error") } res = strings.ReplaceAll(string(s), " ", "") - if res != `{"role":"user","content":""}` { + if res != `{"role":"user"}` { t.Fatalf("invalid message: %s", string(s)) } } diff --git a/openai_test.go b/openai_test.go index 48a00b9fc..6c26eebd1 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer +// https://beta.openai.com/tokenizer/ // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { From 2054db016c335136eba471aebf49cc78981dd502 Mon Sep 17 00:00:00 2001 From: rory malcolm Date: Thu, 6 Feb 2025 14:53:19 +0000 Subject: [PATCH 194/242] Add support for O3-mini (#930) * Add support for O3-mini - Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini). * Deprecate and refactor - Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther` - Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't) * Move reasoning validation to `reasoning_validator.go` - Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request - Also adds a test for chat streams * Final nits --- chat.go | 3 +- chat_stream.go | 3 +- chat_stream_test.go | 167 +++++++++++++++++++++++++++++++++++++++++ chat_test.go | 153 +++++++++++++++++++++++++++++++++++-- completion.go | 86 +-------------------- reasoning_validator.go | 111 +++++++++++++++++++++++++++ 6 files changed, 431 insertions(+), 92 deletions(-) create mode 100644 reasoning_validator.go diff --git a/chat.go b/chat.go index 8ea7238fe..ce24fa34a 100644 --- a/chat.go +++ b/chat.go @@ -392,7 +392,8 @@ func (c *Client) CreateChatCompletion( return } - if err = validateRequestForO1Models(request); err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 58b2651c0..525b4457a 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - if err = validateRequestForO1Models(request); err != nil { + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { return } diff --git a/chat_stream_test.go b/chat_stream_test.go index 28a9acf67..4d992e4d1 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return true } +func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxCompletionTokens: 2000, + Model: openai.O3Mini20250131, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " from", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " O3Mini", + }, + }, + }, + }, + { + ID: "5", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/chat_test.go b/chat_test.go index cea549cbd..fc6c4a936 100644 --- a/chat_test.go +++ b/chat_test.go @@ -64,7 +64,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { MaxTokens: 5, Model: openai.O1Preview, }, - expectedError: openai.ErrO1MaxTokensDeprecated, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, }, { name: "o1-mini_MaxTokens_deprecated", @@ -72,7 +72,7 @@ func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { MaxTokens: 5, Model: openai.O1Mini, }, - expectedError: openai.ErrO1MaxTokensDeprecated, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, }, } @@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { LogProbs: true, Model: openai.O1Preview, }, - expectedError: openai.ErrO1BetaLimitationsLogprobs, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, }, { name: "message_type_unsupported", @@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, Temperature: float32(2), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_top_unsupported", @@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { Temperature: float32(1), TopP: float32(0.1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_n_unsupported", @@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { TopP: float32(1), N: 2, }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_presence_penalty_unsupported", @@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, PresencePenalty: float32(1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, }, { name: "set_frequency_penalty_unsupported", @@ -226,7 +226,127 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, FrequencyPenalty: float32(0.1), }, - expectedError: openai.ErrO1BetaLimitationsOther, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O3Mini, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, }, } @@ -308,6 +428,23 @@ func TestO1ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestO3ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O3Mini, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() diff --git a/completion.go b/completion.go index 62724688a..1985293f8 100644 --- a/completion.go +++ b/completion.go @@ -2,24 +2,9 @@ package openai import ( "context" - "errors" "net/http" ) -var ( - ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll - ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll - ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll -) - -var ( - ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll - ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll - ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll - ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll -) - // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific @@ -31,6 +16,8 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -96,21 +83,14 @@ const ( CodexCodeDavinci001 = "code-davinci-001" ) -// O1SeriesModels List of new Series of OpenAI models. -// Some old api attributes not supported. -var O1SeriesModels = map[string]struct{}{ - O1Mini: {}, - O1Mini20240912: {}, - O1Preview: {}, - O1Preview20240912: {}, -} - var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { O1Mini: true, O1Mini20240912: true, O1Preview: true, O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, @@ -183,64 +163,6 @@ func checkPromptType(prompt any) bool { return true // all items in the slice are string, so it is []string } -var unsupportedToolsForO1Models = map[ToolType]struct{}{ - ToolTypeFunction: {}, -} - -var availableMessageRoleForO1Models = map[string]struct{}{ - ChatMessageRoleUser: {}, - ChatMessageRoleAssistant: {}, -} - -// validateRequestForO1Models checks for deprecated fields of OpenAI models. -func validateRequestForO1Models(request ChatCompletionRequest) error { - if _, found := O1SeriesModels[request.Model]; !found { - return nil - } - - if request.MaxTokens > 0 { - return ErrO1MaxTokensDeprecated - } - - // Logprobs: not supported. - if request.LogProbs { - return ErrO1BetaLimitationsLogprobs - } - - // Message types: user and assistant messages only, system messages are not supported. - for _, m := range request.Messages { - if _, found := availableMessageRoleForO1Models[m.Role]; !found { - return ErrO1BetaLimitationsMessageTypes - } - } - - // Tools: tools, function calling, and response format parameters are not supported - for _, t := range request.Tools { - if _, found := unsupportedToolsForO1Models[t.Type]; found { - return ErrO1BetaLimitationsTools - } - } - - // Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0. - if request.Temperature > 0 && request.Temperature != 1 { - return ErrO1BetaLimitationsOther - } - if request.TopP > 0 && request.TopP != 1 { - return ErrO1BetaLimitationsOther - } - if request.N > 0 && request.N != 1 { - return ErrO1BetaLimitationsOther - } - if request.PresencePenalty > 0 { - return ErrO1BetaLimitationsOther - } - if request.FrequencyPenalty > 0 { - return ErrO1BetaLimitationsOther - } - - return nil -} - // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { Model string `json:"model"` diff --git a/reasoning_validator.go b/reasoning_validator.go new file mode 100644 index 000000000..42a9fbd2e --- /dev/null +++ b/reasoning_validator.go @@ -0,0 +1,111 @@ +package openai + +import ( + "errors" + "strings" +) + +var ( + // Deprecated: use ErrReasoningModelMaxTokensDeprecated instead. + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll +) + +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + // Deprecated: use ErrReasoningModelLimitations* instead. + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var ( + //nolint:lll + ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") + ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var unsupportedToolsForO1Models = map[ToolType]struct{}{ + ToolTypeFunction: {}, +} + +var availableMessageRoleForO1Models = map[string]struct{}{ + ChatMessageRoleUser: {}, + ChatMessageRoleAssistant: {}, +} + +// ReasoningValidator handles validation for o-series model requests. +type ReasoningValidator struct{} + +// NewReasoningValidator creates a new validator for o-series models. +func NewReasoningValidator() *ReasoningValidator { + return &ReasoningValidator{} +} + +// Validate performs all validation checks for o-series models. +func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { + o1Series := strings.HasPrefix(request.Model, "o1") + o3Series := strings.HasPrefix(request.Model, "o3") + + if !o1Series && !o3Series { + return nil + } + + if err := v.validateReasoningModelParams(request); err != nil { + return err + } + + if o1Series { + if err := v.validateO1Specific(request); err != nil { + return err + } + } + + return nil +} + +// validateReasoningModelParams checks reasoning model parameters. +func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error { + if request.MaxTokens > 0 { + return ErrReasoningModelMaxTokensDeprecated + } + if request.LogProbs { + return ErrReasoningModelLimitationsLogprobs + } + if request.Temperature > 0 && request.Temperature != 1 { + return ErrReasoningModelLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrReasoningModelLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrReasoningModelLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrReasoningModelLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrReasoningModelLimitationsOther + } + + return nil +} + +// validateO1Specific checks O1-specific limitations. +func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error { + for _, m := range request.Messages { + if _, found := availableMessageRoleForO1Models[m.Role]; !found { + return ErrO1BetaLimitationsMessageTypes + } + } + + for _, t := range request.Tools { + if _, found := unsupportedToolsForO1Models[t.Type]; found { + return ErrO1BetaLimitationsTools + } + } + return nil +} From a62919e8c66e35db125c129e8a9d2566a73e1e1f Mon Sep 17 00:00:00 2001 From: Mazyar Yousefiniyae shad Date: Sun, 9 Feb 2025 22:06:44 +0330 Subject: [PATCH 195/242] ref: add image url support to messages (#933) * ref: add image url support to messages * fix linter error * fix linter error --- messages.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/messages.go b/messages.go index 902363938..3852d2e37 100644 --- a/messages.go +++ b/messages.go @@ -41,6 +41,7 @@ type MessageContent struct { Type string `json:"type"` Text *MessageText `json:"text,omitempty"` ImageFile *ImageFile `json:"image_file,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` } type MessageText struct { Value string `json:"value"` @@ -51,6 +52,11 @@ type ImageFile struct { FileID string `json:"file_id"` } +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + type MessageRequest struct { Role string `json:"role"` Content string `json:"content"` From c0a9a75fe01dbefb16f87d69bab042516009184f Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Wed, 12 Feb 2025 23:05:44 +0800 Subject: [PATCH 196/242] feat: add developer role (#936) --- chat.go | 1 + reasoning_validator.go | 1 + 2 files changed, 2 insertions(+) diff --git a/chat.go b/chat.go index ce24fa34a..995860c40 100644 --- a/chat.go +++ b/chat.go @@ -14,6 +14,7 @@ const ( ChatMessageRoleAssistant = "assistant" ChatMessageRoleFunction = "function" ChatMessageRoleTool = "tool" + ChatMessageRoleDeveloper = "developer" ) const chatCompletionsSuffix = "/chat/completions" diff --git a/reasoning_validator.go b/reasoning_validator.go index 42a9fbd2e..4d4671b17 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -35,6 +35,7 @@ var unsupportedToolsForO1Models = map[ToolType]struct{}{ var availableMessageRoleForO1Models = map[string]struct{}{ ChatMessageRoleUser: {}, ChatMessageRoleAssistant: {}, + ChatMessageRoleDeveloper: {}, } // ReasoningValidator handles validation for o-series model requests. From 85f578b865a6ea12ab24307f3bc68c97f85b6580 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Mon, 17 Feb 2025 19:29:18 +0800 Subject: [PATCH 197/242] fix: remove validateO1Specific (#939) * fix: remove validateO1Specific * update golangci-lint-action version * fix actions * fix actions * fix actions * fix actions * remove some o1 test --- .github/workflows/pr.yml | 4 ++-- chat_test.go | 34 ---------------------------------- reasoning_validator.go | 32 -------------------------------- 3 files changed, 2 insertions(+), 68 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index a41fff92f..ea0c327f1 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -18,9 +18,9 @@ jobs: run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v4 + uses: golangci/golangci-lint-action@v6 with: - version: latest + version: v1.63.4 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov diff --git a/chat_test.go b/chat_test.go index fc6c4a936..e90142da6 100644 --- a/chat_test.go +++ b/chat_test.go @@ -106,40 +106,6 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { }, expectedError: openai.ErrReasoningModelLimitationsLogprobs, }, - { - name: "message_type_unsupported", - in: openai.ChatCompletionRequest{ - MaxCompletionTokens: 1000, - Model: openai.O1Mini, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleSystem, - }, - }, - }, - expectedError: openai.ErrO1BetaLimitationsMessageTypes, - }, - { - name: "tool_unsupported", - in: openai.ChatCompletionRequest{ - MaxCompletionTokens: 1000, - Model: openai.O1Mini, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - }, - { - Role: openai.ChatMessageRoleAssistant, - }, - }, - Tools: []openai.Tool{ - { - Type: openai.ToolTypeFunction, - }, - }, - }, - expectedError: openai.ErrO1BetaLimitationsTools, - }, { name: "set_temperature_unsupported", in: openai.ChatCompletionRequest{ diff --git a/reasoning_validator.go b/reasoning_validator.go index 4d4671b17..040d6b495 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -28,16 +28,6 @@ var ( ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ) -var unsupportedToolsForO1Models = map[ToolType]struct{}{ - ToolTypeFunction: {}, -} - -var availableMessageRoleForO1Models = map[string]struct{}{ - ChatMessageRoleUser: {}, - ChatMessageRoleAssistant: {}, - ChatMessageRoleDeveloper: {}, -} - // ReasoningValidator handles validation for o-series model requests. type ReasoningValidator struct{} @@ -59,12 +49,6 @@ func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { return err } - if o1Series { - if err := v.validateO1Specific(request); err != nil { - return err - } - } - return nil } @@ -94,19 +78,3 @@ func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletion return nil } - -// validateO1Specific checks O1-specific limitations. -func (v *ReasoningValidator) validateO1Specific(request ChatCompletionRequest) error { - for _, m := range request.Messages { - if _, found := availableMessageRoleForO1Models[m.Role]; !found { - return ErrO1BetaLimitationsMessageTypes - } - } - - for _, t := range request.Tools { - if _, found := unsupportedToolsForO1Models[t.Type]; found { - return ErrO1BetaLimitationsTools - } - } - return nil -} From be2e2387d4dcb15593ae5d0094e6f7b023ab3f53 Mon Sep 17 00:00:00 2001 From: Dan Ackerson Date: Tue, 25 Feb 2025 12:03:38 +0100 Subject: [PATCH 198/242] feat: add Anthropic API support with custom version header (#934) * feat: add Anthropic API support with custom version header * refactor: use switch statement for API type header handling * refactor: add OpenAI & AzureAD types to be exhaustive * Update client.go need explicit fallthrough in empty case statements * constant for APIVersion; addtl tests --- client.go | 18 +++++++++++++----- client_test.go | 15 +++++++++++++++ config.go | 22 +++++++++++++++++++++- config_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index ed8595e0b..cef375348 100644 --- a/client.go +++ b/client.go @@ -182,13 +182,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure || c.config.APIType == APITypeCloudflareAzure { + switch c.config.APIType { + case APITypeAzure, APITypeCloudflareAzure: + // Azure API Key authentication req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else if c.config.authToken != "" { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + case APITypeAnthropic: + // https://docs.anthropic.com/en/api/versioning + req.Header.Set("anthropic-version", c.config.APIVersion) + case APITypeOpenAI, APITypeAzureAD: + fallthrough + default: + if c.config.authToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } } + if c.config.OrgID != "" { req.Header.Set("OpenAI-Organization", c.config.OrgID) } diff --git a/client_test.go b/client_test.go index 2ed82f13c..321971445 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,21 @@ func TestClient(t *testing.T) { } } +func TestSetCommonHeadersAnthropic(t *testing.T) { + config := DefaultAnthropicConfig("mock-token", "") + client := NewClientWithConfig(config) + req, err := http.NewRequest("GET", "http://example.com/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client.setCommonHeaders(req) + + if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion { + t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got) + } +} + func TestDecodeResponse(t *testing.T) { stringInput := "" diff --git a/config.go b/config.go index 8a9183558..4788ba62a 100644 --- a/config.go +++ b/config.go @@ -11,6 +11,8 @@ const ( azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" + + AnthropicAPIVersion = "2023-06-01" ) type APIType string @@ -20,6 +22,7 @@ const ( APITypeAzure APIType = "AZURE" APITypeAzureAD APIType = "AZURE_AD" APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" + APITypeAnthropic APIType = "ANTHROPIC" ) const AzureAPIKeyHeader = "api-key" @@ -37,7 +40,7 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func HTTPClient HTTPDoer @@ -76,6 +79,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { } } +func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig { + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAnthropic, + APIVersion: AnthropicAPIVersion, + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + func (ClientConfig) String() string { return "" } diff --git a/config_test.go b/config_test.go index 3e528c3e9..145c26066 100644 --- a/config_test.go +++ b/config_test.go @@ -60,3 +60,43 @@ func TestGetAzureDeploymentByModel(t *testing.T) { }) } } + +func TestDefaultAnthropicConfig(t *testing.T) { + apiKey := "test-key" + baseURL := "https://api.anthropic.com/v1" + + config := openai.DefaultAnthropicConfig(apiKey, baseURL) + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion) + } + + if config.BaseURL != baseURL { + t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL) + } + + if config.EmptyMessagesLimit != 300 { + t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit) + } +} + +func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { + config := openai.DefaultAnthropicConfig("", "") + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion) + } + + expectedBaseURL := "https://api.anthropic.com/v1" + if config.BaseURL != expectedBaseURL { + t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) + } +} From 261721bfdbeb2edc495f24189b75f2c151f186a7 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 25 Feb 2025 16:56:35 +0000 Subject: [PATCH 199/242] Fix linter (#943) * fix lint * remove linters --- .github/workflows/pr.yml | 4 ++-- .golangci.yml | 14 -------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index ea0c327f1..818a8842b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -13,14 +13,14 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: '1.21' + go-version: '1.24' - name: Run vet run: | go vet . - name: Run golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.63.4 + version: v1.64.5 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v . - name: Upload coverage reports to Codecov diff --git a/.golangci.yml b/.golangci.yml index 9d22d9bd3..9f2ba52e0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -139,11 +139,6 @@ linters-settings: # Default: false all: true - varcheck: - # Check usage of exported fields and variables. - # Default: false - exported-fields: false # default false # TODO: enable after fixing false positives - linters: disable-all: true @@ -167,9 +162,7 @@ linters: - durationcheck # check for two durations multiplied together - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - # Removed execinquery (deprecated). execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - exhaustive # check exhaustiveness of enum switch statements - - exportloopref # checks for pointers to enclosing loop variables - forbidigo # Forbids identifiers - funlen # Tool for detection of long functions # - gochecknoglobals # check that no global variables exist @@ -201,7 +194,6 @@ linters: - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - stylecheck # Stylecheck is a replacement for golint - - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions @@ -239,12 +231,6 @@ linters: #- tagliatelle # Checks the struct tags. #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - ## deprecated - #- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized - #- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - #- interfacer # [deprecated] Linter that suggests narrower interface types - #- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted - #- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs issues: From 74d6449f22dd8bf668ebaeb181263b675b9a668b Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 4 Mar 2025 16:26:59 +0800 Subject: [PATCH 200/242] feat: add gpt-4.5-preview models (#947) --- completion.go | 138 ++++++++++++++++++++++++++------------------------ 1 file changed, 71 insertions(+), 67 deletions(-) diff --git a/completion.go b/completion.go index 1985293f8..015fa2a9f 100644 --- a/completion.go +++ b/completion.go @@ -10,41 +10,43 @@ import ( // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( - O1Mini = "o1-mini" - O1Mini20240912 = "o1-mini-2024-09-12" - O1Preview = "o1-preview" - O1Preview20240912 = "o1-preview-2024-09-12" - O1 = "o1" - O120241217 = "o1-2024-12-17" - O3Mini = "o3-mini" - O3Mini20250131 = "o3-mini-2025-01-31" - GPT432K0613 = "gpt-4-32k-0613" - GPT432K0314 = "gpt-4-32k-0314" - GPT432K = "gpt-4-32k" - GPT40613 = "gpt-4-0613" - GPT40314 = "gpt-4-0314" - GPT4o = "gpt-4o" - GPT4o20240513 = "gpt-4o-2024-05-13" - GPT4o20240806 = "gpt-4o-2024-08-06" - GPT4o20241120 = "gpt-4o-2024-11-20" - GPT4oLatest = "chatgpt-4o-latest" - GPT4oMini = "gpt-4o-mini" - GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" - GPT4Turbo = "gpt-4-turbo" - GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" - GPT4Turbo0125 = "gpt-4-0125-preview" - GPT4Turbo1106 = "gpt-4-1106-preview" - GPT4TurboPreview = "gpt-4-turbo-preview" - GPT4VisionPreview = "gpt-4-vision-preview" - GPT4 = "gpt-4" - GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" - GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" - GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" - GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" - GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" + O1 = "o1" + O120241217 = "o1-2024-12-17" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" + GPT432K0613 = "gpt-4-32k-0613" + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" + GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" + GPT4oLatest = "chatgpt-4o-latest" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" + GPT4VisionPreview = "gpt-4-vision-preview" + GPT4 = "gpt-4" + GPT4Dot5Preview = "gpt-4.5-preview" + GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. @@ -85,38 +87,40 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { - O1Mini: true, - O1Mini20240912: true, - O1Preview: true, - O1Preview20240912: true, - O3Mini: true, - O3Mini20250131: true, - GPT3Dot5Turbo: true, - GPT3Dot5Turbo0301: true, - GPT3Dot5Turbo0613: true, - GPT3Dot5Turbo1106: true, - GPT3Dot5Turbo0125: true, - GPT3Dot5Turbo16K: true, - GPT3Dot5Turbo16K0613: true, - GPT4: true, - GPT4o: true, - GPT4o20240513: true, - GPT4o20240806: true, - GPT4o20241120: true, - GPT4oLatest: true, - GPT4oMini: true, - GPT4oMini20240718: true, - GPT4TurboPreview: true, - GPT4VisionPreview: true, - GPT4Turbo1106: true, - GPT4Turbo0125: true, - GPT4Turbo: true, - GPT4Turbo20240409: true, - GPT40314: true, - GPT40613: true, - GPT432K: true, - GPT432K0314: true, - GPT432K0613: true, + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, + GPT3Dot5Turbo16K: true, + GPT3Dot5Turbo16K0613: true, + GPT4: true, + GPT4Dot5Preview: true, + GPT4Dot5Preview20250227: true, + GPT4o: true, + GPT4o20240513: true, + GPT4o20240806: true, + GPT4o20241120: true, + GPT4oLatest: true, + GPT4oMini: true, + GPT4oMini20240718: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, + GPT40314: true, + GPT40613: true, + GPT432K: true, + GPT432K0314: true, + GPT432K0613: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, From e99eb54c9d81cc102683921f4952a6d0c1964cbf Mon Sep 17 00:00:00 2001 From: "JT A." Date: Sun, 13 Apr 2025 12:00:48 -0600 Subject: [PATCH 201/242] add enum tag to jsonschema (#962) * fix jsonschema tests * ensure all run during PR Github Action * add test for struct to schema * add support for enum tag * support nullable tag --- .github/workflows/pr.yml | 2 +- jsonschema/json.go | 12 ++ jsonschema/json_test.go | 310 ++++++++++++++++++++++++++++++--------- 3 files changed, 252 insertions(+), 72 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 818a8842b..f4cbe7c8b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -22,6 +22,6 @@ jobs: with: version: v1.64.5 - name: Run tests - run: go test -race -covermode=atomic -coverprofile=coverage.out -v . + run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4 diff --git a/jsonschema/json.go b/jsonschema/json.go index bcb253fae..d458418f3 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -46,6 +46,8 @@ type Definition struct { // additionalProperties: false // additionalProperties: jsonschema.Definition{Type: jsonschema.String} AdditionalProperties any `json:"additionalProperties,omitempty"` + // Whether the schema is nullable or not. + Nullable bool `json:"nullable,omitempty"` } func (d *Definition) MarshalJSON() ([]byte, error) { @@ -139,6 +141,16 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { if description != "" { item.Description = description } + enum := field.Tag.Get("enum") + if enum != "" { + item.Enum = strings.Split(enum, ",") + } + + if n := field.Tag.Get("nullable"); n != "" { + nullable, _ := strconv.ParseBool(n) + item.Nullable = nullable + } + properties[jsonTag] = *item if s := field.Tag.Get("required"); s != "" { diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 744706082..17f0aba8a 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) { { name: "Test with empty Definition", def: jsonschema.Definition{}, - want: `{"properties":{}}`, + want: `{}`, }, { name: "Test with Definition properties set", @@ -31,15 +31,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"string", - "description":"A string type", - "properties":{ - "name":{ - "type":"string", - "properties":{} - } - } -}`, + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string" + } + } + }`, }, { name: "Test with nested Definition properties", @@ -60,23 +59,21 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"object", - "properties":{ - "user":{ - "type":"object", - "properties":{ - "name":{ - "type":"string", - "properties":{} - }, - "age":{ - "type":"integer", - "properties":{} - } - } - } - } -}`, + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + } + } + } + } + }`, }, { name: "Test with complex nested Definition", @@ -108,36 +105,32 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"object", - "properties":{ - "user":{ - "type":"object", - "properties":{ - "name":{ - "type":"string", - "properties":{} - }, - "age":{ - "type":"integer", - "properties":{} - }, - "address":{ - "type":"object", - "properties":{ - "city":{ - "type":"string", - "properties":{} - }, - "country":{ - "type":"string", - "properties":{} - } - } - } - } - } - } -}`, + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string" + }, + "country":{ + "type":"string" + } + } + } + } + } + } + }`, }, { name: "Test with Array type Definition", @@ -153,20 +146,16 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"array", - "items":{ - "type":"string", - "properties":{ - - } - }, - "properties":{ - "name":{ - "type":"string", - "properties":{} - } - } -}`, + "type":"array", + "items":{ + "type":"string" + }, + "properties":{ + "name":{ + "type":"string" + } + } + }`, }, } @@ -193,6 +182,185 @@ func TestDefinition_MarshalJSON(t *testing.T) { } } +func TestStructToSchema(t *testing.T) { + tests := []struct { + name string + in any + want string + }{ + { + name: "Test with empty struct", + in: struct{}{}, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with struct containing many fields", + in: struct { + Name string `json:"name"` + Age int `json:"age"` + Active bool `json:"active"` + Height float64 `json:"height"` + Cities []struct { + Name string `json:"name"` + State string `json:"state"` + } `json:"cities"` + }{ + Name: "John Doe", + Age: 30, + Cities: []struct { + Name string `json:"name"` + State string `json:"state"` + }{ + {Name: "New York", State: "NY"}, + {Name: "Los Angeles", State: "CA"}, + }, + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "active":{ + "type":"boolean" + }, + "height":{ + "type":"number" + }, + "cities":{ + "type":"array", + "items":{ + "additionalProperties":false, + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "state":{ + "type":"string" + } + }, + "required":["name","state"] + } + } + }, + "required":["name","age","active","height","cities"], + "additionalProperties":false + }`, + }, + { + name: "Test with description tag", + in: struct { + Name string `json:"name" description:"The name of the person"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "description":"The name of the person" + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + { + name: "Test with required tag", + in: struct { + Name string `json:"name" required:"false"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, + { + name: "Test with enum tag", + in: struct { + Color string `json:"color" enum:"red,green,blue"` + }{ + Color: "red", + }, + want: `{ + "type":"object", + "properties":{ + "color":{ + "type":"string", + "enum":["red","green","blue"] + } + }, + "required":["color"], + "additionalProperties":false + }`, + }, + { + name: "Test with nullable tag", + in: struct { + Name *string `json:"name" nullable:"true"` + }{ + Name: nil, + }, + want: `{ + + "type":"object", + "properties":{ + "name":{ + "type":"string", + "nullable":true + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantBytes := []byte(tt.want) + + schema, err := jsonschema.GenerateSchemaForType(tt.in) + if err != nil { + t.Errorf("Failed to generate schema: error = %v", err) + return + } + + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + got := structToMap(t, schema) + gotPtr := structToMap(t, &schema) + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } + }) + } +} + func structToMap(t *testing.T, v any) map[string]any { t.Helper() gotBytes, err := json.Marshal(v) From d68a6838156049ada8c25d3f4b8fa3befb3b4d1b Mon Sep 17 00:00:00 2001 From: Takahiro Ikeuchi Date: Thu, 24 Apr 2025 06:50:47 +0900 Subject: [PATCH 202/242] feat: add new GPT-4.1 model variants to completion.go (#966) * feat: add new GPT-4.1 model variants to completion.go * feat: add tests for unsupported models in completion endpoint * fix: add missing periods to test function comments in completion_test.go --- completion.go | 13 ++++++++ completion_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/completion.go b/completion.go index 015fa2a9f..0d0c1a8f4 100644 --- a/completion.go +++ b/completion.go @@ -37,6 +37,12 @@ const ( GPT4TurboPreview = "gpt-4-turbo-preview" GPT4VisionPreview = "gpt-4-vision-preview" GPT4 = "gpt-4" + GPT4Dot1 = "gpt-4.1" + GPT4Dot120250414 = "gpt-4.1-2025-04-14" + GPT4Dot1Mini = "gpt-4.1-mini" + GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14" + GPT4Dot1Nano = "gpt-4.1-nano" + GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" GPT4Dot5Preview = "gpt-4.5-preview" GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" @@ -121,6 +127,13 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT432K: true, GPT432K0314: true, GPT432K0613: true, + O1: true, + GPT4Dot1: true, + GPT4Dot120250414: true, + GPT4Dot1Mini: true, + GPT4Dot1Mini20250414: true, + GPT4Dot1Nano: true, + GPT4Dot1Nano20250414: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, diff --git a/completion_test.go b/completion_test.go index 935bbe864..83bd899a1 100644 --- a/completion_test.go +++ b/completion_test.go @@ -181,3 +181,86 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { } return completion, nil } + +// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint. +func TestCompletionWithO1Model(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O1, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err) + } +} + +// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint. +func TestCompletionWithGPT4DotModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4Dot1, + openai.GPT4Dot120250414, + openai.GPT4Dot1Mini, + openai.GPT4Dot1Mini20250414, + openai.GPT4Dot1Nano, + openai.GPT4Dot1Nano20250414, + openai.GPT4Dot5Preview, + openai.GPT4Dot5Preview20250227, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint. +func TestCompletionWithGPT4oModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4o, + openai.GPT4o20240513, + openai.GPT4o20240806, + openai.GPT4o20241120, + openai.GPT4oLatest, + openai.GPT4oMini, + openai.GPT4oMini20240718, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} From 658beda2ba8be4d155bc62208224a5766e0640c0 Mon Sep 17 00:00:00 2001 From: netr Date: Sat, 26 Apr 2025 03:13:43 -0700 Subject: [PATCH 203/242] feat: Add missing TTS models and voices (#958) * feat: Add missing TTS models and voices * feat: Add new instruction field to create speech request - From docs: Control the voice of your generated audio with additional instructions. Does not work with tts-1 or tts-1-hd. * fix: add canary-tts back to SpeechModel --- speech.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/speech.go b/speech.go index 20b52e334..60e7694fd 100644 --- a/speech.go +++ b/speech.go @@ -8,20 +8,25 @@ import ( type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTSModel1HD SpeechModel = "tts-1-hd" - TTSModelCanary SpeechModel = "canary-tts" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" + TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts" ) type SpeechVoice string const ( VoiceAlloy SpeechVoice = "alloy" + VoiceAsh SpeechVoice = "ash" + VoiceBallad SpeechVoice = "ballad" + VoiceCoral SpeechVoice = "coral" VoiceEcho SpeechVoice = "echo" VoiceFable SpeechVoice = "fable" VoiceOnyx SpeechVoice = "onyx" VoiceNova SpeechVoice = "nova" VoiceShimmer SpeechVoice = "shimmer" + VoiceVerse SpeechVoice = "verse" ) type SpeechResponseFormat string @@ -39,6 +44,7 @@ type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` Voice SpeechVoice `json:"voice"` + Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd. ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } From 306fbbbe6f09ff7bd718e11cd322e88b442f4496 Mon Sep 17 00:00:00 2001 From: goodenough Date: Tue, 29 Apr 2025 21:24:45 +0800 Subject: [PATCH 204/242] Add support for reasoning_content field in chat completion messages for DeepSeek R1 (#925) * support deepseek field "reasoning_content" * support deepseek field "reasoning_content" * Comment ends in a period (godot) * add comment on field reasoning_content * fix go lint error * chore: trigger CI * make field "content" in MarshalJSON function omitempty * remove reasoning_content in TestO1ModelChatCompletions func * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. * feat: Add test and handler for deepseek-reasoner chat model completions, including support for reasoning content in responses. --- chat.go | 74 ++++++++++++++++++++++++++-------------------- chat_stream.go | 6 ++++ chat_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++ openai_test.go | 2 +- 4 files changed, 128 insertions(+), 33 deletions(-) diff --git a/chat.go b/chat.go index 995860c40..7112dc7b7 100644 --- a/chat.go +++ b/chat.go @@ -104,6 +104,12 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. @@ -119,41 +125,44 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content,omitempty"` - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &msg); err == nil { @@ -161,14 +170,15 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { return nil } multiMsg := struct { - Role string `json:"role"` - Content string - Refusal string `json:"refusal,omitempty"` - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err diff --git a/chat_stream.go b/chat_stream.go index 525b4457a..80d16cc63 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -11,6 +11,12 @@ type ChatCompletionStreamChoiceDelta struct { FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` Refusal string `json:"refusal,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` } type ChatCompletionStreamChoiceLogprobs struct { diff --git a/chat_test.go b/chat_test.go index e90142da6..514706c96 100644 --- a/chat_test.go +++ b/chat_test.go @@ -411,6 +411,23 @@ func TestO3ModelChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +func TestDeepseekR1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "deepseek-reasoner", + MaxCompletionTokens: 100, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -822,6 +839,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq openai.ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + if completionReq.MaxCompletionTokens == 0 { + completionReq.MaxCompletionTokens = 1000 + } + for i := 0; i < n; i++ { + reasoningContent := "User says hello! And I need to reply" + completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent)) + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ReasoningContent: reasoningContent, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + fmt.Fprintln(w, string(resBytes)) +} + // getChatCompletionBody Returns the body of the request to create a completion. func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { completion := openai.ChatCompletionRequest{} diff --git a/openai_test.go b/openai_test.go index 6c26eebd1..a55f3a858 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,7 +29,7 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer/ +// https://beta.openai.com/tokenizer. // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { From 4cccc6c93455d2ff2cf52660c916cfa5907ddbd3 Mon Sep 17 00:00:00 2001 From: Zhongxian Pan Date: Tue, 29 Apr 2025 21:29:15 +0800 Subject: [PATCH 205/242] Adapt different stream data prefix, with or without space (#945) --- stream_reader.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/stream_reader.go b/stream_reader.go index ecfa26807..6faefe0a7 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,13 +6,14 @@ import ( "fmt" "io" "net/http" + "regexp" utils "github.com/sashabaranov/go-openai/internal" ) var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) + headerData = regexp.MustCompile(`^data:\s*`) + errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) type streamable interface { @@ -70,12 +71,12 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { } noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { + if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if !headerData.Match(noSpaceLine) || hasErrorPrefix { if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { @@ -89,7 +90,7 @@ func (stream *streamReader[T]) processLines() ([]byte, error) { continue } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true return nil, io.EOF From bb5bc275678767d410d25d307959e1c45bc89c90 Mon Sep 17 00:00:00 2001 From: rory malcolm Date: Tue, 29 Apr 2025 14:34:33 +0100 Subject: [PATCH 206/242] Add support for `4o-mini` and `3o` (#968) - This adds supports, and tests, for the 3o and 4o-mini class of models --- chat_stream_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++ completion.go | 8 +++++++ completion_test.go | 36 ++++++++++++++++++++++++++++++ models_test.go | 18 +++++++++++++++ reasoning_validator.go | 3 ++- 5 files changed, 114 insertions(+), 1 deletion(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 4d992e4d1..eabb0f3a2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -959,6 +959,56 @@ func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { } } +func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O4Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/completion.go b/completion.go index 0d0c1a8f4..7a6de3033 100644 --- a/completion.go +++ b/completion.go @@ -16,8 +16,12 @@ const ( O1Preview20240912 = "o1-preview-2024-09-12" O1 = "o1" O120241217 = "o1-2024-12-17" + O3 = "o3" + O320250416 = "o3-2025-04-16" O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" + O4Mini = "o4-mini" + O4Mini2020416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -99,6 +103,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O1Preview20240912: true, O3Mini: true, O3Mini20250131: true, + O4Mini: true, + O4Mini2020416: true, + O3: true, + O320250416: true, GPT3Dot5Turbo: true, GPT3Dot5Turbo0301: true, GPT3Dot5Turbo0613: true, diff --git a/completion_test.go b/completion_test.go index 83bd899a1..27e2d150e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) { } } +// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported. +func TestCompletionsWrongModelO3(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O3, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err) + } +} + +// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported. +func TestCompletionsWrongModelO4Mini(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O4Mini, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err) + } +} + func TestCompletionWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") client := openai.NewClientWithConfig(config) diff --git a/models_test.go b/models_test.go index 24a28ed23..7fd010c34 100644 --- a/models_test.go +++ b/models_test.go @@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) { checks.NoError(t, err, "GetModel error") } +// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server. +func TestGetModelO3(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o3") + checks.NoError(t, err, "GetModel error for O3") +} + +// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server. +func TestGetModelO4Mini(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o4-mini") + checks.NoError(t, err, "GetModel error for O4Mini") +} + func TestAzureGetModel(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() diff --git a/reasoning_validator.go b/reasoning_validator.go index 040d6b495..2910b1395 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -40,8 +40,9 @@ func NewReasoningValidator() *ReasoningValidator { func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") + o4Series := strings.HasPrefix(request.Model, "o4") - if !o1Series && !o3Series { + if !o1Series && !o3Series && !o4Series { return nil } From da5f9bc9bc40537a0c2b451fffa9364efa94dbe1 Mon Sep 17 00:00:00 2001 From: Sean McGinnis Date: Tue, 29 Apr 2025 08:35:26 -0500 Subject: [PATCH 207/242] Add CompletionRequest.StreamOptions (#959) The legacy completion API supports a `stream_options` object when `stream` is set to true [0]. This adds a StreamOptions property to the CompletionRequest struct to support this setting. [0] https://platform.openai.com/docs/api-reference/completions/create#completions-create-stream_options Signed-off-by: Sean McGinnis --- completion.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/completion.go b/completion.go index 7a6de3033..9c3a64dd5 100644 --- a/completion.go +++ b/completion.go @@ -215,6 +215,8 @@ type CompletionRequest struct { Temperature float32 `json:"temperature,omitempty"` TopP float32 `json:"top_p,omitempty"` User string `json:"user,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` } // CompletionChoice represents one of possible completions. From 6836cf6a6fd0027ea21f8d31bff5d023040d9db4 Mon Sep 17 00:00:00 2001 From: Oleksandr Redko Date: Tue, 29 Apr 2025 16:36:38 +0300 Subject: [PATCH 208/242] Remove redundant typecheck linter (#955) --- .golangci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index 9f2ba52e0..a5988825b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -149,7 +149,6 @@ linters: - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - ineffassign # Detects when assignments to existing variables are not used - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unused # Checks Go code for unused constants, variables, functions and types ## disabled by default # - asasalint # Check for pass []any as any in variadic func(...any) From 93a611cf4f1d227963a4f28a2c4a3422a0d37bfd Mon Sep 17 00:00:00 2001 From: Daniel Peng Date: Tue, 29 Apr 2025 06:38:27 -0700 Subject: [PATCH 209/242] Add Prediction field (#970) * Add Prediction field to ChatCompletionRequest * Include prediction tokens in response --- chat.go | 7 +++++++ common.go | 6 ++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index 7112dc7b7..0f91d481c 100644 --- a/chat.go +++ b/chat.go @@ -273,6 +273,8 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + // Configuration for a predicted output. + Prediction *Prediction `json:"prediction,omitempty"` } type StreamOptions struct { @@ -340,6 +342,11 @@ type LogProbs struct { Content []LogProb `json:"content"` } +type Prediction struct { + Content string `json:"content"` + Type string `json:"type"` +} + type FinishReason string const ( diff --git a/common.go b/common.go index 8cc7289c0..d1936d656 100644 --- a/common.go +++ b/common.go @@ -13,8 +13,10 @@ type Usage struct { // CompletionTokensDetails Breakdown of tokens used in a completion. type CompletionTokensDetails struct { - AudioTokens int `json:"audio_tokens"` - ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` } // PromptTokensDetails Breakdown of tokens used in the prompt. From d65f0cb54e8c9c91f4340fad14243beeb38f5a08 Mon Sep 17 00:00:00 2001 From: Ben Katz Date: Sun, 4 May 2025 03:44:48 +0700 Subject: [PATCH 210/242] Fix: Corrected typo in O4Mini20250416 model name and endpoint map. (#981) --- completion.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 9c3a64dd5..21d4897c4 100644 --- a/completion.go +++ b/completion.go @@ -21,7 +21,7 @@ const ( O3Mini = "o3-mini" O3Mini20250131 = "o3-mini-2025-01-31" O4Mini = "o4-mini" - O4Mini2020416 = "o4-mini-2025-04-16" + O4Mini20250416 = "o4-mini-2025-04-16" GPT432K0613 = "gpt-4-32k-0613" GPT432K0314 = "gpt-4-32k-0314" GPT432K = "gpt-4-32k" @@ -104,7 +104,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ O3Mini: true, O3Mini20250131: true, O4Mini: true, - O4Mini2020416: true, + O4Mini20250416: true, O3: true, O320250416: true, GPT3Dot5Turbo: true, From 5ea214a188a7751bddb9f6c899632509d388f643 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 3 May 2025 22:25:14 +0100 Subject: [PATCH 211/242] Improve unit test coverage (#984) * add tests for config * add audio tests * lint * lint * lint --- audio_test.go | 132 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 21 ++++++++ 2 files changed, 153 insertions(+) diff --git a/audio_test.go b/audio_test.go index 9f32d5468..51b3f465d 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,12 +2,16 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" + "errors" "fmt" "io" + "net/http" "os" "path/filepath" "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -107,3 +111,131 @@ func TestCreateFileField(t *testing.T) { checks.HasError(t, err, "createFileField using file should return error when open file fails") }) } + +// failingFormBuilder always returns an error when creating form files. +type failingFormBuilder struct{ err error } + +func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error { + return f.err +} + +func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error { + return f.err +} + +func (f *failingFormBuilder) WriteField(_, _ string) error { + return nil +} + +func (f *failingFormBuilder) Close() error { + return nil +} + +func (f *failingFormBuilder) FormDataContentType() string { + return "multipart/form-data" +} + +// failingAudioRequestBuilder simulates an error during HTTP request construction. +type failingAudioRequestBuilder struct{ err error } + +func (f *failingAudioRequestBuilder) Build( + _ context.Context, + _, _ string, + _ any, + _ http.Header, +) (*http.Request, error) { + return nil, f.err +} + +// errorHTTPClient always returns an error when making HTTP calls. +type errorHTTPClient struct{ err error } + +func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return nil, e.err +} + +func TestCallAudioAPIMultipartFormError(t *testing.T) { + client := NewClient("test-token") + errForm := errors.New("mock create form file failure") + // Override form builder to force an error during multipart form creation. + client.createFormBuilder = func(_ io.Writer) utils.FormBuilder { + return &failingFormBuilder{err: errForm} + } + + // Provide a reader so createFileField uses the reader path (no file open). + req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errForm) { + t.Errorf("expected error %v, got %v", errForm, err) + } +} + +func TestCallAudioAPINewRequestError(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errBuild := errors.New("mock build failure") + client.requestBuilder = &failingAudioRequestBuilder{err: errBuild} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errBuild) { + t.Errorf("expected error %v, got %v", errBuild, err) + } +} + +func TestCallAudioAPISendRequestErrorJSON(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + // Override HTTP client to simulate a network error. + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} + +func TestCallAudioAPISendRequestErrorText(t *testing.T) { + client := NewClient("test-token") + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + // Use a non-JSON response format to exercise the text path. + req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} diff --git a/config_test.go b/config_test.go index 145c26066..960230804 100644 --- a/config_test.go +++ b/config_test.go @@ -100,3 +100,24 @@ func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) } } + +func TestClientConfigString(t *testing.T) { + // String() should always return the constant value + conf := openai.DefaultConfig("dummy-token") + expected := "" + got := conf.String() + if got != expected { + t.Errorf("ClientConfig.String() = %q; want %q", got, expected) + } +} + +func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) { + // On a zero-value or DefaultConfig, AzureModelMapperFunc is nil, + // so GetAzureDeploymentByModel should just return the input model. + conf := openai.DefaultConfig("dummy-token") + model := "some-model" + got := conf.GetAzureDeploymentByModel(model) + if got != model { + t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model) + } +} From 77ccac8d342f7704b020e63b23cec2d6f009bff9 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 3 May 2025 22:39:47 +0100 Subject: [PATCH 212/242] Upgrade golangci-lint to 2.1.5 (#986) * Upgrade golangci-lint to 2.1.5 * update action --- .github/workflows/pr.yml | 4 +- .golangci.yml | 418 +++++++++++++++------------------------ 2 files changed, 166 insertions(+), 256 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index f4cbe7c8b..268e3259b 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -18,9 +18,9 @@ jobs: run: | go vet . - name: Run golangci-lint - uses: golangci/golangci-lint-action@v6 + uses: golangci/golangci-lint-action@v7 with: - version: v1.64.5 + version: v2.1.5 - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov diff --git a/.golangci.yml b/.golangci.yml index a5988825b..6391ad76f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,258 +1,168 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - mnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - +version: "2" linters: - disable-all: true + default: none enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - exhaustive # check exhaustiveness of enum switch statements - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - mnd # An analyzer to detect magic numbers. - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - usetesting # Reports uses of functions with replacement inside the testing package - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - - + - asciicheck + - bidichk + - bodyclose + - contextcheck + - cyclop + - dupl + - durationcheck + - errcheck + - errname + - errorlint + - exhaustive + - forbidigo + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - govet + - ineffassign + - lll + - makezero + - mnd + - nestif + - nilerr + - nilnil + - nolintlint + - nosprintfhostport + - predeclared + - promlinter + - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - testpackage + - tparallel + - unconvert + - unparam + - unused + - usetesting + - wastedassign + - whitespace + settings: + cyclop: + max-complexity: 30 + package-average: 10 + errcheck: + check-type-assertions: true + funlen: + lines: 100 + statements: 50 + gocognit: + min-complexity: 20 + gocritic: + settings: + captLocal: + paramsOnly: false + underef: + skipRecvDeref: false + gomodguard: + blocked: + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: satori's package is not maintained + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw' + govet: + disable: + - fieldalignment + enable-all: true + settings: + shadow: + strict: true + mnd: + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + nakedret: + max-func-lines: 0 + nolintlint: + require-explanation: true + require-specific: true + allow-no-explanation: + - funlen + - gocognit + - lll + rowserrcheck: + packages: + - github.com/jmoiron/sqlx + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - forbidigo + - mnd + - revive + path : ^examples/.*\.go$ + - linters: + - lll + source: ^//\s*go:generate\s + - linters: + - godot + source: (noinspection|TODO) + - linters: + - gocritic + source: //noinspection + - linters: + - errorlint + source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok { + - linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck + - staticcheck + path: _test\.go + paths: + - third_party$ + - builtin$ + - examples$ issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ From 6181facea7e6e5525b6b8da42205d7cce822c249 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sun, 4 May 2025 15:45:40 +0100 Subject: [PATCH 213/242] update codecov action, pass token (#987) --- .github/workflows/pr.yml | 4 +- .golangci.bck.yml | 258 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 .golangci.bck.yml diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 268e3259b..18c720f03 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -24,4 +24,6 @@ jobs: - name: Run tests run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.golangci.bck.yml b/.golangci.bck.yml new file mode 100644 index 000000000..a5988825b --- /dev/null +++ b/.golangci.bck.yml @@ -0,0 +1,258 @@ +## Golden config for golangci-lint v1.47.3 +# +# This is the best config for golangci-lint based on my experience and opinion. +# It is very strict, but not extremely strict. +# Feel free to adopt and change it for your needs. + +run: + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 3m + + +# This file contains only configs which differ from defaults. +# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml +linters-settings: + cyclop: + # The maximal code complexity to report. + # Default: 10 + max-complexity: 30 + # The maximal average package complexity. + # If it's higher than 0.0 (float) the check is enabled + # Default: 0.0 + package-average: 10.0 + + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + + funlen: + # Checks the number of lines in a function. + # If lower than 0, disable the check. + # Default: 60 + lines: 100 + # Checks the number of statements in a function. + # If lower than 0, disable the check. + # Default: 40 + statements: 50 + + gocognit: + # Minimal code complexity to report + # Default: 30 (but we recommend 10-20) + min-complexity: 20 + + gocritic: + # Settings passed to gocritic. + # The settings key is the name of a supported gocritic checker. + # The list of supported checkers can be find in https://go-critic.github.io/overview. + settings: + captLocal: + # Whether to restrict checker to params only. + # Default: true + paramsOnly: false + underef: + # Whether to skip (*x).method() calls where x is a pointer receiver. + # Default: true + skipRecvDeref: false + + mnd: + # List of function patterns to exclude from analysis. + # Values always ignored: `time.Date` + # Default: [] + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + + gomodguard: + blocked: + # List of blocked modules. + # Default: [] + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: "satori's package is not maintained" + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" + + govet: + # Enable all analyzers. + # Default: false + enable-all: true + # Disable analyzers by name. + # Run `go tool vet help` to see all analyzers. + # Default: [] + disable: + - fieldalignment # too strict + # Settings per analyzer. + settings: + shadow: + # Whether to be strict about shadowing; can be noisy. + # Default: false + strict: true + + nakedret: + # Make an issue if func has more lines of code than this setting, and it has naked returns. + # Default: 30 + max-func-lines: 0 + + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [ funlen, gocognit, lll ] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + + rowserrcheck: + # database/sql is always checked + # Default: [] + packages: + - github.com/jmoiron/sqlx + + tenv: + # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. + # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. + # Default: false + all: true + + +linters: + disable-all: true + enable: + ## enabled by default + - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - gosimple # Linter for Go source code that specializes in simplifying a code + - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - ineffassign # Detects when assignments to existing variables are not used + - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks + - unused # Checks Go code for unused constants, variables, functions and types + ## disabled by default + # - asasalint # Check for pass []any as any in variadic func(...any) + - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - contextcheck # check the function whether use a non-inherited context + - cyclop # checks function and package cyclomatic complexity + - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together + - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. + - exhaustive # check exhaustiveness of enum switch statements + - forbidigo # Forbids identifiers + - funlen # Tool for detection of long functions + # - gochecknoglobals # check that no global variables exist + - gochecknoinits # Checks that no init functions are present in Go code + - gocognit # Computes and checks the cognitive complexity of functions + - goconst # Finds repeated strings that could be replaced by a constant + - gocritic # Provides diagnostics that check for bugs, performance and style issues. + - gocyclo # Computes and checks the cyclomatic complexity of functions + - godot # Check if comments end in a period + - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. + - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. + - goprintffuncname # Checks that printf-like functions are named with f at the end + - gosec # Inspects source code for security problems + - lll # Reports long lines + - makezero # Finds slice declarations with non-zero initial length + # - nakedret # Finds naked returns in functions greater than a specified function length + - mnd # An analyzer to detect magic numbers. + - nestif # Reports deeply nested if statements + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. + # - noctx # noctx finds sending http request without context.Context + - nolintlint # Reports ill-formed or insufficient nolint directives + # - nonamedreturns # Reports all named returns + - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. + - predeclared # find code that shadows one of Go's predeclared identifiers + - promlinter # Check Prometheus metrics naming via promlint + - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. + - rowserrcheck # checks whether Err of rows is checked successfully + - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. + - stylecheck # Stylecheck is a replacement for golint + - testpackage # linter that makes you use a separate _test package + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # Remove unnecessary type conversions + - unparam # Reports unused function parameters + - usetesting # Reports uses of functions with replacement inside the testing package + - wastedassign # wastedassign finds wasted assignment statements. + - whitespace # Tool for detection of leading and trailing whitespace + ## you may want to enable + #- decorder # check declaration order and count of types, constants, variables and functions + #- exhaustruct # Checks if all structure fields are initialized + #- goheader # Checks is file header matches to pattern + #- ireturn # Accept Interfaces, Return Concrete Types + #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated + #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope + #- wrapcheck # Checks that errors returned from external packages are wrapped + ## disabled + #- containedctx # containedctx is a linter that detects struct contained context.Context field + #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages + #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. + #- forcetypeassert # [replaced by errcheck] finds forced type assertions + #- gci # Gci controls golang package import order and makes it always deterministic. + #- godox # Tool for detection of FIXME, TODO and other comment keywords + #- goerr113 # [too strict] Golang linter to check the errors handling expressions + #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification + #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. + #- grouper # An analyzer to analyze expression groups. + #- ifshort # Checks that your code uses short syntax for if-statements whenever possible + #- importas # Enforces consistent import aliases + #- maintidx # maintidx measures the maintainability index of each function. + #- misspell # [useless] Finds commonly misspelled English words in comments + #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity + #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 + #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test + #- tagliatelle # Checks the struct tags. + #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! + + +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 50 + + exclude-rules: + - source: "^//\\s*go:generate\\s" + linters: [ lll ] + - source: "(noinspection|TODO)" + linters: [ godot ] + - source: "//noinspection" + linters: [ gocritic ] + - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" + linters: [ errorlint ] + - path: "_test\\.go" + linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck From 8ba38f6ba16264760d3fd88892d02b25ef742c24 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 13 May 2025 12:44:16 +0100 Subject: [PATCH 214/242] remove backup file (#996) --- .golangci.bck.yml | 258 ---------------------------------------------- 1 file changed, 258 deletions(-) delete mode 100644 .golangci.bck.yml diff --git a/.golangci.bck.yml b/.golangci.bck.yml deleted file mode 100644 index a5988825b..000000000 --- a/.golangci.bck.yml +++ /dev/null @@ -1,258 +0,0 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - mnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - -linters: - disable-all: true - enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - exhaustive # check exhaustiveness of enum switch statements - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - mnd # An analyzer to detect magic numbers. - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - usetesting # Reports uses of functions with replacement inside the testing package - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - - -issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 - max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck From 0116f2994de0960ab65598f60e55d98b43e5fc2d Mon Sep 17 00:00:00 2001 From: Pedro Chaparro <94259578+PChaparro@users.noreply.github.com> Date: Tue, 13 May 2025 06:51:08 -0500 Subject: [PATCH 215/242] feat: add support for image generation using `gpt-image-1` (#971) * feat: add gpt-image-1 support * feat: add example to generate image using gpt-image-1 model * style: missing period in comments * feat: add missing fields to example * docs: add GPT Image 1 to README * revert: keep `examples/images/main.go` unchanged * docs: remove unnecessary newline from example in README file --- README.md | 62 +++++++++++++++++++++++++++++++++- examples/images/main.go | 2 +- image.go | 75 +++++++++++++++++++++++++++++++++++------ 3 files changed, 126 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 57d1d35bf..77b85e519 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op * ChatGPT 4o, o1 * GPT-3, GPT-4 -* DALL·E 2, DALL·E 3 +* DALL·E 2, DALL·E 3, GPT Image 1 * Whisper ## Installation @@ -357,6 +357,66 @@ func main() { ``` +
+GPT Image 1 image generation + +```go +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.ImageRequest{ + Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.", + Background: openai.CreateImageBackgroundOpaque, + Model: openai.CreateImageModelGptImage1, + Size: openai.CreateImageSize1024x1024, + N: 1, + Quality: openai.CreateImageQualityLow, + OutputCompression: 100, + OutputFormat: openai.CreateImageOutputFormatJPEG, + // Moderation: openai.CreateImageModerationLow, + // User: "", + } + + resp, err := c.CreateImage(ctx, req) + if err != nil { + fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err) + return + } + + fmt.Println("Image Base64:", resp.Data[0].B64JSON) + + // Decode the base64 data + imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) + if err != nil { + fmt.Printf("Base64 decode error: %v\n", err) + return + } + + // Write image to file + outputPath := "generated_image.jpg" + err = os.WriteFile(outputPath, imgBytes, 0644) + if err != nil { + fmt.Printf("Failed to write image file: %v\n", err) + return + } + + fmt.Printf("The image was saved as %s\n", outputPath) +} +``` +
+
Configuring proxy diff --git a/examples/images/main.go b/examples/images/main.go index 5ee649d22..2bfeb7973 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -25,4 +25,4 @@ func main() { return } fmt.Println(respUrl.Data[0].URL) -} +} \ No newline at end of file diff --git a/image.go b/image.go index 577d7db95..d62622a35 100644 --- a/image.go +++ b/image.go @@ -13,51 +13,101 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. CreateImageSize1792x1024 = "1792x1024" CreateImageSize1024x1792 = "1024x1792" + + // gpt-image-1 supported only. + CreateImageSize1536x1024 = "1536x1024" // Landscape + CreateImageSize1024x1536 = "1024x1536" // Portrait ) const ( - CreateImageResponseFormatURL = "url" + // dall-e-2 and dall-e-3 only. CreateImageResponseFormatB64JSON = "b64_json" + CreateImageResponseFormatURL = "url" ) const ( - CreateImageModelDallE2 = "dall-e-2" - CreateImageModelDallE3 = "dall-e-3" + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" + CreateImageModelGptImage1 = "gpt-image-1" ) const ( CreateImageQualityHD = "hd" CreateImageQualityStandard = "standard" + + // gpt-image-1 only. + CreateImageQualityHigh = "high" + CreateImageQualityMedium = "medium" + CreateImageQualityLow = "low" ) const ( + // dall-e-3 only. CreateImageStyleVivid = "vivid" CreateImageStyleNatural = "natural" ) +const ( + // gpt-image-1 only. + CreateImageBackgroundTransparent = "transparent" + CreateImageBackgroundOpaque = "opaque" +) + +const ( + // gpt-image-1 only. + CreateImageModerationLow = "low" +) + +const ( + // gpt-image-1 only. + CreateImageOutputFormatPNG = "png" + CreateImageOutputFormatJPEG = "jpeg" + CreateImageOutputFormatWEBP = "webp" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` - Style string `json:"style,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } // ImageResponse represents a response structure for image API. type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + Usage ImageResponseUsage `json:"usage,omitempty"` httpHeader } +// ImageResponseInputTokensDetails represents the token breakdown for input tokens. +type ImageResponseInputTokensDetails struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` +} + +// ImageResponseUsage represents the token usage information for image API. +type ImageResponseUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"` +} + // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { URL string `json:"url,omitempty"` @@ -91,6 +141,8 @@ type ImageEditRequest struct { N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -159,6 +211,7 @@ type ImageVariRequest struct { N int `json:"n,omitempty"` Size string `json:"size,omitempty"` ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. From 6aaa7322960741a84da11ac360516e4ec813dfff Mon Sep 17 00:00:00 2001 From: Justa Date: Tue, 13 May 2025 19:52:44 +0800 Subject: [PATCH 216/242] add ChatTemplateKwargs to ChatCompletionRequest (#980) Co-authored-by: Justa --- chat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chat.go b/chat.go index 0f91d481c..c8a3e81b3 100644 --- a/chat.go +++ b/chat.go @@ -275,6 +275,11 @@ type ChatCompletionRequest struct { Metadata map[string]string `json:"metadata,omitempty"` // Configuration for a predicted output. Prediction *Prediction `json:"prediction,omitempty"` + // ChatTemplateKwargs provides a way to add non-standard parameters to the request body. + // Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} + // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes + ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` } type StreamOptions struct { From 4d2e7ab29d7bf853c740e9e63187fed960e6b0b4 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Tue, 13 May 2025 12:59:06 +0100 Subject: [PATCH 217/242] fix lint (#998) --- examples/images/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/images/main.go b/examples/images/main.go index 2bfeb7973..5ee649d22 100644 --- a/examples/images/main.go +++ b/examples/images/main.go @@ -25,4 +25,4 @@ func main() { return } fmt.Println(respUrl.Data[0].URL) -} \ No newline at end of file +} From 8c65b35c57ad4e9ba408def9bf9ff97817aab932 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 20 May 2025 21:45:40 +0800 Subject: [PATCH 218/242] update image api *os.File to io.Reader (#994) * update image api *os.File to io.Reader * update code style * add reader test * supplementary reader test * update the reader in the form builder test * add commnet * update comment * update code style --- image.go | 43 ++++++++++++++++++----------------- image_test.go | 8 +++---- internal/form_builder.go | 35 ++++++++++++++++++++++++++-- internal/form_builder_test.go | 29 +++++++++++++++++++++++ 4 files changed, 88 insertions(+), 27 deletions(-) diff --git a/image.go b/image.go index d62622a35..72077ce41 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,8 @@ package openai import ( "bytes" "context" + "io" "net/http" - "os" "strconv" ) @@ -134,15 +134,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons // ImageEditRequest represents the request structure for the image API. type ImageEditRequest struct { - Image *os.File `json:"image,omitempty"` - Mask *os.File `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Quality string `json:"quality,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -150,15 +150,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.CreateFormFile("mask", request.Mask) + // mask, filename is not required + err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return } @@ -206,12 +207,12 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // ImageVariRequest represents the request structure for the image API. type ImageVariRequest struct { - Image *os.File `json:"image,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Image io.Reader `json:"image,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. @@ -220,8 +221,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename is not required + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } diff --git a/image_test.go b/image_test.go index 9332dd5cd..644005515 100644 --- a/image_test.go +++ b/image_test.go @@ -54,13 +54,13 @@ func TestImageFormBuilderFailures(t *testing.T) { } mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateEditImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { + mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { if name == "mask" { return mockFailedErr } @@ -119,13 +119,13 @@ func TestVariImageFormBuilderFailures(t *testing.T) { req := ImageVariRequest{} mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } _, err := client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { + mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } diff --git a/internal/form_builder.go b/internal/form_builder.go index 2224fad45..1c2513dd9 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -4,8 +4,10 @@ import ( "fmt" "io" "mime/multipart" + "net/textproto" "os" - "path" + "path/filepath" + "strings" ) type FormBuilder interface { @@ -30,8 +32,37 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er return fb.createFormFile(fieldname, file, file.Name()) } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// CreateFormFileReader creates a form field with a file reader. +// The filename in parameters can be an empty string. +// The filename in Content-Disposition is required, But it can be an empty string. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { - return fb.createFormFile(fieldname, r, path.Base(filename)) + h := make(textproto.MIMEHeader) + h.Set( + "Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + escapeQuotes(fieldname), + escapeQuotes(filepath.Base(filename)), + ), + ) + + fieldWriter, err := fb.writer.CreatePart(h) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil } func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 8df989e3b..76922c1ba 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -43,3 +43,32 @@ func TestFormBuilderWithClosedFile(t *testing.T) { checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } + +type failingReader struct { +} + +var errMockFailingReaderError = errors.New("mock reader failed") + +func (*failingReader) Read([]byte) (int, error) { + return 0, errMockFailingReaderError +} + +func TestFormBuilderWithReader(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFileReader("file", file, file.Name()) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") + + builder = NewFormBuilder(&bytes.Buffer{}) + reader := &failingReader{} + err = builder.CreateFormFileReader("file", reader, "") + checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails") + + successReader := &bytes.Buffer{} + err = builder.CreateFormFileReader("file", successReader, "") + checks.NoError(t, err, "formbuilder should not return error") +} From ff9d83a4854790ecbd16e6328415a32c7497efaf Mon Sep 17 00:00:00 2001 From: "JT A." Date: Thu, 29 May 2025 04:31:35 -0600 Subject: [PATCH 219/242] skip json field (#1009) * skip json field * backfill some coverage and tests --- jsonschema/json.go | 7 ++++-- jsonschema/json_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index d458418f3..03bb68891 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -126,9 +126,12 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { } jsonTag := field.Tag.Get("json") var required = true - if jsonTag == "" { + switch { + case jsonTag == "-": + continue + case jsonTag == "": jsonTag = field.Name - } else if strings.HasSuffix(jsonTag, ",omitempty") { + case strings.HasSuffix(jsonTag, ",omitempty"): jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") required = false } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 17f0aba8a..84f25fa85 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -329,6 +329,53 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with exclude mark", + in: struct { + Name string `json:"-"` + }{ + Name: "Name", + }, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with no json tag", + in: struct { + Name string + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "Name":{ + "type":"string" + } + }, + "required":["Name"], + "additionalProperties":false + }`, + }, + { + name: "Test with omitempty tag", + in: struct { + Name string `json:"name,omitempty"` + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, } for _, tt := range tests { From d7dca83beda99528c974392a8295630775c1c197 Mon Sep 17 00:00:00 2001 From: Axb12 <67110563+Axb12@users.noreply.github.com> Date: Tue, 17 Jun 2025 03:08:14 +0800 Subject: [PATCH 220/242] fix image api missing filename bug (#1017) * fix image api missing filename bug * add test * add test * update test --- image.go | 32 +++++++++++++++++++++++++++++--- internal/form_builder.go | 17 +++++++++++++++-- internal/form_builder_test.go | 18 ++++++++++++++++++ 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/image.go b/image.go index 72077ce41..84b9daf02 100644 --- a/image.go +++ b/image.go @@ -132,7 +132,32 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } +// WrapReader wraps an io.Reader with filename and Content-type. +func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Name() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} + // ImageEditRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageEditRequest struct { Image io.Reader `json:"image,omitempty"` Mask io.Reader `json:"mask,omitempty"` @@ -150,7 +175,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return @@ -158,7 +183,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) // mask, it is optional if request.Mask != nil { - // mask, filename is not required + // filename verification can be postponed err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return @@ -206,6 +231,7 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) } // ImageVariRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageVariRequest struct { Image io.Reader `json:"image,omitempty"` Model string `json:"model,omitempty"` @@ -221,7 +247,7 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required + // image, filename verification can be postponed err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return diff --git a/internal/form_builder.go b/internal/form_builder.go index 1c2513dd9..5b382df20 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -39,9 +39,18 @@ func escapeQuotes(s string) string { } // CreateFormFileReader creates a form field with a file reader. -// The filename in parameters can be an empty string. -// The filename in Content-Disposition is required, But it can be an empty string. +// The filename in Content-Disposition is required. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { + if filename == "" { + if f, ok := r.(interface{ Name() string }); ok { + filename = f.Name() + } + } + var contentType string + if f, ok := r.(interface{ ContentType() string }); ok { + contentType = f.ContentType() + } + h := make(textproto.MIMEHeader) h.Set( "Content-Disposition", @@ -51,6 +60,10 @@ func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader escapeQuotes(filepath.Base(filename)), ), ) + // content type is optional, but it can be set + if contentType != "" { + h.Set("Content-Type", contentType) + } fieldWriter, err := fb.writer.CreatePart(h) if err != nil { diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 76922c1ba..f4958ad5e 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,6 +1,8 @@ package openai //nolint:testpackage // testing private field import ( + "io" + "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" @@ -53,6 +55,18 @@ func (*failingReader) Read([]byte) (int, error) { return 0, errMockFailingReaderError } +type readerWithNameAndContentType struct { + io.Reader +} + +func (*readerWithNameAndContentType) Name() string { + return "" +} + +func (*readerWithNameAndContentType) ContentType() string { + return "image/png" +} + func TestFormBuilderWithReader(t *testing.T) { file, err := os.CreateTemp(t.TempDir(), "") if err != nil { @@ -71,4 +85,8 @@ func TestFormBuilderWithReader(t *testing.T) { successReader := &bytes.Buffer{} err = builder.CreateFormFileReader("file", successReader, "") checks.NoError(t, err, "formbuilder should not return error") + + rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}} + err = builder.CreateFormFileReader("file", rnc, "") + checks.NoError(t, err, "formbuilder should not return error") } From a931bf7e85af9d39af414de21d907b6835281519 Mon Sep 17 00:00:00 2001 From: Whitea <145814986+Whitea029@users.noreply.github.com> Date: Tue, 17 Jun 2025 18:00:15 +0800 Subject: [PATCH 221/242] test: enhance error accumulator and form builder tests, add marshaller tests (#999) * test: enhance error accumulator and form builder tests, add marshaller tests * test: fix some issue form golangci-lint * test: gofmt form builder test * fix * fix * fix lint --- internal/error_accumulator_test.go | 43 +++++++----------- internal/form_builder.go | 3 ++ internal/form_builder_test.go | 73 +++++++++++++++++++++++++++++- internal/marshaller_test.go | 34 ++++++++++++++ internal/unmarshaler_test.go | 37 +++++++++++++++ 5 files changed, 163 insertions(+), 27 deletions(-) create mode 100644 internal/marshaller_test.go create mode 100644 internal/unmarshaler_test.go diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index d48f28177..3fa9d7714 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -1,41 +1,32 @@ package openai_test import ( - "bytes" - "errors" "testing" - utils "github.com/sashabaranov/go-openai/internal" - "github.com/sashabaranov/go-openai/internal/test" + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" ) -func TestErrorAccumulatorBytes(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &bytes.Buffer{}, +func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } + checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}"))) + checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}"))) - errBytes := accumulator.Bytes() - if len(errBytes) != 0 { - t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) - } - - err := accumulator.Write([]byte("{}")) - if err != nil { - t.Fatalf("%+v", err) - } - - errBytes = accumulator.Bytes() - if len(errBytes) == 0 { - t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}" + if string(ea.Bytes()) != expected { + t.Fatalf("Expected %q, got %q", expected, ea.Bytes()) } } -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, +func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } - err := accumulator.Write([]byte("{")) - if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) + if len(ea.Bytes()) != 0 { + t.Fatal("Buffer should be empty initially") } } diff --git a/internal/form_builder.go b/internal/form_builder.go index 5b382df20..a17e820c0 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -97,6 +97,9 @@ func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, file } func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + if fieldname == "" { + return fmt.Errorf("fieldname cannot be empty") + } return fb.writer.WriteField(fieldname, value) } diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index f4958ad5e..ddd6b8448 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,16 +1,57 @@ package openai //nolint:testpackage // testing private field import ( + "errors" "io" "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" - "errors" "os" "testing" ) +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockWriteField func(string, string) error + mockClose func() error +} + +func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return m.mockCreateFormFile(fieldname, file) +} + +func (m *mockFormBuilder) WriteField(fieldname, value string) error { + return m.mockWriteField(fieldname, value) +} + +func (m *mockFormBuilder) Close() error { + return m.mockClose() +} + +func (m *mockFormBuilder) FormDataContentType() string { + return "" +} + +func TestCloseMethod(t *testing.T) { + t.Run("NormalClose", func(t *testing.T) { + body := &bytes.Buffer{} + builder := NewFormBuilder(body) + checks.NoError(t, builder.Close(), "正常关闭应成功") + }) + + t.Run("ErrorPropagation", func(t *testing.T) { + errorMock := errors.New("mock close error") + mockBuilder := &mockFormBuilder{ + mockClose: func() error { + return errorMock + }, + } + err := mockBuilder.Close() + checks.ErrorIs(t, err, errorMock, "应传递关闭错误") + }) +} + type failingWriter struct { } @@ -90,3 +131,33 @@ func TestFormBuilderWithReader(t *testing.T) { err = builder.CreateFormFileReader("file", rnc, "") checks.NoError(t, err, "formbuilder should not return error") } + +func TestFormDataContentType(t *testing.T) { + t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + contentType := builder.FormDataContentType() + if contentType == "" { + t.Errorf("expected non-empty content type, got empty string") + } + }) +} + +func TestWriteField(t *testing.T) { + t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("", "some value") + checks.HasError(t, err, "fieldname is required") + }) + + t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("key", "value") + checks.NoError(t, err, "should write field without error") + }) +} diff --git a/internal/marshaller_test.go b/internal/marshaller_test.go new file mode 100644 index 000000000..70694faed --- /dev/null +++ b/internal/marshaller_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONMarshaller_Normal(t *testing.T) { + jm := &openai.JSONMarshaller{} + data := map[string]string{"key": "value"} + + b, err := jm.Marshal(data) + checks.NoError(t, err) + if len(b) == 0 { + t.Fatal("should return non-empty bytes") + } +} + +func TestJSONMarshaller_InvalidInput(t *testing.T) { + jm := &openai.JSONMarshaller{} + _, err := jm.Marshal(make(chan int)) + checks.HasError(t, err, "should return error for unsupported type") +} + +func TestJSONMarshaller_EmptyValue(t *testing.T) { + jm := &openai.JSONMarshaller{} + b, err := jm.Marshal(nil) + checks.NoError(t, err) + if string(b) != "null" { + t.Fatalf("unexpected marshaled value: %s", string(b)) + } +} diff --git a/internal/unmarshaler_test.go b/internal/unmarshaler_test.go new file mode 100644 index 000000000..d63eac779 --- /dev/null +++ b/internal/unmarshaler_test.go @@ -0,0 +1,37 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONUnmarshaler_Normal(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{"key":"value"}`) + var v map[string]string + + err := jm.Unmarshal(data, &v) + checks.NoError(t, err) + if v["key"] != "value" { + t.Fatal("unmarshal result mismatch") + } +} + +func TestJSONUnmarshaler_InvalidJSON(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{invalid}`) + var v map[string]interface{} + + err := jm.Unmarshal(data, &v) + checks.HasError(t, err, "should return error for invalid JSON") +} + +func TestJSONUnmarshaler_EmptyInput(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + var v interface{} + + err := jm.Unmarshal(nil, &v) + checks.HasError(t, err, "should return error for nil input") +} From c125ae2ad7b239355161c7f4260a3743ec0c182b Mon Sep 17 00:00:00 2001 From: Hritik Raj Date: Wed, 25 Jun 2025 15:34:45 +0530 Subject: [PATCH 222/242] Fix for removing usage in every stream chunk response. (#1022) * Fix for https://github.com/sashabaranov/go-openai/issues/1021: 1. Make Usage field in completions Response to pointer. * Fix for https://github.com/sashabaranov/go-openai/issues/1021: 1. Make Usage field in completions Response to pointer. 2. Add omitempty to json tag Signed-off-by: Hritik003 --------- Signed-off-by: Hritik003 --- completion.go | 2 +- completion_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/completion.go b/completion.go index 21d4897c4..02ce7b016 100644 --- a/completion.go +++ b/completion.go @@ -242,7 +242,7 @@ type CompletionResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage,omitempty"` httpHeader } diff --git a/completion_test.go b/completion_test.go index 27e2d150e..f0ead0d63 100644 --- a/completion_test.go +++ b/completion_test.go @@ -192,7 +192,7 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } inputTokens *= n completionTokens := completionReq.MaxTokens * len(prompts) * n - res.Usage = openai.Usage{ + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, From 8e9b2ac83ab2d15a34982d689d97ec5db34cbb06 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Tue, 8 Jul 2025 19:19:49 +0800 Subject: [PATCH 223/242] fix: properly unmarshal JSON schema in ChatCompletionResponseFormatJSONSchema.schema (#1028) * feat: #1027 * add tests * feat: #1027 * feat: #1027 * feat: #1027 * update chat_test.go * feat: #1027 * add test cases --- chat.go | 27 ++++++++++ chat_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) diff --git a/chat.go b/chat.go index c8a3e81b3..b4a0ad90f 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -221,6 +223,31 @@ type ChatCompletionResponseFormatJSONSchema struct { Strict bool `json:"strict"` } +func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error { + type rawJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict"` + } + var raw rawJSONSchema + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Name = raw.Name + r.Description = raw.Description + r.Strict = raw.Strict + if len(raw.Schema) > 0 && string(raw.Schema) != "null" { + var d jsonschema.Definition + err := json.Unmarshal(raw.Schema, &d) + if err != nil { + return err + } + r.Schema = &d + } + return nil +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model"` diff --git a/chat_test.go b/chat_test.go index 514706c96..172ce0740 100644 --- a/chat_test.go +++ b/chat_test.go @@ -946,3 +946,142 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": null + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`[123,456]`), + }, + true, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": 123456 + }`), + }, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r openai.ChatCompletionResponseFormatJSONSchema + err := r.UnmarshalJSON(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { + type args struct { + bs []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{bs: []byte(`{ + "model": "llama3-1b", + "messages": [ + { "role": "system", "content": "You are a helpful math tutor." }, + { "role": "user", "content": "solve 8x + 31 = 2" } + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + } + } +}`)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m openai.ChatCompletionRequest + err := json.Unmarshal(tt.args.bs, &m) + if err != nil { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} From bd612cebceb5d84c423bece9e2f2766c7567f8ed Mon Sep 17 00:00:00 2001 From: Matt Tinsley <68241446+mathewtinsley@users.noreply.github.com> Date: Tue, 8 Jul 2025 04:23:41 -0700 Subject: [PATCH 224/242] Add support for Chat Completion Service Tier (#1023) * Add support for Chat Completion Service Tier * Add priority service tier --- chat.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chat.go b/chat.go index b4a0ad90f..0f0c5b5d5 100644 --- a/chat.go +++ b/chat.go @@ -307,6 +307,8 @@ type ChatCompletionRequest struct { // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Specifies the latency tier to use for processing the request. + ServiceTier ServiceTier `json:"service_tier,omitempty"` } type StreamOptions struct { @@ -390,6 +392,15 @@ const ( FinishReasonNull FinishReason = "null" ) +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierPriority ServiceTier = "priority" +) + func (r FinishReason) MarshalJSON() ([]byte, error) { if r == FinishReasonNull || r == "" { return []byte("null"), nil @@ -422,6 +433,7 @@ type ChatCompletionResponse struct { Usage Usage `json:"usage"` SystemFingerprint string `json:"system_fingerprint"` PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` httpHeader } From c650976e492731e3fa758a2dbff2f743ae0e1c98 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 10 Jul 2025 18:35:16 +0800 Subject: [PATCH 225/242] Support $ref and $defs in JSON Schema (#1030) * support $ref and $defs in JSON Schema * update --- jsonschema/json.go | 36 +++++-- jsonschema/json_test.go | 70 ++++++++++++++ jsonschema/validate.go | 65 +++++++++++-- jsonschema/validate_test.go | 187 +++++++++++++++++++++++++++++++++++- 4 files changed, 340 insertions(+), 18 deletions(-) diff --git a/jsonschema/json.go b/jsonschema/json.go index 03bb68891..29d15b409 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -48,6 +48,11 @@ type Definition struct { AdditionalProperties any `json:"additionalProperties,omitempty"` // Whether the schema is nullable or not. Nullable bool `json:"nullable,omitempty"` + + // Ref Reference to a definition in $defs or external schema. + Ref string `json:"$ref,omitempty"` + // Defs A map of reusable schema definitions. + Defs map[string]Definition `json:"$defs,omitempty"` } func (d *Definition) MarshalJSON() ([]byte, error) { @@ -67,10 +72,16 @@ func (d *Definition) Unmarshal(content string, v any) error { } func GenerateSchemaForType(v any) (*Definition, error) { - return reflectSchema(reflect.TypeOf(v)) + var defs = make(map[string]Definition) + def, err := reflectSchema(reflect.TypeOf(v), defs) + if err != nil { + return nil, err + } + def.Defs = defs + return def, nil } -func reflectSchema(t reflect.Type) (*Definition, error) { +func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d Definition switch t.Kind() { case reflect.String: @@ -84,21 +95,32 @@ func reflectSchema(t reflect.Type) (*Definition, error) { d.Type = Boolean case reflect.Slice, reflect.Array: d.Type = Array - items, err := reflectSchema(t.Elem()) + items, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } d.Items = items case reflect.Struct: + if t.Name() != "" { + if _, ok := defs[t.Name()]; !ok { + defs[t.Name()] = Definition{} + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + defs[t.Name()] = *object + } + return &Definition{Ref: "#/$defs/" + t.Name()}, nil + } d.Type = Object d.AdditionalProperties = false - object, err := reflectSchemaObject(t) + object, err := reflectSchemaObject(t, defs) if err != nil { return nil, err } d = *object case reflect.Ptr: - definition, err := reflectSchema(t.Elem()) + definition, err := reflectSchema(t.Elem(), defs) if err != nil { return nil, err } @@ -112,7 +134,7 @@ func reflectSchema(t reflect.Type) (*Definition, error) { return &d, nil } -func reflectSchemaObject(t reflect.Type) (*Definition, error) { +func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) { var d = Definition{ Type: Object, AdditionalProperties: false, @@ -136,7 +158,7 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) { required = false } - item, err := reflectSchema(field.Type) + item, err := reflectSchema(field.Type, defs) if err != nil { return nil, err } diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 84f25fa85..31b54ed1a 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -183,6 +183,17 @@ func TestDefinition_MarshalJSON(t *testing.T) { } func TestStructToSchema(t *testing.T) { + type Tweet struct { + Text string `json:"text"` + } + + type Person struct { + Name string `json:"name,omitempty"` + Age int `json:"age,omitempty"` + Friends []Person `json:"friends,omitempty"` + Tweets []Tweet `json:"tweets,omitempty"` + } + tests := []struct { name string in any @@ -376,6 +387,65 @@ func TestStructToSchema(t *testing.T) { "additionalProperties":false }`, }, + { + name: "Test with $ref and $defs", + in: struct { + Person Person `json:"person"` + Tweets []Tweet `json:"tweets"` + }{}, + want: `{ + "type" : "object", + "properties" : { + "person" : { + "$ref" : "#/$defs/Person" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "required" : [ "person", "tweets" ], + "additionalProperties" : false, + "$defs" : { + "Person" : { + "type" : "object", + "properties" : { + "age" : { + "type" : "integer" + }, + "friends" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Person" + } + }, + "name" : { + "type" : "string" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "additionalProperties" : false + }, + "Tweet" : { + "type" : "object", + "properties" : { + "text" : { + "type" : "string" + } + }, + "required" : [ "text" ], + "additionalProperties" : false + } + } +}`, + }, } for _, tt := range tests { diff --git a/jsonschema/validate.go b/jsonschema/validate.go index 49f9b8859..1bd2f809c 100644 --- a/jsonschema/validate.go +++ b/jsonschema/validate.go @@ -5,26 +5,68 @@ import ( "errors" ) +func CollectDefs(def Definition) map[string]Definition { + result := make(map[string]Definition) + collectDefsRecursive(def, result, "#") + return result +} + +func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) { + for k, v := range def.Defs { + path := prefix + "/$defs/" + k + result[path] = v + collectDefsRecursive(v, result, path) + } + for k, sub := range def.Properties { + collectDefsRecursive(sub, result, prefix+"/properties/"+k) + } + if def.Items != nil { + collectDefsRecursive(*def.Items, result, prefix) + } +} + func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { var data any err := json.Unmarshal(content, &data) if err != nil { return err } - if !Validate(schema, data) { + if !Validate(schema, data, WithDefs(CollectDefs(schema))) { return errors.New("data validation failed against the provided schema") } return json.Unmarshal(content, &v) } -func Validate(schema Definition, data any) bool { +type validateArgs struct { + Defs map[string]Definition +} + +type ValidateOption func(*validateArgs) + +func WithDefs(defs map[string]Definition) ValidateOption { + return func(option *validateArgs) { + option.Defs = defs + } +} + +func Validate(schema Definition, data any, opts ...ValidateOption) bool { + args := validateArgs{} + for _, opt := range opts { + opt(&args) + } + if len(opts) == 0 { + args.Defs = CollectDefs(schema) + } switch schema.Type { case Object: - return validateObject(schema, data) + return validateObject(schema, data, args.Defs) case Array: - return validateArray(schema, data) + return validateArray(schema, data, args.Defs) case String: - _, ok := data.(string) + v, ok := data.(string) + if ok && len(schema.Enum) > 0 { + return contains(schema.Enum, v) + } return ok case Number: // float64 and int _, ok := data.(float64) @@ -45,11 +87,16 @@ func Validate(schema Definition, data any) bool { case Null: return data == nil default: + if schema.Ref != "" && args.Defs != nil { + if v, ok := args.Defs[schema.Ref]; ok { + return Validate(v, data, WithDefs(args.Defs)) + } + } return false } } -func validateObject(schema Definition, data any) bool { +func validateObject(schema Definition, data any, defs map[string]Definition) bool { dataMap, ok := data.(map[string]any) if !ok { return false @@ -61,7 +108,7 @@ func validateObject(schema Definition, data any) bool { } for key, valueSchema := range schema.Properties { value, exists := dataMap[key] - if exists && !Validate(valueSchema, value) { + if exists && !Validate(valueSchema, value, WithDefs(defs)) { return false } else if !exists && contains(schema.Required, key) { return false @@ -70,13 +117,13 @@ func validateObject(schema Definition, data any) bool { return true } -func validateArray(schema Definition, data any) bool { +func validateArray(schema Definition, data any, defs map[string]Definition) bool { dataArray, ok := data.([]any) if !ok { return false } for _, item := range dataArray { - if !Validate(*schema.Items, item) { + if !Validate(*schema.Items, item, WithDefs(defs)) { return false } } diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go index 6fa30ab0c..aefdf4069 100644 --- a/jsonschema/validate_test.go +++ b/jsonschema/validate_test.go @@ -1,6 +1,7 @@ package jsonschema_test import ( + "reflect" "testing" "github.com/sashabaranov/go-openai/jsonschema" @@ -70,6 +71,96 @@ func Test_Validate(t *testing.T) { }, Required: []string{"string"}, }}, false}, + { + "test schema with ref and defs", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "male", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, true}, + { + "test enum invalid value", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "other", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -156,8 +247,100 @@ func TestUnmarshal(t *testing.T) { err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) if (err != nil) != tt.wantErr { t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) - } else if err == nil { - t.Logf("Unmarshal() v = %+v\n", tt.args.v) + } + }) + } +} + +func TestCollectDefs(t *testing.T) { + type args struct { + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want map[string]jsonschema.Definition + }{ + { + "test collect defs", + args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + }, + map[string]jsonschema.Definition{ + "#/$defs/Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "#/$defs/Person/$defs/Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + "#/$defs/Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := jsonschema.CollectDefs(tt.args.schema) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CollectDefs() = %v, want %v", got, tt.want) } }) } From 1bf77f6fd6f10b9cd80e1dfdedd616f3f1712e20 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:49:20 +0100 Subject: [PATCH 226/242] Improve unit test coverage (#1032) * Add unit tests to improve coverage * Fix type assertion checks in tests --- image_test.go | 50 ++++++++++++++++++++++++++++++ internal/error_accumulator_test.go | 7 +++++ internal/form_builder_test.go | 14 +++++++++ internal/request_builder_test.go | 27 ++++++++++++++++ 4 files changed, 98 insertions(+) diff --git a/image_test.go b/image_test.go index 644005515..bb9a086fd 100644 --- a/image_test.go +++ b/image_test.go @@ -4,6 +4,7 @@ import ( utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test/checks" + "bytes" "context" "fmt" "io" @@ -156,3 +157,52 @@ func TestVariImageFormBuilderFailures(t *testing.T) { _, err = client.CreateVariImage(ctx, req) checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") } + +type testNamedReader struct{ io.Reader } + +func (testNamedReader) Name() string { return "named.txt" } + +func TestWrapReader(t *testing.T) { + r := bytes.NewBufferString("data") + wrapped := WrapReader(r, "file.png", "image/png") + f, ok := wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped reader missing Name or ContentType") + } + if f.Name() != "file.png" { + t.Fatalf("expected name file.png, got %s", f.Name()) + } + if f.ContentType() != "image/png" { + t.Fatalf("expected content type image/png, got %s", f.ContentType()) + } + + // test name from underlying reader + nr := testNamedReader{Reader: bytes.NewBufferString("d")} + wrapped = WrapReader(nr, "", "text/plain") + f, ok = wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped named reader missing Name or ContentType") + } + if f.Name() != "named.txt" { + t.Fatalf("expected name named.txt, got %s", f.Name()) + } + if f.ContentType() != "text/plain" { + t.Fatalf("expected content type text/plain, got %s", f.ContentType()) + } + + // no name provided + wrapped = WrapReader(bytes.NewBuffer(nil), "", "") + f2, ok := wrapped.(interface{ Name() string }) + if !ok { + t.Fatal("wrapped anonymous reader missing Name") + } + if f2.Name() != "" { + t.Fatalf("expected empty name, got %s", f2.Name()) + } +} diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index 3fa9d7714..f6c226c5e 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -4,6 +4,7 @@ import ( "testing" openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -30,3 +31,9 @@ func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { t.Fatal("Buffer should be empty initially") } } + +func TestDefaultErrorAccumulator_WriteError(t *testing.T) { + ea := &openai.DefaultErrorAccumulator{Buffer: &test.FailingErrorBuffer{}} + err := ea.Write([]byte("fail")) + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Write should propagate buffer errors") +} diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index ddd6b8448..1cc82ab8a 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -161,3 +161,17 @@ func TestWriteField(t *testing.T) { checks.NoError(t, err, "should write field without error") }) } + +func TestCreateFormFile(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "") + if err == nil { + t.Fatal("expected error for empty filename") + } + + builder = NewFormBuilder(&failingWriter{}) + err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") + checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e26022a6b..1561b87fb 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "io" "net/http" "reflect" "testing" @@ -59,3 +60,29 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { t.Errorf("Build() got = %v, want %v", got, want) } } + +func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { + b := NewRequestBuilder() + ctx := context.Background() + method := http.MethodPost + url := "/reader" + bodyContent := "hello" + body := bytes.NewBufferString(bodyContent) + header := http.Header{"X-Test": []string{"val"}} + + req, err := b.Build(ctx, method, url, body, header) + if err != nil { + t.Fatalf("Build returned error: %v", err) + } + + gotBody, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("cannot read body: %v", err) + } + if string(gotBody) != bodyContent { + t.Fatalf("expected body %q, got %q", bodyContent, string(gotBody)) + } + if req.Header.Get("X-Test") != "val" { + t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) + } +} From a0c185f3628cd616fb9d1648700fc51c40c42dda Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Fri, 11 Jul 2025 19:17:33 +0800 Subject: [PATCH 227/242] Removed root $ref from GenerateSchemaForType (#1033) * support $ref and $defs in JSON Schema * update * removed root $ref from JSON Schema * Update json.go * Update json_test.go * Update jsonschema/json.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update jsonschema/json.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- jsonschema/json.go | 44 ++++++++++ jsonschema/json_test.go | 174 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 218 insertions(+) diff --git a/jsonschema/json.go b/jsonschema/json.go index 29d15b409..75e3b5173 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -77,6 +77,27 @@ func GenerateSchemaForType(v any) (*Definition, error) { if err != nil { return nil, err } + // If the schema has a root $ref, resolve it by: + // 1. Extracting the key from the $ref. + // 2. Detaching the referenced definition from $defs. + // 3. Checking for self-references in the detached definition. + // - If a self-reference is found, restore the original $defs structure. + // 4. Flattening the referenced definition into the root schema. + // 5. Clearing the $ref field in the root schema. + if def.Ref != "" { + origRef := def.Ref + key := strings.TrimPrefix(origRef, "#/$defs/") + if root, ok := defs[key]; ok { + delete(defs, key) + root.Defs = defs + if containsRef(root, origRef) { + root.Defs = nil + defs[key] = root + } + *def = root + } + def.Ref = "" + } def.Defs = defs return def, nil } @@ -189,3 +210,26 @@ func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definitio d.Properties = properties return &d, nil } + +func containsRef(def Definition, targetRef string) bool { + if def.Ref == targetRef { + return true + } + + for _, d := range def.Defs { + if containsRef(d, targetRef) { + return true + } + } + + for _, prop := range def.Properties { + if containsRef(prop, targetRef) { + return true + } + } + + if def.Items != nil && containsRef(*def.Items, targetRef) { + return true + } + return false +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 31b54ed1a..34f5d88eb 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -182,6 +182,18 @@ func TestDefinition_MarshalJSON(t *testing.T) { } } +type User struct { + ID int `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Orders []*Order `json:"orders,omitempty"` +} + +type Order struct { + ID int `json:"id,omitempty"` + Amount float64 `json:"amount,omitempty"` + Buyer *User `json:"buyer,omitempty"` +} + func TestStructToSchema(t *testing.T) { type Tweet struct { Text string `json:"text"` @@ -194,6 +206,13 @@ func TestStructToSchema(t *testing.T) { Tweets []Tweet `json:"tweets,omitempty"` } + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + tests := []struct { name string in any @@ -444,6 +463,161 @@ func TestStructToSchema(t *testing.T) { "additionalProperties" : false } } +}`, + }, + { + name: "Test Person", + in: Person{}, + want: `{ + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false, + "$defs": { + "Person": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false + }, + "Tweet": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "required": [ + "text" + ], + "additionalProperties": false + } + } +}`, + }, + { + name: "Test MyStructuredResponse", + in: MyStructuredResponse{}, + want: `{ + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "kebab_case", + "snake_case" + ], + "additionalProperties": false +}`, + }, + { + name: "Test User", + in: User{}, + want: `{ + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false, + "$defs": { + "Order": { + "type": "object", + "properties": { + "amount": { + "type": "number" + }, + "buyer": { + "$ref": "#/$defs/User" + }, + "id": { + "type": "integer" + } + }, + "additionalProperties": false + }, + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false + } + } }`, }, } From 575aff439644c3b6460f2617b81b119044bfc35c Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 13:15:15 +0100 Subject: [PATCH 228/242] Add tests for form and request builders (#1036) --- internal/form_builder_test.go | 13 +++++++++++++ internal/request_builder_test.go | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index 1cc82ab8a..53ef11d23 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -8,6 +8,7 @@ import ( "bytes" "os" + "strings" "testing" ) @@ -175,3 +176,15 @@ func TestCreateFormFile(t *testing.T) { err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") } + +func TestCreateFormFileSuccess(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "foo.txt") + checks.NoError(t, err, "createFormFile should succeed") + + if !strings.Contains(buf.String(), "filename=\"foo.txt\"") { + t.Fatalf("expected filename header, got %q", buf.String()) + } +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index 1561b87fb..adccb158e 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -86,3 +86,11 @@ func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) } } + +func TestRequestBuilderInvalidURL(t *testing.T) { + b := NewRequestBuilder() + _, err := b.Build(context.Background(), http.MethodGet, ":", nil, nil) + if err == nil { + t.Fatal("expected error for invalid URL") + } +} From 88eb1df90bb6af6a1f4d471a08d9f8a13dd5a8e5 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:58:52 +0100 Subject: [PATCH 229/242] Ignore codecov coverage for examples and internal/test (#1038) --- .codecov.yml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 .codecov.yml diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..81773666c --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +coverage: + ignore: + - "examples/**" + - "internal/test/**" From 8d681e7f9a8f172168199f29e8e1f16701d6817a Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:18:23 +0100 Subject: [PATCH 230/242] Increase image.go test coverage to 100% (#1039) --- image_test.go | 303 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 209 insertions(+), 94 deletions(-) diff --git a/image_test.go b/image_test.go index bb9a086fd..c2c8f42dc 100644 --- a/image_test.go +++ b/image_test.go @@ -40,122 +40,237 @@ func (fb *mockFormBuilder) FormDataContentType() string { } func TestImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageEditRequest{ - Mask: &os.File{}, - } - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return mockFailedErr - } - _, err := client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { - if name == "mask" { - return mockFailedErr - } - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageEditRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "mask", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { + if name == "mask" { + return mockFailedErr + } + return nil + } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "prompt", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "prompt" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateEditImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateEditImage should return error if form builder fails") + }) } - failForField = "prompt" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "n" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "response_format" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateEditImage(ctx, ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateEditImage should return error if request builder fails") + }) } func TestVariImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageVariRequest{} - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return mockFailedErr - } - _, err := client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFileReader = func(string, io.Reader, string) error { - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageVariRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, } - failForField = "n" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateVariImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + }) + } - failForField = "response_format" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateVariImage(ctx, ImageVariRequest{Image: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateVariImage should return error if request builder fails") + }) } type testNamedReader struct{ io.Reader } From 8665ad7264b64cd2046c2a2d4addf8617c6fffea Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:31:09 +0100 Subject: [PATCH 231/242] Increase jsonschema test coverage (#1040) --- jsonschema/json_additional_test.go | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 jsonschema/json_additional_test.go diff --git a/jsonschema/json_additional_test.go b/jsonschema/json_additional_test.go new file mode 100644 index 000000000..70cf37490 --- /dev/null +++ b/jsonschema/json_additional_test.go @@ -0,0 +1,73 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// Test Definition.Unmarshal, including success path, validation error, +// JSON syntax error and type mismatch during unmarshalling. +func TestDefinitionUnmarshal(t *testing.T) { + schema := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + }, + } + + var dst struct { + Name string `json:"name"` + } + if err := schema.Unmarshal(`{"name":"foo"}`, &dst); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Name != "foo" { + t.Errorf("expected name to be foo, got %q", dst.Name) + } + + if err := schema.Unmarshal(`{`, &dst); err == nil { + t.Error("expected error for malformed json") + } + + if err := schema.Unmarshal(`{"name":1}`, &dst); err == nil { + t.Error("expected validation error") + } + + numSchema := jsonschema.Definition{Type: jsonschema.Number} + var s string + if err := numSchema.Unmarshal(`123`, &s); err == nil { + t.Error("expected unmarshal type error") + } +} + +// Ensure GenerateSchemaForType returns an error when encountering unsupported types. +func TestGenerateSchemaForTypeUnsupported(t *testing.T) { + type Bad struct { + Ch chan int `json:"ch"` + } + if _, err := jsonschema.GenerateSchemaForType(Bad{}); err == nil { + t.Fatal("expected error for unsupported type") + } +} + +// Validate should fail when provided data does not match the expected container types. +func TestValidateInvalidContainers(t *testing.T) { + objSchema := jsonschema.Definition{Type: jsonschema.Object} + if jsonschema.Validate(objSchema, 1) { + t.Error("expected object validation to fail for non-map input") + } + + arrSchema := jsonschema.Definition{Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}} + if jsonschema.Validate(arrSchema, 1) { + t.Error("expected array validation to fail for non-slice input") + } +} + +// Validate should return false when $ref cannot be resolved. +func TestValidateRefNotFound(t *testing.T) { + refSchema := jsonschema.Definition{Ref: "#/$defs/Missing"} + if jsonschema.Validate(refSchema, "data", jsonschema.WithDefs(map[string]jsonschema.Definition{})) { + t.Error("expected validation to fail when reference is missing") + } +} From 1e912177886f37ae0c4159982a962d9ef8ff921b Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:48:49 +0100 Subject: [PATCH 232/242] test: cover CreateFile request builder failure (#1041) --- files_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/files_test.go b/files_test.go index 3c1b99fb4..486ef892e 100644 --- a/files_test.go +++ b/files_test.go @@ -121,3 +121,20 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } +func TestCreateFileRequestBuilderFailure(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return &mockFormBuilder{ + mockWriteField: func(string, string) error { return nil }, + mockCreateFormFile: func(string, *os.File) error { return nil }, + mockClose: func() error { return nil }, + } + } + + _, err := client.CreateFile(context.Background(), FileRequest{FilePath: "client.go"}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateFile should return error if request builder fails") +} From 3bb1014fa7e8d09fa6ff31bc6bce6172a3bf59ae Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:11:07 +0100 Subject: [PATCH 233/242] ci: enable version compatibility vet (#1042) --- .github/workflows/pr.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 18c720f03..2c9730656 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -16,7 +16,7 @@ jobs: go-version: '1.24' - name: Run vet run: | - go vet . + go vet -stdversion ./... - name: Run golangci-lint uses: golangci/golangci-lint-action@v7 with: From bd36c45dc505d3592ddfc0d53c06561fe8a3dacb Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Fri, 11 Jul 2025 21:45:53 +0530 Subject: [PATCH 234/242] Support for extra_body parameter for embeddings API (#906) * support for extra_body parameter for embeddings API * done linting * added unit tests * improved code coverage and removed unnecessary checks * test cleanup * updated body map creation code * code coverage * minor change * updated testcase comment --- client.go | 14 ++++++++++++++ embeddings.go | 32 +++++++++++++++++++++++++++++++- embeddings_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index cef375348..413b8db03 100644 --- a/client.go +++ b/client.go @@ -84,6 +84,20 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) diff --git a/embeddings.go b/embeddings.go index 4a0e682da..8593f8b5b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "math" "net/http" @@ -160,6 +161,9 @@ type EmbeddingRequest struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() + + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + req, err := c.newRequest( ctx, http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), - withBody(baseReq), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. ) if err != nil { return diff --git a/embeddings_test.go b/embeddings_test.go index 438978169..07f1262cb 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) { t.Fatalf("Expected embedding request to contain model field") } + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } - // test create embeddings with strings (simple embedding request) + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + Dimensions: 1, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: make(chan int), // Channels are not serializable + Model: "example_model", + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ From e6c1d3e5bde0bae5966070ab4edee874b6c8c73f Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:39:39 +0100 Subject: [PATCH 235/242] Increase jsonschema test coverage (#1043) * test: expand jsonschema coverage * test: fix package name for containsref tests --- jsonschema/containsref_test.go | 48 ++++++++++++++++++++++++++++++++++ jsonschema/json_errors_test.go | 27 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 jsonschema/containsref_test.go create mode 100644 jsonschema/json_errors_test.go diff --git a/jsonschema/containsref_test.go b/jsonschema/containsref_test.go new file mode 100644 index 000000000..dc1842775 --- /dev/null +++ b/jsonschema/containsref_test.go @@ -0,0 +1,48 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// SelfRef struct used to produce a self-referential schema. +type SelfRef struct { + Friends []SelfRef `json:"friends"` +} + +// Address struct referenced by Person without self-reference. +type Address struct { + Street string `json:"street"` +} + +type Person struct { + Address Address `json:"address"` +} + +// TestGenerateSchemaForType_SelfRef ensures that self-referential types are not +// flattened during schema generation. +func TestGenerateSchemaForType_SelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(SelfRef{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["SelfRef"]; !ok { + t.Fatal("expected defs to contain SelfRef for self reference") + } +} + +// TestGenerateSchemaForType_NoSelfRef ensures that non-self-referential types +// are flattened and do not reappear in $defs. +func TestGenerateSchemaForType_NoSelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(Person{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["Person"]; ok { + t.Fatal("unexpected Person definition in defs") + } + if _, ok := schema.Defs["Address"]; !ok { + t.Fatal("expected Address definition in defs") + } +} diff --git a/jsonschema/json_errors_test.go b/jsonschema/json_errors_test.go new file mode 100644 index 000000000..3b864fc21 --- /dev/null +++ b/jsonschema/json_errors_test.go @@ -0,0 +1,27 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// TestGenerateSchemaForType_ErrorPaths verifies error handling for unsupported types. +func TestGenerateSchemaForType_ErrorPaths(t *testing.T) { + type anon struct{ Ch chan int } + tests := []struct { + name string + v any + }{ + {"slice", []chan int{}}, + {"anon struct", anon{}}, + {"pointer", (*chan int)(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := jsonschema.GenerateSchemaForType(tt.v); err == nil { + t.Errorf("expected error for %s", tt.name) + } + }) + } +} From 181c0e8fd7358b1f37f902262b246d513acd1e29 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 18:29:23 +0100 Subject: [PATCH 236/242] Add tests for internal utilities (#1044) * Add unit tests for internal test utilities * Fix lint issues in internal tests --- internal/test/checks/checks_test.go | 19 +++++++ internal/test/failer_test.go | 24 +++++++++ internal/test/helpers_test.go | 54 +++++++++++++++++++ internal/test/server.go | 12 +++++ internal/test/server_test.go | 80 +++++++++++++++++++++++++++++ 5 files changed, 189 insertions(+) create mode 100644 internal/test/checks/checks_test.go create mode 100644 internal/test/failer_test.go create mode 100644 internal/test/helpers_test.go create mode 100644 internal/test/server_test.go diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..0677054df --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,19 @@ +package checks_test + +import ( + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestChecksSuccessPaths(t *testing.T) { + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) + target := errors.New("x") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..fb1f4bf06 --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,24 @@ +//nolint:testpackage // need access to unexported fields and types for testing +package test + +import ( + "errors" + "testing" +) + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..aa177679b --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,54 @@ +package test_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + internaltest.CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -23,6 +23,18 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { // to make the registered paths friendlier to a regex match in the route handler // in OpenAITestServer diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..f8ce731d1 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,80 @@ +package test_test + +import ( + "io" + "net/http" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestGetTestToken(t *testing.T) { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := internaltest.NewTestServer() + if ts == nil { + t.Fatalf("server not properly initialized") + } + if ts.HandlerCount() != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if !ts.HasHandler("/foo/*") { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} From 4f87294cebdd14457cc2e1013cdf5acf5db3d27d Mon Sep 17 00:00:00 2001 From: Ayush Sawant Date: Sun, 20 Jul 2025 01:29:02 +0530 Subject: [PATCH 237/242] Add GuidedChoice to ChatCompletionRequest (#1034) * Add GuidedChoice to ChatCompletionRequest * made separate NonOpenAIExtensions * fixed lint issue * renamed struct and removed inline json tag * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update chat.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- chat.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chat.go b/chat.go index 0f0c5b5d5..e14acd9d9 100644 --- a/chat.go +++ b/chat.go @@ -248,6 +248,16 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro return nil } +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions (e.g., vendor-specific implementations like vLLM). +type ChatCompletionRequestExtensions struct { + // GuidedChoice is a vLLM-specific extension that restricts the model's output + // to one of the predefined string choices provided in this field. This feature + // is used to constrain the model's responses to a controlled set of options, + // ensuring predictable and consistent outputs in scenarios where specific + // choices are required. + GuidedChoice []string `json:"guided_choice,omitempty"` +} + // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { Model string `json:"model"` @@ -309,6 +319,8 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Embedded struct for non-OpenAI extensions + ChatCompletionRequestExtensions } type StreamOptions struct { From c4273cb5f46031ee478601f4edd82d2fad401e77 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Sat, 19 Jul 2025 21:07:34 +0100 Subject: [PATCH 238/242] fix(chat): shorten comment to pass linter (#1050) --- chat.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chat.go b/chat.go index e14acd9d9..0bb2e98ee 100644 --- a/chat.go +++ b/chat.go @@ -248,7 +248,8 @@ func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) erro return nil } -// ChatCompletionRequestExtensions contains third-party OpenAI API extensions (e.g., vendor-specific implementations like vLLM). +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions +// (e.g., vendor-specific implementations like vLLM). type ChatCompletionRequestExtensions struct { // GuidedChoice is a vLLM-specific extension that restricts the model's output // to one of the predefined string choices provided in this field. This feature @@ -264,7 +265,7 @@ type ChatCompletionRequest struct { Messages []ChatCompletionMessage `json:"messages"` // MaxTokens The maximum number of tokens that can be generated in the chat completion. // This value can be used to control costs for text generated via API. - // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models. + // Deprecated: use MaxCompletionTokens. Not compatible with o1-series models. // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens MaxTokens int `json:"max_tokens,omitempty"` // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, From f7d6ece81065bde1e93576c89f7506d1b35cb205 Mon Sep 17 00:00:00 2001 From: Behzad Soltanpour Date: Mon, 11 Aug 2025 12:45:50 +0330 Subject: [PATCH 239/242] add GPT-5 model constants and reasoning validation (#1062) --- chat_test.go | 120 +++++++++++++++++++++++++++++++++++++++++ completion.go | 8 +++ completion_test.go | 29 ++++++++++ reasoning_validator.go | 9 ++-- 4 files changed, 162 insertions(+), 4 deletions(-) diff --git a/chat_test.go b/chat_test.go index 172ce0740..236cff736 100644 --- a/chat_test.go +++ b/chat_test.go @@ -331,6 +331,126 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { } } +func TestGPT5ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.GPT5, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Nano, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5ChatLatest, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense diff --git a/completion.go b/completion.go index 02ce7b016..27d69f587 100644 --- a/completion.go +++ b/completion.go @@ -49,6 +49,10 @@ const ( GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" GPT4Dot5Preview = "gpt-4.5-preview" GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT5 = "gpt-5" + GPT5Mini = "gpt-5-mini" + GPT5Nano = "gpt-5-nano" + GPT5ChatLatest = "gpt-5-chat-latest" GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" @@ -142,6 +146,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4Dot1Mini20250414: true, GPT4Dot1Nano: true, GPT4Dot1Nano20250414: true, + GPT5: true, + GPT5Mini: true, + GPT5Nano: true, + GPT5ChatLatest: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, diff --git a/completion_test.go b/completion_test.go index f0ead0d63..abfc3007e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -300,3 +300,32 @@ func TestCompletionWithGPT4oModels(t *testing.T) { }) } } + +// TestCompletionWithGPT5Models Tests that GPT5 models are not supported for completion endpoint. +func TestCompletionWithGPT5Models(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT5, + openai.GPT5Mini, + openai.GPT5Nano, + openai.GPT5ChatLatest, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} diff --git a/reasoning_validator.go b/reasoning_validator.go index 2910b1395..1d26ca047 100644 --- a/reasoning_validator.go +++ b/reasoning_validator.go @@ -28,21 +28,22 @@ var ( ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll ) -// ReasoningValidator handles validation for o-series model requests. +// ReasoningValidator handles validation for reasoning model requests. type ReasoningValidator struct{} -// NewReasoningValidator creates a new validator for o-series models. +// NewReasoningValidator creates a new validator for reasoning models. func NewReasoningValidator() *ReasoningValidator { return &ReasoningValidator{} } -// Validate performs all validation checks for o-series models. +// Validate performs all validation checks for reasoning models. func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { o1Series := strings.HasPrefix(request.Model, "o1") o3Series := strings.HasPrefix(request.Model, "o3") o4Series := strings.HasPrefix(request.Model, "o4") + gpt5Series := strings.HasPrefix(request.Model, "gpt-5") - if !o1Series && !o3Series && !o4Series { + if !o1Series && !o3Series && !o4Series && !gpt5Series { return nil } From f71d1a622abab7fa75159b02318cbcf6e2dcb0a0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 12 Aug 2025 18:03:57 +0800 Subject: [PATCH 240/242] feat: add safety_identifier params (#1066) --- chat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chat.go b/chat.go index 0bb2e98ee..9719f6b92 100644 --- a/chat.go +++ b/chat.go @@ -320,6 +320,11 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. + // The IDs should be a string that uniquely identifies each user. + // We recommend hashing their username or email address, in order to avoid sending us any identifying information. + // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier + SafetyIdentifier string `json:"safety_identifier,omitempty"` // Embedded struct for non-OpenAI extensions ChatCompletionRequestExtensions } From 8e5611cc5efdc2533b80e5667b69741c3fad875c Mon Sep 17 00:00:00 2001 From: Amady Azdaev Date: Fri, 29 Aug 2025 20:29:03 +0300 Subject: [PATCH 241/242] Add Verbosity parameter to Chat Completion Request (#1064) * add verbosity param to ChatCompletionRequest * edit comment about verbosity --- chat.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/chat.go b/chat.go index 9719f6b92..0aa018715 100644 --- a/chat.go +++ b/chat.go @@ -320,6 +320,12 @@ type ChatCompletionRequest struct { ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // Specifies the latency tier to use for processing the request. ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Verbosity determines how many output tokens are generated. Lowering the number of + // tokens reduces overall latency. It can be set to "low", "medium", or "high". + // Note: This field is only confirmed to work with gpt-5, gpt-5-mini and gpt-5-nano. + // Also, it is not in the API reference of chat completion at the time of writing, + // though it is supported by the API. + Verbosity string `json:"verbosity,omitempty"` // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. // The IDs should be a string that uniquely identifies each user. // We recommend hashing their username or email address, in order to avoid sending us any identifying information. From 5d7a276f4c0e48f97354fe555cb793b52e350e62 Mon Sep 17 00:00:00 2001 From: Christopher Petito <47751006+krissetto@users.noreply.github.com> Date: Tue, 21 Oct 2025 21:27:33 +0200 Subject: [PATCH 242/242] Stop stripping dots in azure model mapper for models that aren't 3.5 based (#1079) fixes #978 Signed-off-by: Christopher Petito --- config.go | 7 ++++++- config_test.go | 4 ++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 4788ba62a..4b8cfb6fb 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package openai import ( "net/http" "regexp" + "strings" ) const ( @@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { APIType: APITypeAzure, APIVersion: "2023-05-15", AzureModelMapperFunc: func(model string) string { - return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + // only 3.5 models have the "." stripped in their names + if strings.Contains(model, "3.5") { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + } + return strings.ReplaceAll(model, ":", "") }, HTTPClient: &http.Client{}, diff --git a/config_test.go b/config_test.go index 960230804..f44b80825 100644 --- a/config_test.go +++ b/config_test.go @@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) { Model: "gpt-3.5-turbo-0301", Expect: "gpt-35-turbo-0301", }, + { + Model: "gpt-4.1", + Expect: "gpt-4.1", + }, { Model: "text-embedding-ada-002", Expect: "text-embedding-ada-002",