diff --git a/api_internal_test.go b/api_internal_test.go index a590ec9ab..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org/", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,7 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) { cases := []struct { Name string BaseURL string + Suffix string Expect string }{ { "CloudflareAzureBaseURLWithSlashAutoStrip", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + "chat/completions?api-version=2023-05-15", }, { - "CloudflareAzureBaseURLWithoutSlashOK", + "", "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", - "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + - "chat/completions?api-version=2023-05-15", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", }, } @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) { cli := NewClientWithConfig(az) - actual := cli.fullURL("/chat/completions") + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/assistant.go b/assistant.go index 661681e83..44cb63659 100644 --- a/assistant.go +++ b/assistant.go @@ -14,17 +14,17 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` httpHeader } @@ -34,6 +34,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -41,19 +42,36 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases diff --git a/audio.go b/audio.go index dbc26d154..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..3c1a9d0d7 --- /dev/null +++ b/batch.go @@ -0,0 +1,271 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data []struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/chat.go b/chat.go index a1eb11720..56e99a78b 100644 --- a/chat.go +++ b/chat.go @@ -175,11 +175,20 @@ type ChatCompletionResponseFormatType string const ( ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` } // ChatCompletionRequest represents a request structure for chat completion API. @@ -218,6 +227,8 @@ type ChatCompletionRequest struct { ToolChoice any `json:"tool_choice,omitempty"` // Options for streaming response. Only set this when you set stream: true. StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` } type StreamOptions struct { @@ -251,6 +262,7 @@ type ToolFunction struct { type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` // Parameters is an object describing the function. // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. @@ -345,7 +357,12 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index ffd512ff6..3f90bc019 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/client.go b/client.go index 7bc28e984..206a933a4 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" utils "github.com/sashabaranov/go-openai/internal" @@ -156,6 +157,28 @@ func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err er return } +func sendRequestStreamV2(client *Client, req *http.Request) (stream *StreamerV2, err error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) + if err != nil { + return + } + + // TODO: how to handle error? + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d , body:%s", resp.StatusCode, string(body)) + } + + return NewStreamerV2(resp.Body), nil +} + func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") @@ -212,6 +235,28 @@ func decodeResponse(body io.Reader, v any) error { } } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + func decodeString(body io.Reader, output *string) error { b, err := io.ReadAll(body) if err != nil { @@ -222,38 +267,43 @@ func decodeString(body io.Reader, output *string) error { } // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) - } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } - return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, c.config.APIVersion, - ) + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) + } + + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) } + return fmt.Sprintf("%s%s", baseURL, suffix) +} - // https://developers.cloudflare.com/ai-gateway/providers/azureopenai/ - if c.config.APIType == APITypeCloudflareAzure { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion) +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { diff --git a/completion.go b/completion.go index ced8e0606..0a26e278e 100644 --- a/completion.go +++ b/completion.go @@ -24,6 +24,10 @@ const ( GPT40314 = "gpt-4-0314" GPT4o = "gpt-4o" GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4oLatest = "chatgpt-4o-latest" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" GPT4Turbo = "gpt-4-turbo" GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" GPT4Turbo0125 = "gpt-4-0125-preview" @@ -86,6 +90,9 @@ var disabledModelsForEndpoints = map[string]map[string]bool{ GPT4: true, GPT4o: true, GPT4o20240513: true, + GPT4o20240806: true, + GPT4oMini: true, + GPT4oMini20240718: true, GPT4TurboPreview: true, GPT4VisionPreview: true, GPT4Turbo1106: true, @@ -203,7 +210,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/config.go b/config.go index bb437c97f..584ec1812 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,11 @@ const ( const AzureAPIKeyHeader = "api-key" -const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. +const defaultAssistantVersion = "v2" // This will be deprecated by the end of 2024. + +type HTTPDoer interface { + Do(req *http.Request) (*http.Response, error) +} // ClientConfig is a configuration of a client. type ClientConfig struct { @@ -36,7 +40,7 @@ type ClientConfig struct { APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func - HTTPClient *http.Client + HTTPClient HTTPDoer EmptyMessagesLimit uint } diff --git a/edits.go b/edits.go index 97d026029..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index c5633a313..efff2c288 100644 --- a/embeddings.go +++ b/embeddings.go @@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(baseReq), + ) if err != nil { return } diff --git a/files.go b/files.go index b40a44f15..26ad6bd70 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..74b47bf3f 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated if err != nil { return } diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 9dcb49de1..5a9f54a92 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -26,7 +26,9 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs any `json:"n_epochs,omitempty"` + Epochs any `json:"n_epochs,omitempty"` + LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` + BatchSize any `json:"batch_size,omitempty"` } type FineTuningJobRequest struct { diff --git a/image.go b/image.go index 665de1a74..577d7db95 100644 --- a/image.go +++ b/image.go @@ -68,7 +68,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -132,8 +137,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -183,8 +193,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..b9d86845b 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -46,3 +46,10 @@ func ErrorIsNotf(t *testing.T, err, target error, format string, msg ...string) t.Fatalf(format, msg) } } + +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} diff --git a/jsonschema/json.go b/jsonschema/json.go index cb941eb75..bcb253fae 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,13 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) type DataType string @@ -29,14 +35,20 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties"` + Properties map[string]Definition `json:"properties,omitempty"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` + // AdditionalProperties is used to control the handling of properties in an object + // that are not explicitly defined in the properties section of the schema. example: + // additionalProperties: true + // additionalProperties: false + // additionalProperties: jsonschema.Definition{Type: jsonschema.String} + AdditionalProperties any `json:"additionalProperties,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -44,6 +56,99 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + return reflectSchema(reflect.TypeOf(v)) +} + +func reflectSchema(t reflect.Type) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem()) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + if jsonTag == "" { + jsonTag = field.Name + } else if strings.HasSuffix(jsonTag, ",omitempty") { + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..f14ffd4c4 --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,89 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +func Validate(schema Definition, data any) bool { + switch schema.Type { + case Object: + return validateObject(schema, data) + case Array: + return validateArray(schema, data) + case String: + _, ok := data.(string) + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + return false + } +} + +func validateObject(schema Definition, data any) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/messages.go b/messages.go index 6af118445..ad91ce3b6 100644 --- a/messages.go +++ b/messages.go @@ -43,8 +43,20 @@ type MessageContent struct { ImageFile *ImageFile `json:"image_file,omitempty"` } type MessageText struct { - Value string `json:"value"` - Annotations []any `json:"annotations"` + Value string `json:"value"` + Annotations []*Annotation `json:"annotations"` +} + +type Annotation struct { + StartIndex int `json:"start_index"` + EndIndex int `json:"end_index"` + FileCitation *FileCitation `json:"file_citation,omitempty"` + Text string `json:"text,omitempty"` + Type string `json:"type,omitempty"` +} + +type FileCitation struct { + FileID string `json:"file_id"` } type ImageFile struct { @@ -67,6 +79,14 @@ type MessageFile struct { httpHeader } +type MessageDeletionStatus struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + type MessageFilesList struct { MessageFiles []MessageFile `json:"data"` @@ -86,6 +106,22 @@ func (c *Client) CreateMessage(ctx context.Context, threadID string, request Mes return } +// DeleteMessage deletes a message.. +func (c *Client) DeleteMessage( + ctx context.Context, + threadID, messageID string, +) (status MessageDeletionStatus, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &status) + return +} + // ListMessage fetches all messages in the thread. func (c *Client) ListMessage(ctx context.Context, threadID string, limit *int, diff --git a/moderation.go b/moderation.go index ae285ef83..c8652efc8 100644 --- a/moderation.go +++ b/moderation.go @@ -88,7 +88,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/run.go b/run.go index 5598f1dfb..ec025c860 100644 --- a/run.go +++ b/run.go @@ -83,12 +83,15 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model string `json:"model,omitempty"` - Instructions string `json:"instructions,omitempty"` - AdditionalInstructions string `json:"additional_instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []*AdditionalMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + + ToolResources map[string]any `json:"tool_resources,omitempty"` // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. @@ -112,6 +115,14 @@ type RunRequest struct { ResponseFormat any `json:"response_format,omitempty"` } +type AdditionalMessage struct { + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. // https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. type ThreadTruncationStrategy struct { @@ -389,6 +400,12 @@ func (c *Client) CreateThreadAndRun( return } +// ChatCompletionStream +// Note: Perhaps it is more elegant to abstract Stream using generics. +type RunCompletionStream struct { + *streamReader[ChatCompletionStreamResponse] +} + // RetrieveRunStep retrieves a run step. func (c *Client) RetrieveRunStep( ctx context.Context, @@ -449,3 +466,141 @@ func (c *Client) ListRunSteps( err = c.sendRequest(req, &response) return } + +type StreamMessageDelta struct { + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIDs []string `json:"file_ids"` +} + +type AssistantStreamEvent struct { + ID string `json:"id"` + Object string `json:"object"` + Delta StreamMessageDelta `json:"delta,omitempty"` + + // Run + CreatedAt int64 `json:"created_at,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + AssistantID string `json:"assistant_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Usage Usage `json:"usage,omitempty"` + + // ThreadMessage.Completed + Role string `json:"role,omitempty"` + Content []MessageContent `json:"content,omitempty"` + // IncompleteDetails + // IncompleteAt + + // Run steps + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + StepDetails StepDetails `json:"step_details"` + ExpiredAt *int64 `json:"expired_at,omitempty"` +} + +type AssistantStream struct { + *streamReader[AssistantStreamEvent] +} + +func (c *Client) CreateThreadAndRunStream( + ctx context.Context, + request CreateThreadAndRunRequest) (stream *StreamerV2, err error) { + type createThreadAndStreamRequest struct { + CreateThreadAndRunRequest + Stream bool `json:"stream"` + } + + urlSuffix := "/threads/runs" + sr := createThreadAndStreamRequest{ + CreateThreadAndRunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(sr), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + if err != nil { + return nil, err + } + + return sendRequestStreamV2(c, req) +} + +type RunRequestStreaming struct { + RunRequest + Stream bool `json:"stream"` +} + +func (c *Client) CreateRunStream( + ctx context.Context, + threadID string, + request RunRequest) (stream *StreamerV2, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + + r := RunRequestStreaming{ + RunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + return sendRequestStreamV2(c, req) +} + +type SubmitToolOutputsRequestStreaming struct { + SubmitToolOutputsRequest + Stream bool `json:"stream"` +} + +// SubmitToolOutputs submits tool outputs. +func (c *Client) SubmitToolOutputsStream( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest) (stream *StreamerV2, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + + r := SubmitToolOutputsRequestStreaming{ + SubmitToolOutputsRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + return sendRequestStreamV2(c, req) +} diff --git a/speech.go b/speech.go index 7e22e755c..20b52e334 100644 --- a/speech.go +++ b/speech.go @@ -2,7 +2,6 @@ package openai import ( "context" - "errors" "net/http" ) @@ -36,11 +35,6 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` @@ -49,33 +43,11 @@ type CreateSpeechRequest struct { Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/speech_test.go b/speech_test.go index d9ba58b13..8cb5012fb 100644 --- a/speech_test.go +++ b/speech_test.go @@ -95,21 +95,5 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) } diff --git a/sse.go b/sse.go new file mode 100644 index 000000000..fe5a5c5f3 --- /dev/null +++ b/sse.go @@ -0,0 +1,165 @@ +package openai + +import ( + "bufio" + "io" + "strconv" + "strings" +) + +// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance. +func NewEOLSplitterFunc() bufio.SplitFunc { + splitter := NewEOLSplitter() + return splitter.Split +} + +// EOLSplitter is the custom split function to handle CR LF, CR, and LF as end-of-line. +type EOLSplitter struct { + prevCR bool +} + +// NewEOLSplitter creates a new EOLSplitter instance. +func NewEOLSplitter() *EOLSplitter { + return &EOLSplitter{prevCR: false} +} + +const crlfLen = 2 + +// Split function to handle CR LF, CR, and LF as end-of-line. +func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Check if the previous data ended with a CR + if s.prevCR { + s.prevCR = false + if len(data) > 0 && data[0] == '\n' { + return 1, nil, nil // Skip the LF following the previous CR + } + } + + // Search for the first occurrence of CR LF, CR, or LF + for i := 0; i < len(data); i++ { + if data[i] == '\r' { + if i+1 < len(data) && data[i+1] == '\n' { + // Found CR LF + return i + crlfLen, data[:i], nil + } + // Found CR + if !atEOF && i == len(data)-1 { + // If CR is the last byte, and not EOF, then need to check if + // the next byte is LF. + // + // save the state and request more data + s.prevCR = true + return 0, nil, nil + } + return i + 1, data[:i], nil + } + if data[i] == '\n' { + // Found LF + return i + 1, data[:i], nil + } + } + + // If at EOF, we have a final, non-terminated line. Return it. + if atEOF && len(data) > 0 { + return len(data), data, nil + } + + // Request more data. + return 0, nil, nil +} + +type ServerSentEvent struct { + ID string // ID of the event + Data string // Data of the event + Event string // Type of the event + Retry int // Retry time in milliseconds + Comment string // Comment +} + +type SSEScanner struct { + scanner *bufio.Scanner + next ServerSentEvent + err error + readComment bool +} + +func NewSSEScanner(r io.Reader, readComment bool) *SSEScanner { + scanner := bufio.NewScanner(r) + + // N.B. The bufio.ScanLines handles `\r?\n``, but not `\r` itself as EOL, as + // the SSE spec requires + // + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream + // + // scanner.Split(bufio.ScanLines) + scanner.Split(NewEOLSplitterFunc()) + + return &SSEScanner{ + scanner: scanner, + readComment: readComment, + } +} + +func (s *SSEScanner) Next() bool { + // Zero the next event before scanning a new one + var event ServerSentEvent + s.next = event + + var dataLines []string + + var seenNonEmptyLine bool + + for s.scanner.Scan() { + line := strings.TrimSpace(s.scanner.Text()) + + if line == "" { + if seenNonEmptyLine { + break + } + + continue + } + + seenNonEmptyLine = true + switch { + case strings.HasPrefix(line, "id: "): + event.ID = strings.TrimPrefix(line, "id: ") + case strings.HasPrefix(line, "data: "): + dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) + case strings.HasPrefix(line, "event: "): + event.Event = strings.TrimPrefix(line, "event: ") + case strings.HasPrefix(line, "retry: "): + retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) + if err == nil { + event.Retry = retry + } + // ignore invalid retry values + case strings.HasPrefix(line, ":"): + if s.readComment { + event.Comment = strings.TrimPrefix(line, ":") + } + // ignore comment line + default: + // ignore unknown lines + } + } + + s.err = s.scanner.Err() + + if !seenNonEmptyLine { + return false + } + + event.Data = strings.Join(dataLines, "\n") + s.next = event + + return true +} + +func (s *SSEScanner) Scan() ServerSentEvent { + return s.next +} + +func (s *SSEScanner) Err() error { + return s.err +} diff --git a/stream.go b/stream.go index b277f3c29..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..433548794 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -16,7 +16,7 @@ var ( ) type streamable interface { - ChatCompletionStreamResponse | CompletionResponse + ChatCompletionStreamResponse | CompletionResponse | AssistantStreamEvent } type streamReader[T streamable] struct { diff --git a/stream_v2.go b/stream_v2.go new file mode 100644 index 000000000..f028766cf --- /dev/null +++ b/stream_v2.go @@ -0,0 +1,298 @@ +package openai + +import ( + "encoding/json" + "io" +) + +type StreamRawEvent struct { + streamEvent + Data json.RawMessage +} + +type StreamDone struct { + streamEvent +} + +type StreamThreadMessageCompleted struct { + Message + streamEvent +} + +type StreamThreadMessageDelta struct { + ID string `json:"id"` + Object string `json:"object"` + Delta Delta `json:"delta"` + + streamEvent +} + +type Delta struct { + // DeltaText | DeltaImageFile + Content []DeltaContent `json:"content"` +} + +type DeltaContent struct { + Index int `json:"index"` + Type string `json:"type"` + + Text *DeltaText `json:"text"` + ImageFile *DeltaImageFile `json:"image_file"` + ImageURL *DeltaImageURL `json:"image_url"` +} + +type DeltaText struct { + Value string `json:"value"` + // Annotations []any `json:"annotations"` +} + +type DeltaImageFile struct { + FileID string `json:"file_id"` + Detail string `json:"detail"` +} + +type DeltaImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + +func NewStreamerV2(r io.Reader) *StreamerV2 { + var rc io.ReadCloser + + if closer, ok := r.(io.ReadCloser); ok { + rc = closer + } else { + rc = io.NopCloser(r) + } + + return &StreamerV2{ + readCloser: rc, + scanner: NewSSEScanner(r, false), + } +} + +type StreamerV2 struct { + // readCloser is only used for closing the stream + readCloser io.ReadCloser + + scanner *SSEScanner + next StreamEvent + + // buffer for implementing io.Reader + buffer []byte +} + +// TeeSSE tees the stream data with a io.TeeReader +func (s *StreamerV2) TeeSSE(w io.Writer) { + // readCloser is a helper struct that implements io.ReadCloser by combining an io.Reader and an io.Closer + type readCloser struct { + io.Reader + io.Closer + } + + s.readCloser = &readCloser{ + Reader: io.TeeReader(s.readCloser, w), + Closer: s.readCloser, + } + + s.scanner = NewSSEScanner(s.readCloser, false) +} + +// Close closes the underlying io.ReadCloser. +func (s *StreamerV2) Close() error { + return s.readCloser.Close() +} + +type StreamThreadCreated struct { + Thread + streamEvent +} + +type StreamThreadRunCreated struct { + Run + streamEvent +} + +type StreamThreadRunRequiresAction struct { + Run + streamEvent +} + +type StreamThreadRunCompleted struct { + Run + streamEvent +} + +type StreamRunStepCompleted struct { + RunStep + streamEvent +} + +type StreamEvent interface { + Event() string + JSON() json.RawMessage +} + +type streamEvent struct { + event string + data json.RawMessage +} + +// Event returns the event name +func (s *streamEvent) Event() string { + return s.event +} + +// JSON returns the raw JSON data +func (s *streamEvent) JSON() json.RawMessage { + return s.data +} + +func (s *StreamerV2) Next() bool { + if !s.scanner.Next() { + return false + } + + event := s.scanner.Scan() + + streamEvent := streamEvent{ + event: event.Event, + data: json.RawMessage(event.Data), + } + + switch event.Event { + case "thread.created": + var thread Thread + if err := json.Unmarshal([]byte(event.Data), &thread); err == nil { + s.next = &StreamThreadCreated{ + Thread: thread, + streamEvent: streamEvent, + } + } + case "thread.run.created": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunCreated{ + Run: run, + streamEvent: streamEvent, + } + } + + case "thread.run.requires_action": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunRequiresAction{ + Run: run, + streamEvent: streamEvent, + } + } + case "thread.run.completed": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunCompleted{ + Run: run, + streamEvent: streamEvent, + } + } + case "thread.message.delta": + var delta StreamThreadMessageDelta + if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { + delta.streamEvent = streamEvent + s.next = &delta + } + case "thread.run.step.completed": + var runStep RunStep + if err := json.Unmarshal([]byte(event.Data), &runStep); err == nil { + s.next = &StreamRunStepCompleted{ + RunStep: runStep, + streamEvent: streamEvent, + } + } + case "thread.message.completed": + var msg Message + if err := json.Unmarshal([]byte(event.Data), &msg); err == nil { + s.next = &StreamThreadMessageCompleted{ + Message: msg, + streamEvent: streamEvent, + } + } + case "done": + streamEvent.data = nil + s.next = &StreamDone{ + streamEvent: streamEvent, + } + default: + s.next = &StreamRawEvent{ + streamEvent: streamEvent, + } + } + + return true +} + +// Read implements io.Reader of the text deltas of thread.message.delta events. +func (s *StreamerV2) Read(p []byte) (int, error) { + // If we have data in the buffer, copy it to p first. + if len(s.buffer) > 0 { + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] + return n, nil + } + + for s.Next() { + // Read only text deltas + text, ok := s.MessageDeltaText() + if !ok { + continue + } + + s.buffer = []byte(text) + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] + return n, nil + } + + // Check for streamer error + if err := s.Err(); err != nil { + return 0, err + } + + return 0, io.EOF +} + +func (s *StreamerV2) Event() StreamEvent { + return s.next +} + +// Text returns text delta if the current event is a "thread.message.delta". Alias of MessageDeltaText. +func (s *StreamerV2) Text() (string, bool) { + return s.MessageDeltaText() +} + +func (s *StreamerV2) Done() bool { + _, ok := s.next.(*StreamDone) + return ok +} + +// MessageDeltaText returns text delta if the current event is a "thread.message.delta". +func (s *StreamerV2) MessageDeltaText() (string, bool) { + event, ok := s.next.(*StreamThreadMessageDelta) + if !ok { + return "", false + } + + var text string + for _, content := range event.Delta.Content { + if content.Text != nil { + // Can we return the first text we find? Does OpenAI stream ever + // return multiple text contents in a delta? + text += content.Text.Value + } + } + + return text, true +} + +func (s *StreamerV2) Err() error { + return s.scanner.Err() +} diff --git a/thread.go b/thread.go index 900e3f2ea..6b86a5750 100644 --- a/thread.go +++ b/thread.go @@ -10,34 +10,97 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` +} +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` } +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string const ( - ThreadMessageRoleUser ThreadMessageRole = "user" + ThreadMessageRoleAssistant ThreadMessageRole = "assistant" + ThreadMessageRoleUser ThreadMessageRole = "user" ) type ThreadMessage struct { - Role ThreadMessageRole `json:"role"` - Content string `json:"content"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadAttachment struct { + FileID string `json:"file_id"` + Tools []ThreadAttachmentTool `json:"tools"` +} + +type ThreadAttachmentTool struct { + Type string `json:"type"` } type ThreadDeleteResponse struct { diff --git a/vector.go b/vector.go new file mode 100644 index 000000000..5c364362a --- /dev/null +++ b/vector.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +}