diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 000000000..81773666c --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,4 @@ +coverage: + ignore: + - "examples/**" + - "internal/test/**" diff --git a/.github/workflows/close-inactive-issues.yml b/.github/workflows/close-inactive-issues.yml index bfe9b5c96..32723c4e9 100644 --- a/.github/workflows/close-inactive-issues.yml +++ b/.github/workflows/close-inactive-issues.yml @@ -10,7 +10,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v9 with: days-before-issue-stale: 30 days-before-issue-close: 14 diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 000000000..7260b00b4 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,21 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} + run: go test -v -tags=integration ./api_integration_test.go diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 8df721f0f..2c9730656 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -9,19 +9,21 @@ jobs: name: Sanity check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v5 with: - go-version: '1.19' + go-version: '1.24' - name: Run vet run: | - go vet . + go vet -stdversion ./... - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v7 with: - version: latest + version: v2.1.5 - name: Run tests - run: go test -race -covermode=atomic -coverprofile=coverage.out -v . + run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./... - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 99b40bf17..b0ac1605c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,7 @@ # Auth token for tests .openai-token -.idea \ No newline at end of file +.idea + +# Generated by tests +test.mp3 \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index 58fab4a20..6391ad76f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,272 +1,168 @@ -## Golden config for golangci-lint v1.47.3 -# -# This is the best config for golangci-lint based on my experience and opinion. -# It is very strict, but not extremely strict. -# Feel free to adopt and change it for your needs. - -run: - # Timeout for analysis, e.g. 30s, 5m. - # Default: 1m - timeout: 3m - - -# This file contains only configs which differ from defaults. -# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml -linters-settings: - cyclop: - # The maximal code complexity to report. - # Default: 10 - max-complexity: 30 - # The maximal average package complexity. - # If it's higher than 0.0 (float) the check is enabled - # Default: 0.0 - package-average: 10.0 - - errcheck: - # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. - # Such cases aren't reported by default. - # Default: false - check-type-assertions: true - - funlen: - # Checks the number of lines in a function. - # If lower than 0, disable the check. - # Default: 60 - lines: 100 - # Checks the number of statements in a function. - # If lower than 0, disable the check. - # Default: 40 - statements: 50 - - gocognit: - # Minimal code complexity to report - # Default: 30 (but we recommend 10-20) - min-complexity: 20 - - gocritic: - # Settings passed to gocritic. - # The settings key is the name of a supported gocritic checker. - # The list of supported checkers can be find in https://go-critic.github.io/overview. - settings: - captLocal: - # Whether to restrict checker to params only. - # Default: true - paramsOnly: false - underef: - # Whether to skip (*x).method() calls where x is a pointer receiver. - # Default: true - skipRecvDeref: false - - gomnd: - # List of function patterns to exclude from analysis. - # Values always ignored: `time.Date` - # Default: [] - ignored-functions: - - os.Chmod - - os.Mkdir - - os.MkdirAll - - os.OpenFile - - os.WriteFile - - prometheus.ExponentialBuckets - - prometheus.ExponentialBucketsRange - - prometheus.LinearBuckets - - strconv.FormatFloat - - strconv.FormatInt - - strconv.FormatUint - - strconv.ParseFloat - - strconv.ParseInt - - strconv.ParseUint - - gomodguard: - blocked: - # List of blocked modules. - # Default: [] - modules: - - github.com/golang/protobuf: - recommendations: - - google.golang.org/protobuf - reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" - - github.com/satori/go.uuid: - recommendations: - - github.com/google/uuid - reason: "satori's package is not maintained" - - github.com/gofrs/uuid: - recommendations: - - github.com/google/uuid - reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw" - - govet: - # Enable all analyzers. - # Default: false - enable-all: true - # Disable analyzers by name. - # Run `go tool vet help` to see all analyzers. - # Default: [] - disable: - - fieldalignment # too strict - # Settings per analyzer. - settings: - shadow: - # Whether to be strict about shadowing; can be noisy. - # Default: false - strict: true - - nakedret: - # Make an issue if func has more lines of code than this setting, and it has naked returns. - # Default: 30 - max-func-lines: 0 - - nolintlint: - # Exclude following linters from requiring an explanation. - # Default: [] - allow-no-explanation: [ funlen, gocognit, lll ] - # Enable to require an explanation of nonzero length after each nolint directive. - # Default: false - require-explanation: true - # Enable to require nolint directives to mention the specific linter being suppressed. - # Default: false - require-specific: true - - rowserrcheck: - # database/sql is always checked - # Default: [] - packages: - - github.com/jmoiron/sqlx - - tenv: - # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. - # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. - # Default: false - all: true - - varcheck: - # Check usage of exported fields and variables. - # Default: false - exported-fields: false # default false # TODO: enable after fixing false positives - - +version: "2" linters: - disable-all: true + default: none enable: - ## enabled by default - - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - - gosimple # Linter for Go source code that specializes in simplifying a code - - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - - ineffassign # Detects when assignments to existing variables are not used - - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - - unused # Checks Go code for unused constants, variables, functions and types - ## disabled by default - # - asasalint # Check for pass []any as any in variadic func(...any) - - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - - bidichk # Checks for dangerous unicode character sequences - - bodyclose # checks whether HTTP response body is closed successfully - - contextcheck # check the function whether use a non-inherited context - - cyclop # checks function and package cyclomatic complexity - - dupl # Tool for code clone detection - - durationcheck # check for two durations multiplied together - - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error. - - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - - execinquery # execinquery is a linter about query string checker in Query function which reads your Go src files and warning it finds - - exhaustive # check exhaustiveness of enum switch statements - - exportloopref # checks for pointers to enclosing loop variables - - forbidigo # Forbids identifiers - - funlen # Tool for detection of long functions - # - gochecknoglobals # check that no global variables exist - - gochecknoinits # Checks that no init functions are present in Go code - - gocognit # Computes and checks the cognitive complexity of functions - - goconst # Finds repeated strings that could be replaced by a constant - - gocritic # Provides diagnostics that check for bugs, performance and style issues. - - gocyclo # Computes and checks the cyclomatic complexity of functions - - godot # Check if comments end in a period - - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt. - - gomnd # An analyzer to detect magic numbers. - - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - - goprintffuncname # Checks that printf-like functions are named with f at the end - - gosec # Inspects source code for security problems - - lll # Reports long lines - - makezero # Finds slice declarations with non-zero initial length - # - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements - - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - - nilnil # Checks that there is no simultaneous return of nil error and an invalid value. - # - noctx # noctx finds sending http request without context.Context - - nolintlint # Reports ill-formed or insufficient nolint directives - # - nonamedreturns # Reports all named returns - - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL. - - predeclared # find code that shadows one of Go's predeclared identifiers - - promlinter # Check Prometheus metrics naming via promlint - - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - - rowserrcheck # checks whether Err of rows is checked successfully - - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - - stylecheck # Stylecheck is a replacement for golint - - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - testpackage # linter that makes you use a separate _test package - - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - - unconvert # Remove unnecessary type conversions - - unparam # Reports unused function parameters - - wastedassign # wastedassign finds wasted assignment statements. - - whitespace # Tool for detection of leading and trailing whitespace - ## you may want to enable - #- decorder # check declaration order and count of types, constants, variables and functions - #- exhaustruct # Checks if all structure fields are initialized - #- goheader # Checks is file header matches to pattern - #- ireturn # Accept Interfaces, Return Concrete Types - #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated - #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope - #- wrapcheck # Checks that errors returned from external packages are wrapped - ## disabled - #- containedctx # containedctx is a linter that detects struct contained context.Context field - #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages - #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted. - #- forcetypeassert # [replaced by errcheck] finds forced type assertions - #- gci # Gci controls golang package import order and makes it always deterministic. - #- godox # Tool for detection of FIXME, TODO and other comment keywords - #- goerr113 # [too strict] Golang linter to check the errors handling expressions - #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed. - #- grouper # An analyzer to analyze expression groups. - #- ifshort # Checks that your code uses short syntax for if-statements whenever possible - #- importas # Enforces consistent import aliases - #- maintidx # maintidx measures the maintainability index of each function. - #- misspell # [useless] Finds commonly misspelled English words in comments - #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity - #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14 - #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test - #- tagliatelle # Checks the struct tags. - #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines! - ## deprecated - #- exhaustivestruct # [deprecated, replaced by exhaustruct] Checks if all struct's fields are initialized - #- golint # [deprecated, replaced by revive] Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes - #- interfacer # [deprecated] Linter that suggests narrower interface types - #- maligned # [deprecated, replaced by govet fieldalignment] Tool to detect Go structs that would take less memory if their fields were sorted - #- scopelint # [deprecated, replaced by exportloopref] Scopelint checks for unpinned variables in go programs - - + - asciicheck + - bidichk + - bodyclose + - contextcheck + - cyclop + - dupl + - durationcheck + - errcheck + - errname + - errorlint + - exhaustive + - forbidigo + - funlen + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godot + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - govet + - ineffassign + - lll + - makezero + - mnd + - nestif + - nilerr + - nilnil + - nolintlint + - nosprintfhostport + - predeclared + - promlinter + - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - testpackage + - tparallel + - unconvert + - unparam + - unused + - usetesting + - wastedassign + - whitespace + settings: + cyclop: + max-complexity: 30 + package-average: 10 + errcheck: + check-type-assertions: true + funlen: + lines: 100 + statements: 50 + gocognit: + min-complexity: 20 + gocritic: + settings: + captLocal: + paramsOnly: false + underef: + skipRecvDeref: false + gomodguard: + blocked: + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: satori's package is not maintained + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw' + govet: + disable: + - fieldalignment + enable-all: true + settings: + shadow: + strict: true + mnd: + ignored-functions: + - os.Chmod + - os.Mkdir + - os.MkdirAll + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets + - prometheus.ExponentialBucketsRange + - prometheus.LinearBuckets + - strconv.FormatFloat + - strconv.FormatInt + - strconv.FormatUint + - strconv.ParseFloat + - strconv.ParseInt + - strconv.ParseUint + nakedret: + max-func-lines: 0 + nolintlint: + require-explanation: true + require-specific: true + allow-no-explanation: + - funlen + - gocognit + - lll + rowserrcheck: + packages: + - github.com/jmoiron/sqlx + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - forbidigo + - mnd + - revive + path : ^examples/.*\.go$ + - linters: + - lll + source: ^//\s*go:generate\s + - linters: + - godot + source: (noinspection|TODO) + - linters: + - gocritic + source: //noinspection + - linters: + - errorlint + source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok { + - linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck + - staticcheck + path: _test\.go + paths: + - third_party$ + - builtin$ + - examples$ issues: - # Maximum count of issues with the same text. - # Set to 0 to disable. - # Default: 3 max-same-issues: 50 - - exclude-rules: - - source: "^//\\s*go:generate\\s" - linters: [ lll ] - - source: "(noinspection|TODO)" - linters: [ godot ] - - source: "//noinspection" - linters: [ gocritic ] - - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {" - linters: [ errorlint ] - - path: "_test\\.go" - linters: - - bodyclose - - dupl - - funlen - - goconst - - gosec - - noctx - - wrapcheck +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/Makefile b/Makefile deleted file mode 100644 index 2e608aa0c..000000000 --- a/Makefile +++ /dev/null @@ -1,35 +0,0 @@ -##@ General - -# The help target prints out all targets with their descriptions organized -# beneath their categories. The categories are represented by '##@' and the -# target descriptions by '##'. The awk commands is responsible for reading the -# entire set of makefiles included in this invocation, looking for lines of the -# file as xyz: ## something, and then pretty-format the target and help. Then, -# if there's a line with ##@ something, that gets pretty-printed as a category. -# More info on the usage of ANSI control characters for terminal formatting: -# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters -# More info on the awk command: -# http://linuxcommand.org/lc3_adv_awk.php - -.PHONY: help -help: ## Display this help. - @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) - - -##@ Development - -.PHONY: test -TEST_ARGS ?= -v -TEST_TARGETS ?= ./... -test: ## Test the Go modules within this package. - @ echo ▶️ go test $(TEST_ARGS) $(TEST_TARGETS) - go test $(TEST_ARGS) $(TEST_TARGETS) - @ echo ✅ success! - - -.PHONY: lint -LINT_TARGETS ?= ./... -lint: ## Lint Go code with the installed golangci-lint - @ echo "▶️ golangci-lint run" - golangci-lint run $(LINT_TARGETS) - @ echo "✅ golangci-lint run" diff --git a/README.md b/README.md index 9a479c0a0..77b85e519 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.openai.com/). We support: -* ChatGPT +* ChatGPT 4o, o1 * GPT-3, GPT-4 -* DALL·E 2 +* DALL·E 2, DALL·E 3, GPT Image 1 * Whisper ## Installation @@ -141,7 +141,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", } @@ -174,7 +174,7 @@ func main() { ctx := context.Background() req := openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, @@ -357,6 +357,66 @@ func main() { ``` +
+GPT Image 1 image generation + +```go +package main + +import ( + "context" + "encoding/base64" + "fmt" + "os" + + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.ImageRequest{ + Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.", + Background: openai.CreateImageBackgroundOpaque, + Model: openai.CreateImageModelGptImage1, + Size: openai.CreateImageSize1024x1024, + N: 1, + Quality: openai.CreateImageQualityLow, + OutputCompression: 100, + OutputFormat: openai.CreateImageOutputFormatJPEG, + // Moderation: openai.CreateImageModerationLow, + // User: "", + } + + resp, err := c.CreateImage(ctx, req) + if err != nil { + fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err) + return + } + + fmt.Println("Image Base64:", resp.Data[0].B64JSON) + + // Decode the base64 data + imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) + if err != nil { + fmt.Printf("Base64 decode error: %v\n", err) + return + } + + // Write image to file + outputPath := "generated_image.jpg" + err = os.WriteFile(outputPath, imgBytes, 0644) + if err != nil { + fmt.Printf("Failed to write image file: %v\n", err) + return + } + + fmt.Printf("The image was saved as %s\n", outputPath) +} +``` +
+
Configuring proxy @@ -636,7 +696,7 @@ FunctionDefinition{ }, "unit": { Type: jsonschema.String, - Enum: []string{"celcius", "fahrenheit"}, + Enum: []string{"celsius", "fahrenheit"}, }, }, Required: []string{"location"}, @@ -743,6 +803,70 @@ func main() { } ```
+ +
+Structured Outputs + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" +) + +func main() { + client := openai.NewClient("your token") + ctx := context.Background() + + type Result struct { + Steps []struct { + Explanation string `json:"explanation"` + Output string `json:"output"` + } `json:"steps"` + FinalAnswer string `json:"final_answer"` + } + var result Result + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + log.Fatalf("GenerateSchemaForType error: %v", err) + } + resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "You are a helpful math tutor. Guide the user through the solution step by step.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "how can I solve 8x + 7 = -23", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "math_reasoning", + Schema: schema, + Strict: true, + }, + }, + }) + if err != nil { + log.Fatalf("CreateChatCompletion error: %v", err) + } + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + if err != nil { + log.Fatalf("Unmarshal schema error: %v", err) + } + fmt.Println(result) +} +``` +
See the `examples/` folder for more. ## Frequently Asked Questions diff --git a/api_integration_test.go b/api_integration_test.go index 736040c50..7828d9451 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -4,6 +4,7 @@ package openai_test import ( "context" + "encoding/json" "errors" "io" "os" @@ -26,7 +27,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +43,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +78,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +110,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { @@ -168,3 +179,136 @@ func TestAPIError(t *testing.T) { t.Fatal("Empty error message occurred") } } + +func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + var result MyStructuredResponse + schema, err := jsonschema.GenerateSchemaForType(result) + if err != nil { + t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error") + } + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + ResponseFormat: &openai.ChatCompletionResponseFormat{ + Type: openai.ChatCompletionResponseFormatTypeJSONSchema, + JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{ + Name: "cases", + Schema: schema, + Strict: true, + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error") + if err == nil { + err = schema.Unmarshal(resp.Choices[0].Message.Content, &result) + checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error") + } +} + +func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + var err error + c := openai.NewClient(apiToken) + ctx := context.Background() + + resp, err := c.CreateChatCompletion( + ctx, + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleSystem, + Content: "Please enter a string, and we will convert it into the following naming conventions:" + + "1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." + + "2. CamelCase: The first word starts with a lowercase letter, " + + "and subsequent words start with an uppercase letter, with no spaces or separators." + + "3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." + + "4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.", + }, + { + Role: openai.ChatMessageRoleUser, + Content: "Hello World", + }, + }, + Tools: []openai.Tool{ + { + Type: openai.ToolTypeFunction, + Function: &openai.FunctionDefinition{ + Name: "display_cases", + Strict: true, + Parameters: &jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "PascalCase": { + Type: jsonschema.String, + }, + "CamelCase": { + Type: jsonschema.String, + }, + "KebabCase": { + Type: jsonschema.String, + }, + "SnakeCase": { + Type: jsonschema.String, + }, + }, + Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"}, + AdditionalProperties: false, + }, + }, + }, + }, + ToolChoice: openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{ + Name: "display_cases", + }, + }, + }, + ) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error") + var result = make(map[string]string) + err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result) + checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error") + for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} { + if _, ok := result[key]; !ok { + t.Errorf("key:%s does not exist.", key) + } + } +} diff --git a/api_internal_test.go b/api_internal_test.go index 0fb0f8993..09677968a 100644 --- a/api_internal_test.go +++ b/api_internal_test.go @@ -112,6 +112,7 @@ func TestAzureFullURL(t *testing.T) { Name string BaseURL string AzureModelMapper map[string]string + Suffix string Model string Expect string }{ @@ -119,6 +120,7 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithSlashAutoStrip", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + @@ -128,11 +130,20 @@ func TestAzureFullURL(t *testing.T) { "AzureBaseURLWithoutSlashOK", "https://httpbin.org/", nil, + "/chat/completions", "chatgpt-demo", "https://httpbin.org/" + "openai/deployments/chatgpt-demo" + "/chat/completions?api-version=2023-05-15", }, + { + "", + "https://httpbin.org/", + nil, + "/assistants?limit=10", + "chatgpt-demo", + "https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10", + }, } for _, c := range cases { @@ -140,7 +151,46 @@ func TestAzureFullURL(t *testing.T) { az := DefaultAzureConfig("dummy", c.BaseURL) cli := NewClientWithConfig(az) // /openai/deployments/{engine}/chat/completions?api-version={api_version} - actual := cli.fullURL("/chat/completions", c.Model) + actual := cli.fullURL(c.Suffix, withModel(c.Model)) + if actual != c.Expect { + t.Errorf("Expected %s, got %s", c.Expect, actual) + } + t.Logf("Full URL: %s", actual) + }) + } +} + +func TestCloudflareAzureFullURL(t *testing.T) { + cases := []struct { + Name string + BaseURL string + Suffix string + Expect string + }{ + { + "CloudflareAzureBaseURLWithSlashAutoStrip", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/", + "/chat/completions", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" + + "chat/completions?api-version=2023-05-15", + }, + { + "", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo", + "/assistants?limit=10", + "https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" + + "/assistants?api-version=2023-05-15&limit=10", + }, + } + + for _, c := range cases { + t.Run(c.Name, func(t *testing.T) { + az := DefaultAzureConfig("dummy", c.BaseURL) + az.APIType = APITypeCloudflareAzure + + cli := NewClientWithConfig(az) + + actual := cli.fullURL(c.Suffix) if actual != c.Expect { t.Errorf("Expected %s, got %s", c.Expect, actual) } diff --git a/assistant.go b/assistant.go index 4ca2dda62..8aab5bcf0 100644 --- a/assistant.go +++ b/assistant.go @@ -11,20 +11,23 @@ import ( const ( assistantsSuffix = "/assistants" assistantsFilesSuffix = "/files" - openaiAssistantsV1 = "assistants=v1" ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` // Deprecated in v2 + Metadata map[string]any `json:"metadata,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` httpHeader } @@ -35,6 +38,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -42,19 +46,36 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` + ResponseFormat any `json:"response_format,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases @@ -116,7 +137,7 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -132,7 +153,7 @@ func (c *Client) RetrieveAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -149,7 +170,7 @@ func (c *Client) ModifyAssistant( ) (response Assistant, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -165,7 +186,7 @@ func (c *Client) DeleteAssistant( ) (response AssistantDeleteResponse, err error) { urlSuffix := fmt.Sprintf("%s/%s", assistantsSuffix, assistantID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -181,7 +202,7 @@ func (c *Client) ListAssistants( order *string, after *string, before *string, -) (reponse AssistantsList, err error) { +) (response AssistantsList, err error) { urlValues := url.Values{} if limit != nil { urlValues.Add("limit", fmt.Sprintf("%d", *limit)) @@ -203,12 +224,12 @@ func (c *Client) ListAssistants( urlSuffix := fmt.Sprintf("%s%s", assistantsSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } - err = c.sendRequest(req, &reponse) + err = c.sendRequest(req, &response) return } @@ -221,7 +242,7 @@ func (c *Client) CreateAssistantFile( urlSuffix := fmt.Sprintf("%s/%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -238,7 +259,7 @@ func (c *Client) RetrieveAssistantFile( ) (response AssistantFile, err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -255,7 +276,7 @@ func (c *Client) DeleteAssistantFile( ) (err error) { urlSuffix := fmt.Sprintf("%s/%s%s/%s", assistantsSuffix, assistantID, assistantsFilesSuffix, fileID) req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -294,7 +315,7 @@ func (c *Client) ListAssistantFiles( urlSuffix := fmt.Sprintf("%s/%s%s%s", assistantsSuffix, assistantID, assistantsFilesSuffix, encodedValues) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/audio.go b/audio.go index 4cbe4fe64..f321f93d6 100644 --- a/audio.go +++ b/audio.go @@ -27,8 +27,14 @@ const ( AudioResponseFormatVTT AudioResponseFormat = "vtt" ) +type TranscriptionTimestampGranularity string + +const ( + TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" + TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" +) + // AudioRequest represents a request structure for audio API. -// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { Model string @@ -38,10 +44,11 @@ type AudioRequest struct { // Reader is an optional io.Reader when you do not want to use an existing file. Reader io.Reader - Prompt string // For translation, it should be in English - Temperature float32 - Language string // For translation, just do not use it. It seems "en" works, not confirmed... - Format AudioResponseFormat + Prompt string + Temperature float32 + Language string // Only for transcription. + Format AudioResponseFormat + TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. } // AudioResponse represents a response structure for audio API. @@ -62,6 +69,11 @@ type AudioResponse struct { NoSpeechProb float64 `json:"no_speech_prob"` Transient bool `json:"transient"` } `json:"segments"` + Words []struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` + } `json:"words"` Text string `json:"text"` httpHeader @@ -110,8 +122,13 @@ func (c *Client) callAudioAPI( } urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), - withBody(&formBody), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(&formBody), + withContentType(builder.FormDataContentType()), + ) if err != nil { return AudioResponse{}, err } @@ -179,6 +196,15 @@ func audioMultipartForm(request AudioRequest, b utils.FormBuilder) error { } } + if len(request.TimestampGranularities) > 0 { + for _, tg := range request.TimestampGranularities { + err = b.WriteField("timestamp_granularities[]", string(tg)) + if err != nil { + return fmt.Errorf("writing timestamp_granularities[]: %w", err) + } + } + } + // Close the multipart writer return b.Close() } diff --git a/audio_api_test.go b/audio_api_test.go index a0efc7921..6c6a35643 100644 --- a/audio_api_test.go +++ b/audio_api_test.go @@ -40,12 +40,9 @@ func TestAudio(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ @@ -90,12 +87,9 @@ func TestAudioWithOptionalArgs(t *testing.T) { ctx := context.Background() - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := openai.AudioRequest{ @@ -105,6 +99,10 @@ func TestAudioWithOptionalArgs(t *testing.T) { Temperature: 0.5, Language: "zh", Format: openai.AudioResponseFormatSRT, + TimestampGranularities: []openai.TranscriptionTimestampGranularity{ + openai.TranscriptionTimestampGranularitySegment, + openai.TranscriptionTimestampGranularityWord, + }, } _, err := tc.createFn(ctx, req) checks.NoError(t, err, "audio API error") diff --git a/audio_test.go b/audio_test.go index 5346244c8..51b3f465d 100644 --- a/audio_test.go +++ b/audio_test.go @@ -2,20 +2,22 @@ package openai //nolint:testpackage // testing private field import ( "bytes" + "context" + "errors" "fmt" "io" + "net/http" "os" "path/filepath" "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestAudioWithFailingFormBuilder(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ @@ -24,6 +26,10 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { Temperature: 0.5, Language: "en", Format: AudioResponseFormatSRT, + TimestampGranularities: []TranscriptionTimestampGranularity{ + TranscriptionTimestampGranularitySegment, + TranscriptionTimestampGranularityWord, + }, } mockFailedErr := fmt.Errorf("mock form builder fail") @@ -47,7 +53,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { return nil } - failOn := []string{"model", "prompt", "temperature", "language", "response_format"} + failOn := []string{"model", "prompt", "temperature", "language", "response_format", "timestamp_granularities[]"} for _, failingField := range failOn { failForField = failingField mockFailedErr = fmt.Errorf("mock form builder fail on field %s", failingField) @@ -59,9 +65,7 @@ func TestAudioWithFailingFormBuilder(t *testing.T) { func TestCreateFileField(t *testing.T) { t.Run("createFileField failing file", func(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) req := AudioRequest{ @@ -107,3 +111,131 @@ func TestCreateFileField(t *testing.T) { checks.HasError(t, err, "createFileField using file should return error when open file fails") }) } + +// failingFormBuilder always returns an error when creating form files. +type failingFormBuilder struct{ err error } + +func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error { + return f.err +} + +func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error { + return f.err +} + +func (f *failingFormBuilder) WriteField(_, _ string) error { + return nil +} + +func (f *failingFormBuilder) Close() error { + return nil +} + +func (f *failingFormBuilder) FormDataContentType() string { + return "multipart/form-data" +} + +// failingAudioRequestBuilder simulates an error during HTTP request construction. +type failingAudioRequestBuilder struct{ err error } + +func (f *failingAudioRequestBuilder) Build( + _ context.Context, + _, _ string, + _ any, + _ http.Header, +) (*http.Request, error) { + return nil, f.err +} + +// errorHTTPClient always returns an error when making HTTP calls. +type errorHTTPClient struct{ err error } + +func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) { + return nil, e.err +} + +func TestCallAudioAPIMultipartFormError(t *testing.T) { + client := NewClient("test-token") + errForm := errors.New("mock create form file failure") + // Override form builder to force an error during multipart form creation. + client.createFormBuilder = func(_ io.Writer) utils.FormBuilder { + return &failingFormBuilder{err: errForm} + } + + // Provide a reader so createFileField uses the reader path (no file open). + req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errForm) { + t.Errorf("expected error %v, got %v", errForm, err) + } +} + +func TestCallAudioAPINewRequestError(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errBuild := errors.New("mock build failure") + client.requestBuilder = &failingAudioRequestBuilder{err: errBuild} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errBuild) { + t.Errorf("expected error %v, got %v", errBuild, err) + } +} + +func TestCallAudioAPISendRequestErrorJSON(t *testing.T) { + client := NewClient("test-token") + // Create a real temp file so multipart form succeeds. + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + // Override HTTP client to simulate a network error. + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + req := AudioRequest{FilePath: path, Model: Whisper1} + _, err := client.callAudioAPI(context.Background(), req, "transcriptions") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} + +func TestCallAudioAPISendRequestErrorText(t *testing.T) { + client := NewClient("test-token") + tmp := t.TempDir() + path := filepath.Join(tmp, "file.mp3") + if err := os.WriteFile(path, []byte("content"), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + + errHTTP := errors.New("mock HTTPClient failure") + client.config.HTTPClient = &errorHTTPClient{err: errHTTP} + + // Use a non-JSON response format to exercise the text path. + req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText} + _, err := client.callAudioAPI(context.Background(), req, "translations") + if err == nil { + t.Fatal("expected error but got none") + } + if !errors.Is(err, errHTTP) { + t.Errorf("expected error %v, got %v", errHTTP, err) + } +} diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..3c1a9d0d7 --- /dev/null +++ b/batch.go @@ -0,0 +1,271 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data []struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 000000000..f4714f4eb --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/chat.go b/chat.go index efb14fd4c..0aa018715 100644 --- a/chat.go +++ b/chat.go @@ -5,6 +5,8 @@ import ( "encoding/json" "errors" "net/http" + + "github.com/sashabaranov/go-openai/jsonschema" ) // Chat message role defined by the OpenAI API. @@ -14,6 +16,7 @@ const ( ChatMessageRoleAssistant = "assistant" ChatMessageRoleFunction = "function" ChatMessageRoleTool = "tool" + ChatMessageRoleDeveloper = "developer" ) const chatCompletionsSuffix = "/chat/completions" @@ -41,11 +44,23 @@ type Violence struct { Severity string `json:"severity,omitempty"` } +type JailBreak struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + +type Profanity struct { + Filtered bool `json:"filtered"` + Detected bool `json:"detected"` +} + type ContentFilterResults struct { - Hate Hate `json:"hate,omitempty"` - SelfHarm SelfHarm `json:"self_harm,omitempty"` - Sexual Sexual `json:"sexual,omitempty"` - Violence Violence `json:"violence,omitempty"` + Hate Hate `json:"hate,omitempty"` + SelfHarm SelfHarm `json:"self_harm,omitempty"` + Sexual Sexual `json:"sexual,omitempty"` + Violence Violence `json:"violence,omitempty"` + JailBreak JailBreak `json:"jailbreak,omitempty"` + Profanity Profanity `json:"profanity,omitempty"` } type PromptAnnotation struct { @@ -81,7 +96,8 @@ type ChatMessagePart struct { type ChatCompletionMessage struct { Role string `json:"role"` - Content string `json:"content"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` MultiContent []ChatMessagePart // This property isn't in the official documentation, but it's in @@ -90,6 +106,12 @@ type ChatCompletionMessage struct { // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb Name string `json:"name,omitempty"` + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. @@ -105,50 +127,60 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { } if len(m.MultiContent) > 0 { msg := struct { - Role string `json:"role"` - Content string `json:"-"` - MultiContent []ChatMessagePart `json:"content,omitempty"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"-"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } + msg := struct { - Role string `json:"role"` - Content string `json:"content"` - MultiContent []ChatMessagePart `json:"-"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"-"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }(m) return json.Marshal(msg) } func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { msg := struct { - Role string `json:"role"` - Content string `json:"content"` - MultiContent []ChatMessagePart - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} + if err := json.Unmarshal(bs, &msg); err == nil { *m = ChatCompletionMessage(msg) return nil } multiMsg := struct { - Role string `json:"role"` - Content string - MultiContent []ChatMessagePart `json:"content"` - Name string `json:"name,omitempty"` - FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string + Refusal string `json:"refusal,omitempty"` + MultiContent []ChatMessagePart `json:"content"` + Name string `json:"name,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` }{} if err := json.Unmarshal(bs, &multiMsg); err != nil { return err @@ -160,7 +192,7 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error { type ToolCall struct { // Index is not nil only in chat completion chunk object Index *int `json:"index,omitempty"` - ID string `json:"id"` + ID string `json:"id,omitempty"` Type ToolType `json:"type"` Function FunctionCall `json:"function"` } @@ -175,27 +207,79 @@ type ChatCompletionResponseFormatType string const ( ChatCompletionResponseFormatTypeJSONObject ChatCompletionResponseFormatType = "json_object" + ChatCompletionResponseFormatTypeJSONSchema ChatCompletionResponseFormatType = "json_schema" ChatCompletionResponseFormatTypeText ChatCompletionResponseFormatType = "text" ) type ChatCompletionResponseFormat struct { - Type ChatCompletionResponseFormatType `json:"type,omitempty"` + Type ChatCompletionResponseFormatType `json:"type,omitempty"` + JSONSchema *ChatCompletionResponseFormatJSONSchema `json:"json_schema,omitempty"` +} + +type ChatCompletionResponseFormatJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.Marshaler `json:"schema"` + Strict bool `json:"strict"` +} + +func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error { + type rawJSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict"` + } + var raw rawJSONSchema + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.Name = raw.Name + r.Description = raw.Description + r.Strict = raw.Strict + if len(raw.Schema) > 0 && string(raw.Schema) != "null" { + var d jsonschema.Definition + err := json.Unmarshal(raw.Schema, &d) + if err != nil { + return err + } + r.Schema = &d + } + return nil +} + +// ChatCompletionRequestExtensions contains third-party OpenAI API extensions +// (e.g., vendor-specific implementations like vLLM). +type ChatCompletionRequestExtensions struct { + // GuidedChoice is a vLLM-specific extension that restricts the model's output + // to one of the predefined string choices provided in this field. This feature + // is used to constrain the model's responses to a controlled set of options, + // ensuring predictable and consistent outputs in scenarios where specific + // choices are required. + GuidedChoice []string `json:"guided_choice,omitempty"` } // ChatCompletionRequest represents a request structure for chat completion API. type ChatCompletionRequest struct { - Model string `json:"model"` - Messages []ChatCompletionMessage `json:"messages"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + // MaxTokens The maximum number of tokens that can be generated in the chat completion. + // This value can be used to control costs for text generated via API. + // Deprecated: use MaxCompletionTokens. Not compatible with o1-series models. + // refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens + MaxTokens int `json:"max_tokens,omitempty"` + // MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion, + // including visible output tokens and reasoning tokens https://platform.openai.com/docs/guides/reasoning + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + ResponseFormat *ChatCompletionResponseFormat `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/chat/create#chat/create-logit_bias @@ -216,6 +300,47 @@ type ChatCompletionRequest struct { Tools []Tool `json:"tools,omitempty"` // This can be either a string or an ToolChoice object. ToolChoice any `json:"tool_choice,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Controls effort on reasoning for reasoning models. It can be set to "low", "medium", or "high". + ReasoningEffort string `json:"reasoning_effort,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + // Configuration for a predicted output. + Prediction *Prediction `json:"prediction,omitempty"` + // ChatTemplateKwargs provides a way to add non-standard parameters to the request body. + // Additional kwargs to pass to the template renderer. Will be accessible by the chat template. + // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false} + // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes + ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` + // Specifies the latency tier to use for processing the request. + ServiceTier ServiceTier `json:"service_tier,omitempty"` + // Verbosity determines how many output tokens are generated. Lowering the number of + // tokens reduces overall latency. It can be set to "low", "medium", or "high". + // Note: This field is only confirmed to work with gpt-5, gpt-5-mini and gpt-5-nano. + // Also, it is not in the API reference of chat completion at the time of writing, + // though it is supported by the API. + Verbosity string `json:"verbosity,omitempty"` + // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. + // The IDs should be a string that uniquely identifies each user. + // We recommend hashing their username or email address, in order to avoid sending us any identifying information. + // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier + SafetyIdentifier string `json:"safety_identifier,omitempty"` + // Embedded struct for non-OpenAI extensions + ChatCompletionRequestExtensions +} + +type StreamOptions struct { + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. + IncludeUsage bool `json:"include_usage,omitempty"` } type ToolType string @@ -241,6 +366,7 @@ type ToolFunction struct { type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description,omitempty"` + Strict bool `json:"strict,omitempty"` // Parameters is an object describing the function. // You can pass json.RawMessage to describe the schema, // or you can pass in a struct which serializes to the proper JSON schema. @@ -274,6 +400,11 @@ type LogProbs struct { Content []LogProb `json:"content"` } +type Prediction struct { + Content string `json:"content"` + Type string `json:"type"` +} + type FinishReason string const ( @@ -285,6 +416,15 @@ const ( FinishReasonNull FinishReason = "null" ) +type ServiceTier string + +const ( + ServiceTierAuto ServiceTier = "auto" + ServiceTierDefault ServiceTier = "default" + ServiceTierFlex ServiceTier = "flex" + ServiceTierPriority ServiceTier = "priority" +) + func (r FinishReason) MarshalJSON() ([]byte, error) { if r == FinishReasonNull || r == "" { return []byte("null"), nil @@ -302,19 +442,22 @@ type ChatCompletionChoice struct { // function_call: The model decided to call a function // content_filter: Omitted content due to a flag from our content filters // null: API response still in progress or incomplete - FinishReason FinishReason `json:"finish_reason"` - LogProbs *LogProbs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + LogProbs *LogProbs `json:"logprobs,omitempty"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } // ChatCompletionResponse represents a response structure for chat completion API. type ChatCompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionChoice `json:"choices"` - Usage Usage `json:"usage"` - SystemFingerprint string `json:"system_fingerprint"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage Usage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + ServiceTier ServiceTier `json:"service_tier,omitempty"` httpHeader } @@ -335,7 +478,17 @@ func (c *Client) CreateChatCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { + return + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/chat_stream.go b/chat_stream.go index 57cfa789f..80d16cc63 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -10,22 +10,59 @@ type ChatCompletionStreamChoiceDelta struct { Role string `json:"role,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Refusal string `json:"refusal,omitempty"` + + // This property is used for the "reasoning" feature supported by deepseek-reasoner + // which is not in the official documentation. + // the doc from deepseek: + // - https://api-docs.deepseek.com/api/create-chat-completion#responses + ReasoningContent string `json:"reasoning_content,omitempty"` +} + +type ChatCompletionStreamChoiceLogprobs struct { + Content []ChatCompletionTokenLogprob `json:"content,omitempty"` + Refusal []ChatCompletionTokenLogprob `json:"refusal,omitempty"` +} + +type ChatCompletionTokenLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes,omitempty"` + Logprob float64 `json:"logprob,omitempty"` + TopLogprobs []ChatCompletionTokenLogprobTopLogprob `json:"top_logprobs"` +} + +type ChatCompletionTokenLogprobTopLogprob struct { + Token string `json:"token"` + Bytes []int64 `json:"bytes"` + Logprob float64 `json:"logprob"` } type ChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta ChatCompletionStreamChoiceDelta `json:"delta"` - FinishReason FinishReason `json:"finish_reason"` - ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` + Index int `json:"index"` + Delta ChatCompletionStreamChoiceDelta `json:"delta"` + Logprobs *ChatCompletionStreamChoiceLogprobs `json:"logprobs,omitempty"` + FinishReason FinishReason `json:"finish_reason"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` +} + +type PromptFilterResult struct { + Index int `json:"index"` + ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"` } type ChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionStreamChoice `json:"choices"` - PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` + SystemFingerprint string `json:"system_fingerprint"` + PromptAnnotations []PromptAnnotation `json:"prompt_annotations,omitempty"` + PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"` + // An optional field that will only be present when you set stream_options: {"include_usage": true} in your request. + // When present, it contains a null value except for the last chunk which contains the token usage statistics + // for the entire request. + Usage *Usage `json:"usage,omitempty"` } // ChatCompletionStream @@ -49,7 +86,17 @@ func (c *Client) CreateChatCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + reasoningValidator := NewReasoningValidator() + if err = reasoningValidator.Validate(request); err != nil { + return + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/chat_stream_test.go b/chat_stream_test.go index bd571cb48..eabb0f3a2 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -46,12 +46,12 @@ func TestCreateChatCompletionStream(t *testing.T) { dataBytes := []byte{} dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: message\n")...) //nolint:lll - data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) dataBytes = append(dataBytes, []byte("event: done\n")...) @@ -77,10 +77,11 @@ func TestCreateChatCompletionStream(t *testing.T) { expectedResponses := []openai.ChatCompletionStreamResponse{ { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: openai.GPT3Dot5Turbo, + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -91,10 +92,11 @@ func TestCreateChatCompletionStream(t *testing.T) { }, }, { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: openai.GPT3Dot5Turbo, + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", Choices: []openai.ChatCompletionStreamChoice{ { Delta: openai.ChatCompletionStreamChoiceDelta{ @@ -335,6 +337,271 @@ func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithRefusal(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"refusal":" World"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT4oMini20240718, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Refusal: " World", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamWithLogprobs(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":{"content":[],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":{"content":[{"token":"Hello","logprob":-0.000020458236,"bytes":[72,101,108,108,111],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":" World"},"logprobs":{"content":[{"token":" World","logprob":-0.00055303273,"bytes":[32,87,111,114,108,100],"top_logprobs":[]}],"refusal":null},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_d9767fc5b9","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 2000, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{}, + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: "Hello", + Logprob: -0.000020458236, + Bytes: []int64{72, 101, 108, 108, 111}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " World", + }, + Logprobs: &openai.ChatCompletionStreamChoiceLogprobs{ + Content: []openai.ChatCompletionTokenLogprob{ + { + Token: " World", + Logprob: -0.00055303273, + Bytes: []int64{32, 87, 111, 114, 108, 100}, + TopLogprobs: []openai.ChatCompletionTokenLogprobTopLogprob{}, + }, + }, + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.GPT4oMini20240718, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { wantCode := "429" wantMessage := "Requests to the Creates a completion for the chat message Operation under Azure OpenAI API " + @@ -386,6 +653,120 @@ func TestAzureCreateChatCompletionStreamRateLimitError(t *testing.T) { } } +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + //nolint:lll + data := `{"id":"1","object":"completion","created":1598069254,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"2","object":"completion","created":1598069255,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + //nolint:lll + data = `{"id":"3","object":"completion","created":1598069256,"model":"gpt-3.5-turbo","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &openai.StreamOptions{ + IncludeUsage: true, + }, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{ + { + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: openai.GPT3Dot5Turbo, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []openai.ChatCompletionStreamChoice{}, + Usage: &openai.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + checks.ErrorIs(t, streamErr, io.EOF, "stream.Recv() did not return EOF when the stream is finished") + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF when the stream is finished: %v", streamErr) + } +} + // Helper funcs. func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { if r1.ID != r2.ID || r1.Object != r2.Object || r1.Created != r2.Created || r1.Model != r2.Model { @@ -399,9 +780,235 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool { return false } } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } return true } +func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + dataBytes := []byte{} + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + //nolint:lll + dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...) + dataBytes = append(dataBytes, []byte("\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxCompletionTokens: 2000, + Model: openai.O3Mini20250131, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []openai.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Role: "assistant", + }, + }, + }, + }, + { + ID: "2", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: "Hello", + }, + }, + }, + }, + { + ID: "3", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " from", + }, + }, + }, + }, + { + ID: "4", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: " O3Mini", + }, + }, + }, + }, + { + ID: "5", + Object: "chat.completion.chunk", + Created: 1729585728, + Model: openai.O3Mini20250131, + SystemFingerprint: "fp_mini", + Choices: []openai.ChatCompletionStreamChoice{ + { + Index: 0, + Delta: openai.ChatCompletionStreamChoiceDelta{}, + FinishReason: "stop", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + checks.NoError(t, streamErr, "stream.Recv() failed") + if !compareChatResponses(expectedResponse, receivedResponse) { + t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } +} + +func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO3ReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O3, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O3, got: %v", err) + } +} + +func TestCreateChatCompletionStreamO4MiniReasoningValidatorFails(t *testing.T) { + client, _, _ := setupOpenAITestServer() + + stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 100, // This will trigger the validator to fail + Model: openai.O4Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + + if stream != nil { + t.Error("Expected nil stream when validation fails") + stream.Close() + } + + if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) { + t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated for O4Mini, got: %v", err) + } +} + func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool { if c1.Index != c2.Index { return false diff --git a/chat_test.go b/chat_test.go index 520bf5ca4..236cff736 100644 --- a/chat_test.go +++ b/chat_test.go @@ -52,6 +52,405 @@ func TestChatCompletionsWrongModel(t *testing.T) { checks.ErrorIs(t, err, openai.ErrChatCompletionInvalidModel, msg) } +func TestO1ModelsChatCompletionsDeprecatedFields(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "o1-preview_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Preview, + }, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, + }, + { + name: "o1-mini_MaxTokens_deprecated", + in: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.O1Mini, + }, + expectedError: openai.ErrReasoningModelMaxTokensDeprecated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O1Preview, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O1Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.O3Mini, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.O3Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + +func TestGPT5ModelsChatCompletionsBetaLimitations(t *testing.T) { + tests := []struct { + name string + in openai.ChatCompletionRequest + expectedError error + }{ + { + name: "log_probs_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + LogProbs: true, + Model: openai.GPT5, + }, + expectedError: openai.ErrReasoningModelLimitationsLogprobs, + }, + { + name: "set_temperature_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(2), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_top_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Nano, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_n_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5ChatLatest, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + Temperature: float32(1), + TopP: float32(1), + N: 2, + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_presence_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + PresencePenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + { + name: "set_frequency_penalty_unsupported", + in: openai.ChatCompletionRequest{ + MaxCompletionTokens: 1000, + Model: openai.GPT5Mini, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + }, + { + Role: openai.ChatMessageRoleAssistant, + }, + }, + FrequencyPenalty: float32(0.1), + }, + expectedError: openai.ErrReasoningModelLimitationsOther, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.CreateChatCompletion(ctx, tt.in) + checks.HasError(t, err) + msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err) + checks.ErrorIs(t, err, tt.expectedError, msg) + }) + } +} + func TestChatRequestOmitEmpty(t *testing.T) { data, err := json.Marshal(openai.ChatCompletionRequest{ // We set model b/c it's required, so omitempty doesn't make sense @@ -97,6 +496,58 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestO1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O1Preview, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + +func TestO3ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: openai.O3Mini, + MaxCompletionTokens: 1000, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + +func TestDeepseekR1ModelChatCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleDeepseekR1ChatCompletionEndpoint) + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + Model: "deepseek-reasoner", + MaxCompletionTokens: 100, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestChatCompletionsWithHeaders(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -277,6 +728,32 @@ func TestChatCompletionsFunctions(t *testing.T) { }) checks.NoError(t, err, "CreateChatCompletion with functions error") }) + t.Run("StructuredOutputs", func(t *testing.T) { + type testMessage struct { + Count int `json:"count"` + Words []string `json:"words"` + } + msg := testMessage{ + Count: 2, + Words: []string{"hello", "world"}, + } + _, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo0613, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Functions: []openai.FunctionDefinition{{ + Name: "test", + Strict: true, + Parameters: &msg, + }}, + }) + checks.NoError(t, err, "CreateChatCompletion with functions error") + }) } func TestAzureChatCompletions(t *testing.T) { @@ -394,7 +871,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) { t.Fatalf("Unexpected error") } res = strings.ReplaceAll(string(s), " ", "") - if res != `{"role":"user","content":""}` { + if res != `{"role":"user"}` { t.Fatalf("invalid message: %s", string(s)) } } @@ -482,6 +959,68 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +func handleDeepseekR1ChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq openai.ChatCompletionRequest + if completionReq, err = getChatCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := openai.ChatCompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + if completionReq.MaxCompletionTokens == 0 { + completionReq.MaxCompletionTokens = 1000 + } + for i := 0; i < n; i++ { + reasoningContent := "User says hello! And I need to reply" + completionStr := strings.Repeat("a", completionReq.MaxCompletionTokens-numTokens(reasoningContent)) + res.Choices = append(res.Choices, openai.ChatCompletionChoice{ + Message: openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + ReasoningContent: reasoningContent, + Content: completionStr, + }, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Messages[0].Content) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = openai.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + fmt.Fprintln(w, string(resBytes)) +} + // getChatCompletionBody Returns the body of the request to create a completion. func getChatCompletionBody(r *http.Request) (openai.ChatCompletionRequest, error) { completion := openai.ChatCompletionRequest{} @@ -527,3 +1066,142 @@ func TestFinishReason(t *testing.T) { } } } + +func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": null + }`), + }, + false, + }, + { + "", + args{ + data: []byte(`[123,456]`), + }, + true, + }, + { + "", + args{ + data: []byte(`{ + "name": "math_response", + "strict": true, + "schema": 123456 + }`), + }, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var r openai.ChatCompletionResponseFormatJSONSchema + err := r.UnmarshalJSON(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) { + type args struct { + bs []byte + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + "", + args{bs: []byte(`{ + "model": "llama3-1b", + "messages": [ + { "role": "system", "content": "You are a helpful math tutor." }, + { "role": "user", "content": "solve 8x + 31 = 2" } + ], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "math_response", + "strict": true, + "schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation","output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps","final_answer"], + "additionalProperties": false + } + } + } +}`)}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var m openai.ChatCompletionRequest + err := json.Unmarshal(tt.args.bs, &m) + if err != nil { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/client.go b/client.go index 9a1c8958d..413b8db03 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" utils "github.com/sashabaranov/go-openai/internal" @@ -83,15 +84,29 @@ func withBody(body any) requestOption { } } +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + // Assert that args.body is a map[string]any. + bodyMap, ok := args.body.(map[string]any) + if ok { + // If it's a map[string]any then only add extraBody + // fields to args.body otherwise keep only fields in request struct. + for key, value := range extraBody { + bodyMap[key] = value + } + } + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) } } -func withBetaAssistantV1() requestOption { +func withBetaAssistantVersion(version string) requestOption { return func(args *requestOptions) { - args.header.Set("OpenAI-Beta", "assistants=v1") + args.header.Set("OpenAI-Beta", fmt.Sprintf("assistants=%s", version)) } } @@ -129,14 +144,14 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { defer res.Body.Close() - if isFailureStatusCode(res) { - return c.handleErrorResp(res) - } - if v != nil { v.SetHeader(res.Header) } + if isFailureStatusCode(res) { + return c.handleErrorResp(res) + } + return decodeResponse(res.Body, v) } @@ -181,13 +196,21 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream func (c *Client) setCommonHeaders(req *http.Request) { // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure { + switch c.config.APIType { + case APITypeAzure, APITypeCloudflareAzure: + // Azure API Key authentication req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else if c.config.authToken != "" { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + case APITypeAnthropic: + // https://docs.anthropic.com/en/api/versioning + req.Header.Set("anthropic-version", c.config.APIVersion) + case APITypeOpenAI, APITypeAzureAD: + fallthrough + default: + if c.config.authToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } } + if c.config.OrgID != "" { req.Header.Set("OpenAI-Organization", c.config.OrgID) } @@ -221,42 +244,81 @@ func decodeString(body io.Reader, output *string) error { return nil } +type fullURLOptions struct { + model string +} + +type fullURLOption func(*fullURLOptions) + +func withModel(model string) fullURLOption { + return func(args *fullURLOptions) { + args.model = model + } +} + +var azureDeploymentsEndpoints = []string{ + "/completions", + "/embeddings", + "/chat/completions", + "/audio/transcriptions", + "/audio/translations", + "/audio/speech", + "/images/generations", +} + // fullURL returns full URL for request. -// args[0] is model name, if API type is Azure, model name is required to get deployment name. -func (c *Client) fullURL(suffix string, args ...any) string { - // /openai/deployments/{model}/chat/completions?api-version={api_version} +func (c *Client) fullURL(suffix string, setters ...fullURLOption) string { + baseURL := strings.TrimRight(c.config.BaseURL, "/") + args := fullURLOptions{} + for _, setter := range setters { + setter(&args) + } + if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD { - baseURL := c.config.BaseURL - baseURL = strings.TrimRight(baseURL, "/") - // if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01 - // https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP - if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) { - return fmt.Sprintf("%s/%s%s?api-version=%s", baseURL, azureAPIPrefix, suffix, c.config.APIVersion) - } - azureDeploymentName := "UNKNOWN" - if len(args) > 0 { - model, ok := args[0].(string) - if ok { - azureDeploymentName = c.config.GetAzureDeploymentByModel(model) - } - } - return fmt.Sprintf("%s/%s/%s/%s%s?api-version=%s", - baseURL, azureAPIPrefix, azureDeploymentsPrefix, - azureDeploymentName, suffix, c.config.APIVersion, - ) + baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model) } - // c.config.APIType == APITypeOpenAI || c.config.APIType == "" - return fmt.Sprintf("%s%s", c.config.BaseURL, suffix) + if c.config.APIVersion != "" { + suffix = c.suffixWithAPIVersion(suffix) + } + return fmt.Sprintf("%s%s", baseURL, suffix) +} + +func (c *Client) suffixWithAPIVersion(suffix string) string { + parsedSuffix, err := url.Parse(suffix) + if err != nil { + panic("failed to parse url suffix") + } + query := parsedSuffix.Query() + query.Add("api-version", c.config.APIVersion) + return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode()) +} + +func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) { + baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix) + if containsSubstr(azureDeploymentsEndpoints, suffix) { + azureDeploymentName := c.config.GetAzureDeploymentByModel(model) + if azureDeploymentName == "" { + azureDeploymentName = "UNKNOWN" + } + baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName) + } + return baseURL } func (c *Client) handleErrorResp(resp *http.Response) error { + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, reading response body: %w", err) + } var errRes ErrorResponse - err := json.NewDecoder(resp.Body).Decode(&errRes) + err = json.Unmarshal(body, &errRes) if err != nil || errRes.Error == nil { reqErr := &RequestError{ + HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: err, + Body: body, } if errRes.Error != nil { reqErr.Err = errRes.Error @@ -264,6 +326,7 @@ func (c *Client) handleErrorResp(resp *http.Response) error { return reqErr } + errRes.Error.HTTPStatus = resp.Status errRes.Error.HTTPStatusCode = resp.StatusCode return errRes.Error } diff --git a/client_test.go b/client_test.go index bc5133edc..321971445 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,21 @@ func TestClient(t *testing.T) { } } +func TestSetCommonHeadersAnthropic(t *testing.T) { + config := DefaultAnthropicConfig("mock-token", "") + client := NewClientWithConfig(config) + req, err := http.NewRequest("GET", "http://example.com/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + client.setCommonHeaders(req) + + if got := req.Header.Get("anthropic-version"); got != AnthropicAPIVersion { + t.Errorf("Expected anthropic-version header to be %q, got %q", AnthropicAPIVersion, got) + } +} + func TestDecodeResponse(t *testing.T) { stringInput := "" @@ -134,14 +149,17 @@ func TestHandleErrorResp(t *testing.T) { client := NewClient(mockToken) testCases := []struct { - name string - httpCode int - body io.Reader - expected string + name string + httpCode int + httpStatus string + contentType string + body io.Reader + expected string }{ { - name: "401 Invalid Authentication", - httpCode: http.StatusUnauthorized, + name: "401 Invalid Authentication", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -152,11 +170,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: You didn't provide an API key. ....", + expected: "error, status code: 401, status: , message: You didn't provide an API key. ....", }, { - name: "401 Azure Access Denied", - httpCode: http.StatusUnauthorized, + name: "401 Azure Access Denied", + httpCode: http.StatusUnauthorized, + contentType: "application/json", body: bytes.NewReader([]byte( `{ "error":{ @@ -165,11 +184,12 @@ func TestHandleErrorResp(t *testing.T) { } }`, )), - expected: "error, status code: 401, message: Access denied due to Virtual Network/Firewall rules.", + expected: "error, status code: 401, status: , message: Access denied due to Virtual Network/Firewall rules.", }, { - name: "503 Model Overloaded", - httpCode: http.StatusServiceUnavailable, + name: "503 Model Overloaded", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{ @@ -179,22 +199,58 @@ func TestHandleErrorResp(t *testing.T) { "code":null } }`)), - expected: "error, status code: 503, message: That model...", + expected: "error, status code: 503, status: , message: That model...", }, { - name: "503 no message (Unknown response)", - httpCode: http.StatusServiceUnavailable, + name: "503 no message (Unknown response)", + httpCode: http.StatusServiceUnavailable, + contentType: "application/json", body: bytes.NewReader([]byte(` { "error":{} }`)), - expected: "error, status code: 503, message: ", + expected: `error, status code: 503, status: , message: , body: + { + "error":{} + }`, + }, + { + name: "413 Request Entity Too Large", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: bytes.NewReader([]byte(` + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `)), + expected: `error, status code: 413, status: , message: invalid character '<' looking for beginning of value, body: + + 413 Request Entity Too Large + +

413 Request Entity Too Large

+
nginx
+ + `, + }, + { + name: "errorReader", + httpCode: http.StatusRequestEntityTooLarge, + contentType: "text/html", + body: &errorReader{err: errors.New("errorReader")}, + expected: "error, reading response body: errorReader", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - testCase := &http.Response{} + testCase := &http.Response{ + Header: map[string][]string{ + "Content-Type": {tc.contentType}, + }, + } testCase.StatusCode = tc.httpCode testCase.Body = io.NopCloser(tc.body) err := client.handleErrorResp(testCase) @@ -203,12 +259,6 @@ func TestHandleErrorResp(t *testing.T) { t.Errorf("Unexpected error: %v , expected: %s", err, tc.expected) t.Fail() } - - e := &APIError{} - if !errors.As(err, &e) { - t.Errorf("(%s) Expected error to be of type APIError", tc.name) - t.Fail() - } }) } } @@ -340,7 +390,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { return client.CreateMessage(ctx, "", MessageRequest{}) }}, {"ListMessage", func() (any, error) { - return client.ListMessage(ctx, "", nil, nil, nil, nil) + return client.ListMessage(ctx, "", nil, nil, nil, nil, nil) }}, {"RetrieveMessage", func() (any, error) { return client.RetrieveMessage(ctx, "", "") @@ -348,6 +398,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"ModifyMessage", func() (any, error) { return client.ModifyMessage(ctx, "", "", nil) }}, + {"DeleteMessage", func() (any, error) { + return client.DeleteMessage(ctx, "", "") + }}, {"RetrieveMessageFile", func() (any, error) { return client.RetrieveMessageFile(ctx, "", "", "") }}, @@ -396,6 +449,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { @@ -406,7 +470,7 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { } } -func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { +func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) { config := DefaultConfig(test.GetTestToken()) client := NewClientWithConfig(config) client.requestBuilder = &failingRequestBuilder{} @@ -420,3 +484,105 @@ func TestClientReturnsRequestBuilderErrorsAddtion(t *testing.T) { t.Fatalf("Did not return error when request builder failed: %v", err) } } + +func TestClient_suffixWithAPIVersion(t *testing.T) { + type fields struct { + apiVersion string + } + type args struct { + suffix string + } + tests := []struct { + name string + fields fields + args args + want string + wantPanic string + }{ + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants"}, + "/assistants?api-version=2023-05", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "/assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "", + }, + { + "", + fields{apiVersion: "2023-05"}, + args{suffix: "123:assistants?limit=5"}, + "/assistants?api-version=2023-05&limit=5", + "failed to parse url suffix", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + config: ClientConfig{APIVersion: tt.fields.apiVersion}, + } + defer func() { + if r := recover(); r != nil { + // Check if the panic message matches the expected panic message + if rStr, ok := r.(string); ok { + if rStr != tt.wantPanic { + t.Errorf("suffixWithAPIVersion() = %v, want %v", rStr, tt.wantPanic) + } + } else { + // If the panic is not a string, log it + t.Errorf("suffixWithAPIVersion() panicked with non-string value: %v", r) + } + } + }() + if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want { + t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_baseURLWithAzureDeployment(t *testing.T) { + type args struct { + baseURL string + suffix string + model string + } + tests := []struct { + name string + args args + wantNewBaseURL string + }{ + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini}, + "https://test.openai.azure.com/openai/deployments/gpt-4o-mini", + }, + { + "", + args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""}, + "https://test.openai.azure.com/openai/deployments/UNKNOWN", + }, + } + client := NewClient("") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotNewBaseURL := client.baseURLWithAzureDeployment( + tt.args.baseURL, + tt.args.suffix, + tt.args.model, + ); gotNewBaseURL != tt.wantNewBaseURL { + t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL) + } + }) + } +} diff --git a/common.go b/common.go index cbfda4e3c..d1936d656 100644 --- a/common.go +++ b/common.go @@ -4,7 +4,23 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details"` + CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details"` +} + +// CompletionTokensDetails Breakdown of tokens used in a completion. +type CompletionTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +// PromptTokensDetails Breakdown of tokens used in the prompt. +type PromptTokensDetails struct { + AudioTokens int `json:"audio_tokens"` + CachedTokens int `json:"cached_tokens"` } diff --git a/completion.go b/completion.go index ab1dbd6c5..27d69f587 100644 --- a/completion.go +++ b/completion.go @@ -2,63 +2,92 @@ package openai import ( "context" - "errors" "net/http" ) -var ( - ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll - ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll - ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll -) - // GPT3 Defines the models provided by OpenAI to use when generating // completions from OpenAI. // GPT3 Models are designed for text-based tasks. For code-specific // tasks, please refer to the Codex series of models. const ( - GPT432K0613 = "gpt-4-32k-0613" - GPT432K0314 = "gpt-4-32k-0314" - GPT432K = "gpt-4-32k" - GPT40613 = "gpt-4-0613" - GPT40314 = "gpt-4-0314" - GPT4Turbo0125 = "gpt-4-0125-preview" - GPT4Turbo1106 = "gpt-4-1106-preview" - GPT4TurboPreview = "gpt-4-turbo-preview" - GPT4VisionPreview = "gpt-4-vision-preview" - GPT4 = "gpt-4" - GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" - GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" - GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" - GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" - GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" - GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" - GPT3Dot5Turbo = "gpt-3.5-turbo" - GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + O1Mini = "o1-mini" + O1Mini20240912 = "o1-mini-2024-09-12" + O1Preview = "o1-preview" + O1Preview20240912 = "o1-preview-2024-09-12" + O1 = "o1" + O120241217 = "o1-2024-12-17" + O3 = "o3" + O320250416 = "o3-2025-04-16" + O3Mini = "o3-mini" + O3Mini20250131 = "o3-mini-2025-01-31" + O4Mini = "o4-mini" + O4Mini20250416 = "o4-mini-2025-04-16" + GPT432K0613 = "gpt-4-32k-0613" + GPT432K0314 = "gpt-4-32k-0314" + GPT432K = "gpt-4-32k" + GPT40613 = "gpt-4-0613" + GPT40314 = "gpt-4-0314" + GPT4o = "gpt-4o" + GPT4o20240513 = "gpt-4o-2024-05-13" + GPT4o20240806 = "gpt-4o-2024-08-06" + GPT4o20241120 = "gpt-4o-2024-11-20" + GPT4oLatest = "chatgpt-4o-latest" + GPT4oMini = "gpt-4o-mini" + GPT4oMini20240718 = "gpt-4o-mini-2024-07-18" + GPT4Turbo = "gpt-4-turbo" + GPT4Turbo20240409 = "gpt-4-turbo-2024-04-09" + GPT4Turbo0125 = "gpt-4-0125-preview" + GPT4Turbo1106 = "gpt-4-1106-preview" + GPT4TurboPreview = "gpt-4-turbo-preview" + GPT4VisionPreview = "gpt-4-vision-preview" + GPT4 = "gpt-4" + GPT4Dot1 = "gpt-4.1" + GPT4Dot120250414 = "gpt-4.1-2025-04-14" + GPT4Dot1Mini = "gpt-4.1-mini" + GPT4Dot1Mini20250414 = "gpt-4.1-mini-2025-04-14" + GPT4Dot1Nano = "gpt-4.1-nano" + GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14" + GPT4Dot5Preview = "gpt-4.5-preview" + GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27" + GPT5 = "gpt-5" + GPT5Mini = "gpt-5-mini" + GPT5Nano = "gpt-5-nano" + GPT5ChatLatest = "gpt-5-chat-latest" + GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125" + GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106" + GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613" + GPT3Dot5Turbo0301 = "gpt-3.5-turbo-0301" + GPT3Dot5Turbo16K = "gpt-3.5-turbo-16k" + GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" + GPT3Dot5Turbo = "gpt-3.5-turbo" + GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. @@ -72,23 +101,55 @@ const ( var disabledModelsForEndpoints = map[string]map[string]bool{ "/completions": { - GPT3Dot5Turbo: true, - GPT3Dot5Turbo0301: true, - GPT3Dot5Turbo0613: true, - GPT3Dot5Turbo1106: true, - GPT3Dot5Turbo0125: true, - GPT3Dot5Turbo16K: true, - GPT3Dot5Turbo16K0613: true, - GPT4: true, - GPT4TurboPreview: true, - GPT4VisionPreview: true, - GPT4Turbo1106: true, - GPT4Turbo0125: true, - GPT40314: true, - GPT40613: true, - GPT432K: true, - GPT432K0314: true, - GPT432K0613: true, + O1Mini: true, + O1Mini20240912: true, + O1Preview: true, + O1Preview20240912: true, + O3Mini: true, + O3Mini20250131: true, + O4Mini: true, + O4Mini20250416: true, + O3: true, + O320250416: true, + GPT3Dot5Turbo: true, + GPT3Dot5Turbo0301: true, + GPT3Dot5Turbo0613: true, + GPT3Dot5Turbo1106: true, + GPT3Dot5Turbo0125: true, + GPT3Dot5Turbo16K: true, + GPT3Dot5Turbo16K0613: true, + GPT4: true, + GPT4Dot5Preview: true, + GPT4Dot5Preview20250227: true, + GPT4o: true, + GPT4o20240513: true, + GPT4o20240806: true, + GPT4o20241120: true, + GPT4oLatest: true, + GPT4oMini: true, + GPT4oMini20240718: true, + GPT4TurboPreview: true, + GPT4VisionPreview: true, + GPT4Turbo1106: true, + GPT4Turbo0125: true, + GPT4Turbo: true, + GPT4Turbo20240409: true, + GPT40314: true, + GPT40613: true, + GPT432K: true, + GPT432K0314: true, + GPT432K0613: true, + O1: true, + GPT4Dot1: true, + GPT4Dot120250414: true, + GPT4Dot1Mini: true, + GPT4Dot1Mini20250414: true, + GPT4Dot1Nano: true, + GPT4Dot1Nano20250414: true, + GPT5: true, + GPT5Mini: true, + GPT5Nano: true, + GPT5ChatLatest: true, }, chatCompletionsSuffix: { CodexCodeDavinci002: true, @@ -116,30 +177,54 @@ func checkEndpointSupportsModel(endpoint, model string) bool { func checkPromptType(prompt any) bool { _, isString := prompt.(string) _, isStringSlice := prompt.([]string) - return isString || isStringSlice + if isString || isStringSlice { + return true + } + + // check if it is prompt is []string hidden under []any + slice, isSlice := prompt.([]any) + if !isSlice { + return false + } + + for _, item := range slice { + _, itemIsString := item.(string) + if !itemIsString { + return false + } + } + return true // all items in the slice are string, so it is []string } // CompletionRequest represents a request structure for completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt any `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` + Model string `json:"model"` + Prompt any `json:"prompt,omitempty"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` // LogitBias is must be a token id string (specified by their token ID in the tokenizer), not a word string. // incorrect: `"logit_bias":{"You": 6}`, correct: `"logit_bias":{"1639": 6}` // refs: https://platform.openai.com/docs/api-reference/completions/create#completions/create-logit_bias LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + // Store can be set to true to store the output of this completion request for use in distillations and evals. + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-store + Store bool `json:"store,omitempty"` + // Metadata to store with the completion. + Metadata map[string]string `json:"metadata,omitempty"` + LogProbs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` + // Options for streaming response. Only set this when you set stream: true. + StreamOptions *StreamOptions `json:"stream_options,omitempty"` } // CompletionChoice represents one of possible completions. @@ -165,7 +250,7 @@ type CompletionResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + Usage *Usage `json:"usage,omitempty"` httpHeader } @@ -195,7 +280,12 @@ func (c *Client) CreateCompletion( return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } diff --git a/completion_test.go b/completion_test.go index 89950bf94..abfc3007e 100644 --- a/completion_test.go +++ b/completion_test.go @@ -33,6 +33,42 @@ func TestCompletionsWrongModel(t *testing.T) { } } +// TestCompletionsWrongModelO3 Tests the completions endpoint with O3 model which is not supported. +func TestCompletionsWrongModelO3(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O3, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O3, but returned: %v", err) + } +} + +// TestCompletionsWrongModelO4Mini Tests the completions endpoint with O4Mini model which is not supported. +func TestCompletionsWrongModelO4Mini(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O4Mini, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O4Mini, but returned: %v", err) + } +} + func TestCompletionWithStream(t *testing.T) { config := openai.DefaultConfig("whatever") client := openai.NewClientWithConfig(config) @@ -59,6 +95,38 @@ func TestCompletions(t *testing.T) { checks.NoError(t, err, "CreateCompletion error") } +// TestMultiplePromptsCompletionsWrong Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts with wrong type. +func TestMultiplePromptsCompletionsWrong(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", 9}, + } + _, err := client.CreateCompletion(context.Background(), req) + if !errors.Is(err, openai.ErrCompletionRequestPromptTypeNotSupported) { + t.Fatalf("CreateCompletion should return ErrCompletionRequestPromptTypeNotSupported, but returned: %v", err) + } +} + +// TestMultiplePromptsCompletions Tests the completions endpoint of the API using the mocked server +// where the completions requests has a list of prompts. +func TestMultiplePromptsCompletions(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := openai.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: []interface{}{"Lorem ipsum", "Lorem ipsum"}, + } + _, err := client.CreateCompletion(context.Background(), req) + checks.NoError(t, err, "CreateCompletion error") +} + // handleCompletionEndpoint Handles the completion endpoint by the test server. func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error @@ -87,24 +155,50 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if n == 0 { n = 1 } + // Handle different types of prompts: single string or list of strings + prompts := []string{} + switch v := completionReq.Prompt.(type) { + case string: + prompts = append(prompts, v) + case []interface{}: + for _, item := range v { + if str, ok := item.(string); ok { + prompts = append(prompts, str) + } + } + default: + http.Error(w, "Invalid prompt type", http.StatusBadRequest) + return + } + for i := 0; i < n; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr + for _, prompt := range prompts { + // Generate a random string of length completionReq.MaxTokens + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = prompt + completionStr + } + + res.Choices = append(res.Choices, openai.CompletionChoice{ + Text: completionStr, + Index: len(res.Choices), + }) } - res.Choices = append(res.Choices, openai.CompletionChoice{ - Text: completionStr, - Index: i, - }) } - inputTokens := numTokens(completionReq.Prompt.(string)) * n - completionTokens := completionReq.MaxTokens * n - res.Usage = openai.Usage{ + + inputTokens := 0 + for _, prompt := range prompts { + inputTokens += numTokens(prompt) + } + inputTokens *= n + completionTokens := completionReq.MaxTokens * len(prompts) * n + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, } + + // Serialize the response and send it back resBytes, _ = json.Marshal(res) fmt.Fprintln(w, string(resBytes)) } @@ -123,3 +217,115 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { } return completion, nil } + +// TestCompletionWithO1Model Tests that O1 model is not supported for completion endpoint. +func TestCompletionWithO1Model(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: openai.O1, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for O1 model, but returned: %v", err) + } +} + +// TestCompletionWithGPT4DotModels Tests that newer GPT4 models are not supported for completion endpoint. +func TestCompletionWithGPT4DotModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4Dot1, + openai.GPT4Dot120250414, + openai.GPT4Dot1Mini, + openai.GPT4Dot1Mini20250414, + openai.GPT4Dot1Nano, + openai.GPT4Dot1Nano20250414, + openai.GPT4Dot5Preview, + openai.GPT4Dot5Preview20250227, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT4oModels Tests that GPT4o models are not supported for completion endpoint. +func TestCompletionWithGPT4oModels(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT4o, + openai.GPT4o20240513, + openai.GPT4o20240806, + openai.GPT4o20241120, + openai.GPT4oLatest, + openai.GPT4oMini, + openai.GPT4oMini20240718, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} + +// TestCompletionWithGPT5Models Tests that GPT5 models are not supported for completion endpoint. +func TestCompletionWithGPT5Models(t *testing.T) { + config := openai.DefaultConfig("whatever") + config.BaseURL = "http://localhost/v1" + client := openai.NewClientWithConfig(config) + + models := []string{ + openai.GPT5, + openai.GPT5Mini, + openai.GPT5Nano, + openai.GPT5ChatLatest, + } + + for _, model := range models { + t.Run(model, func(t *testing.T) { + _, err := client.CreateCompletion( + context.Background(), + openai.CompletionRequest{ + MaxTokens: 5, + Model: model, + }, + ) + if !errors.Is(err, openai.ErrCompletionUnsupportedModel) { + t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err) + } + }) + } +} diff --git a/config.go b/config.go index c58b71ec6..4b8cfb6fb 100644 --- a/config.go +++ b/config.go @@ -3,6 +3,7 @@ package openai import ( "net/http" "regexp" + "strings" ) const ( @@ -11,18 +12,28 @@ const ( azureAPIPrefix = "openai" azureDeploymentsPrefix = "deployments" + + AnthropicAPIVersion = "2023-06-01" ) type APIType string const ( - APITypeOpenAI APIType = "OPEN_AI" - APITypeAzure APIType = "AZURE" - APITypeAzureAD APIType = "AZURE_AD" + APITypeOpenAI APIType = "OPEN_AI" + APITypeAzure APIType = "AZURE" + APITypeAzureAD APIType = "AZURE_AD" + APITypeCloudflareAzure APIType = "CLOUDFLARE_AZURE" + APITypeAnthropic APIType = "ANTHROPIC" ) const AzureAPIKeyHeader = "api-key" +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store + +type HTTPDoer interface { + Do(req *http.Request) (*http.Response, error) +} + // ClientConfig is a configuration of a client. type ClientConfig struct { authToken string @@ -30,19 +41,21 @@ type ClientConfig struct { BaseURL string OrgID string APIType APIType - APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD + APIVersion string // required when APIType is APITypeAzure or APITypeAzureAD or APITypeAnthropic + AssistantVersion string AzureModelMapperFunc func(model string) string // replace model to azure deployment name func - HTTPClient *http.Client + HTTPClient HTTPDoer EmptyMessagesLimit uint } func DefaultConfig(authToken string) ClientConfig { return ClientConfig{ - authToken: authToken, - BaseURL: openaiAPIURLv1, - APIType: APITypeOpenAI, - OrgID: "", + authToken: authToken, + BaseURL: openaiAPIURLv1, + APIType: APITypeOpenAI, + AssistantVersion: defaultAssistantVersion, + OrgID: "", HTTPClient: &http.Client{}, @@ -58,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { APIType: APITypeAzure, APIVersion: "2023-05-15", AzureModelMapperFunc: func(model string) string { - return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + // only 3.5 models have the "." stripped in their names + if strings.Contains(model, "3.5") { + return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "") + } + return strings.ReplaceAll(model, ":", "") }, HTTPClient: &http.Client{}, @@ -67,6 +84,23 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig { } } +func DefaultAnthropicConfig(apiKey, baseURL string) ClientConfig { + if baseURL == "" { + baseURL = "https://api.anthropic.com/v1" + } + return ClientConfig{ + authToken: apiKey, + BaseURL: baseURL, + OrgID: "", + APIType: APITypeAnthropic, + APIVersion: AnthropicAPIVersion, + + HTTPClient: &http.Client{}, + + EmptyMessagesLimit: defaultEmptyMessagesLimit, + } +} + func (ClientConfig) String() string { return "" } diff --git a/config_test.go b/config_test.go index 3e528c3e9..f44b80825 100644 --- a/config_test.go +++ b/config_test.go @@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) { Model: "gpt-3.5-turbo-0301", Expect: "gpt-35-turbo-0301", }, + { + Model: "gpt-4.1", + Expect: "gpt-4.1", + }, { Model: "text-embedding-ada-002", Expect: "text-embedding-ada-002", @@ -60,3 +64,64 @@ func TestGetAzureDeploymentByModel(t *testing.T) { }) } } + +func TestDefaultAnthropicConfig(t *testing.T) { + apiKey := "test-key" + baseURL := "https://api.anthropic.com/v1" + + config := openai.DefaultAnthropicConfig(apiKey, baseURL) + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be 2023-06-01, got %v", config.APIVersion) + } + + if config.BaseURL != baseURL { + t.Errorf("Expected BaseURL to be %v, got %v", baseURL, config.BaseURL) + } + + if config.EmptyMessagesLimit != 300 { + t.Errorf("Expected EmptyMessagesLimit to be 300, got %v", config.EmptyMessagesLimit) + } +} + +func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) { + config := openai.DefaultAnthropicConfig("", "") + + if config.APIType != openai.APITypeAnthropic { + t.Errorf("Expected APIType to be %v, got %v", openai.APITypeAnthropic, config.APIType) + } + + if config.APIVersion != openai.AnthropicAPIVersion { + t.Errorf("Expected APIVersion to be %s, got %v", openai.AnthropicAPIVersion, config.APIVersion) + } + + expectedBaseURL := "https://api.anthropic.com/v1" + if config.BaseURL != expectedBaseURL { + t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL) + } +} + +func TestClientConfigString(t *testing.T) { + // String() should always return the constant value + conf := openai.DefaultConfig("dummy-token") + expected := "" + got := conf.String() + if got != expected { + t.Errorf("ClientConfig.String() = %q; want %q", got, expected) + } +} + +func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) { + // On a zero-value or DefaultConfig, AzureModelMapperFunc is nil, + // so GetAzureDeploymentByModel should just return the input model. + conf := openai.DefaultConfig("dummy-token") + model := "some-model" + got := conf.GetAzureDeploymentByModel(model) + if got != model { + t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model) + } +} diff --git a/edits.go b/edits.go index 97d026029..fe8ecd0c1 100644 --- a/edits.go +++ b/edits.go @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024. You can use CreateChatCompletion or CreateChatCompletionStream instead. */ func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/edits", withModel(fmt.Sprint(request.Model))), + withBody(request), + ) if err != nil { return } diff --git a/embeddings.go b/embeddings.go index c5633a313..8593f8b5b 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/binary" + "encoding/json" "errors" "math" "net/http" @@ -16,7 +17,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" @@ -155,11 +156,14 @@ const ( type EmbeddingRequest struct { Input any `json:"input"` Model EmbeddingModel `json:"model"` - User string `json:"user"` + User string `json:"user,omitempty"` EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { @@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct { // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. Dimensions int `json:"dimensions,omitempty"` + // The ExtraBody field allows for the inclusion of arbitrary key-value pairs + // in the request body that may not be explicitly defined in this struct. + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { @@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest { User: r.User, EncodingFormat: r.EncodingFormat, Dimensions: r.Dimensions, + ExtraBody: r.ExtraBody, } } @@ -241,7 +253,30 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq)) + + // The body map is used to dynamically construct the request payload for the embedding API. + // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields + // based on their presence, avoiding unnecessary or empty fields in the request. + extraBody := baseReq.ExtraBody + baseReq.ExtraBody = nil + + // Serialize baseReq to JSON + jsonData, err := json.Marshal(baseReq) + if err != nil { + return + } + + // Deserialize JSON to map[string]any + var body map[string]any + _ = json.Unmarshal(jsonData, &body) + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/embeddings", withModel(string(baseReq.Model))), + withBody(body), // Main request body. + withExtraBody(extraBody), // Merge ExtraBody fields. + ) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index 438978169..07f1262cb 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) { t.Fatalf("Expected embedding request to contain model field") } + // test embedding request with strings and extra_body param + embeddingReqWithExtraBody := openai.EmbeddingRequest{ + Input: []string{ + "The food was delicious and the waiter", + "Other examples of embedding request", + }, + Model: model, + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + } + marshaled, err = json.Marshal(embeddingReqWithExtraBody) + checks.NoError(t, err, "Could not marshal embedding request") + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { + t.Fatalf("Expected embedding request to contain model field") + } + // test embedding request with strings embeddingReqStrings := openai.EmbeddingRequestStrings{ Input: []string{ @@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) { t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) } - // test create embeddings with strings (simple embedding request) + // test create embeddings with strings (ExtraBody in request) + res, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + ExtraBody: map[string]any{ + "input_type": "query", + "truncate": "NONE", + }, + Dimensions: 1, + }, + ) + checks.NoError(t, err, "CreateEmbeddings error") + if !reflect.DeepEqual(res.Data, sampleEmbeddings) { + t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data) + } + + // test create embeddings with strings (ExtraBody in request and ) + _, err = client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Input: make(chan int), // Channels are not serializable + Model: "example_model", + }, + ) + checks.HasError(t, err, "CreateEmbeddings error") + + // test failed (Serialize JSON error) res, err = client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ diff --git a/error.go b/error.go index b2d01e22e..8a74bd52c 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ type APIError struct { Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` + HTTPStatus string `json:"-"` HTTPStatusCode int `json:"-"` InnerError *InnerError `json:"innererror,omitempty"` } @@ -23,10 +24,12 @@ type InnerError struct { ContentFilterResults ContentFilterResults `json:"content_filter_result,omitempty"` } -// RequestError provides informations about generic request errors. +// RequestError provides information about generic request errors. type RequestError struct { + HTTPStatus string HTTPStatusCode int Err error + Body []byte } type ErrorResponse struct { @@ -35,7 +38,7 @@ type ErrorResponse struct { func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Message) + return fmt.Sprintf("error, status code: %d, status: %s, message: %s", e.HTTPStatusCode, e.HTTPStatus, e.Message) } return e.Message @@ -101,7 +104,10 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) { } func (e *RequestError) Error() string { - return fmt.Sprintf("error, status code: %d, message: %s", e.HTTPStatusCode, e.Err) + return fmt.Sprintf( + "error, status code: %d, status: %s, message: %s, body: %s", + e.HTTPStatusCode, e.HTTPStatus, e.Err, e.Body, + ) } func (e *RequestError) Unwrap() error { diff --git a/example_test.go b/example_test.go index de67c57cd..5910ffb84 100644 --- a/example_test.go +++ b/example_test.go @@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() { } defer stream.Close() - fmt.Printf("Stream response: ") + fmt.Print("Stream response: ") for { var response openai.ChatCompletionStreamResponse response, err = stream.Recv() @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() { return } - fmt.Printf(response.Choices[0].Delta.Content) + fmt.Println(response.Choices[0].Delta.Content) } } @@ -82,7 +82,7 @@ func ExampleClient_CreateCompletion() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, @@ -99,7 +99,7 @@ func ExampleClient_CreateCompletionStream() { stream, err := client.CreateCompletionStream( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", Stream: true, diff --git a/examples/completion/main.go b/examples/completion/main.go index 22af1fd82..8c5cbd5ca 100644 --- a/examples/completion/main.go +++ b/examples/completion/main.go @@ -13,7 +13,7 @@ func main() { resp, err := client.CreateCompletion( context.Background(), openai.CompletionRequest{ - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, MaxTokens: 5, Prompt: "Lorem ipsum", }, diff --git a/files.go b/files.go index b40a44f15..edc9f2a20 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. @@ -101,6 +102,7 @@ func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File if err != nil { return } + defer fileData.Close() err = builder.CreateFormFile("file", fileData) if err != nil { diff --git a/files_api_test.go b/files_api_test.go index c92162a84..aa4fda458 100644 --- a/files_api_test.go +++ b/files_api_test.go @@ -152,6 +152,7 @@ func TestGetFileContentReturnError(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, wantErrorResp) }) diff --git a/files_test.go b/files_test.go index 3c1b99fb4..486ef892e 100644 --- a/files_test.go +++ b/files_test.go @@ -121,3 +121,20 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } +func TestCreateFileRequestBuilderFailure(t *testing.T) { + config := DefaultConfig("") + config.BaseURL = "" + client := NewClientWithConfig(config) + client.requestBuilder = &failingRequestBuilder{} + + client.createFormBuilder = func(io.Writer) utils.FormBuilder { + return &mockFormBuilder{ + mockWriteField: func(string, string) error { return nil }, + mockCreateFormFile: func(string, *os.File) error { return nil }, + mockClose: func() error { return nil }, + } + } + + _, err := client.CreateFile(context.Background(), FileRequest{FilePath: "client.go"}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateFile should return error if request builder fails") +} diff --git a/fine_tunes.go b/fine_tunes.go index ca840781c..74b47bf3f 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -115,7 +115,7 @@ func (c *Client) CreateFineTune(ctx context.Context, request FineTuneRequest) (r // This API will be officially deprecated on January 4th, 2024. // OpenAI recommends to migrate to the new fine tuning API implemented in fine_tuning_job.go. func (c *Client) CancelFineTune(ctx context.Context, fineTuneID string) (response FineTune, err error) { - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine-tunes/"+fineTuneID+"/cancel")) //nolint:lll //this method is deprecated if err != nil { return } diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 9dcb49de1..5a9f54a92 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -26,7 +26,9 @@ type FineTuningJob struct { } type Hyperparameters struct { - Epochs any `json:"n_epochs,omitempty"` + Epochs any `json:"n_epochs,omitempty"` + LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"` + BatchSize any `json:"batch_size,omitempty"` } type FineTuningJobRequest struct { diff --git a/fine_tuning_job_test.go b/fine_tuning_job_test.go index d2fbcd4c7..5f63ef24c 100644 --- a/fine_tuning_job_test.go +++ b/fine_tuning_job_test.go @@ -33,7 +33,9 @@ func TestFineTuningJob(t *testing.T) { ValidationFile: "", TrainingFile: "file-abc123", Hyperparameters: openai.Hyperparameters{ - Epochs: "auto", + Epochs: "auto", + LearningRateMultiplier: "auto", + BatchSize: "auto", }, TrainedTokens: 5768, }) diff --git a/image.go b/image.go index 665de1a74..84b9daf02 100644 --- a/image.go +++ b/image.go @@ -3,8 +3,8 @@ package openai import ( "bytes" "context" + "io" "net/http" - "os" "strconv" ) @@ -13,51 +13,101 @@ const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" CreateImageSize1024x1024 = "1024x1024" + // dall-e-3 supported only. CreateImageSize1792x1024 = "1792x1024" CreateImageSize1024x1792 = "1024x1792" + + // gpt-image-1 supported only. + CreateImageSize1536x1024 = "1536x1024" // Landscape + CreateImageSize1024x1536 = "1024x1536" // Portrait ) const ( - CreateImageResponseFormatURL = "url" + // dall-e-2 and dall-e-3 only. CreateImageResponseFormatB64JSON = "b64_json" + CreateImageResponseFormatURL = "url" ) const ( - CreateImageModelDallE2 = "dall-e-2" - CreateImageModelDallE3 = "dall-e-3" + CreateImageModelDallE2 = "dall-e-2" + CreateImageModelDallE3 = "dall-e-3" + CreateImageModelGptImage1 = "gpt-image-1" ) const ( CreateImageQualityHD = "hd" CreateImageQualityStandard = "standard" + + // gpt-image-1 only. + CreateImageQualityHigh = "high" + CreateImageQualityMedium = "medium" + CreateImageQualityLow = "low" ) const ( + // dall-e-3 only. CreateImageStyleVivid = "vivid" CreateImageStyleNatural = "natural" ) +const ( + // gpt-image-1 only. + CreateImageBackgroundTransparent = "transparent" + CreateImageBackgroundOpaque = "opaque" +) + +const ( + // gpt-image-1 only. + CreateImageModerationLow = "low" +) + +const ( + // gpt-image-1 only. + CreateImageOutputFormatPNG = "png" + CreateImageOutputFormatJPEG = "jpeg" + CreateImageOutputFormatWEBP = "webp" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Quality string `json:"quality,omitempty"` - Size string `json:"size,omitempty"` - Style string `json:"style,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Quality string `json:"quality,omitempty"` + Size string `json:"size,omitempty"` + Style string `json:"style,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` + Background string `json:"background,omitempty"` + Moderation string `json:"moderation,omitempty"` + OutputCompression int `json:"output_compression,omitempty"` + OutputFormat string `json:"output_format,omitempty"` } // ImageResponse represents a response structure for image API. type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + Usage ImageResponseUsage `json:"usage,omitempty"` httpHeader } +// ImageResponseInputTokensDetails represents the token breakdown for input tokens. +type ImageResponseInputTokensDetails struct { + TextTokens int `json:"text_tokens,omitempty"` + ImageTokens int `json:"image_tokens,omitempty"` +} + +// ImageResponseUsage represents the token usage information for image API. +type ImageResponseUsage struct { + TotalTokens int `json:"total_tokens,omitempty"` + InputTokens int `json:"input_tokens,omitempty"` + OutputTokens int `json:"output_tokens,omitempty"` + InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"` +} + // ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { URL string `json:"url,omitempty"` @@ -68,7 +118,12 @@ type ImageResponseDataInner struct { // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { urlSuffix := "/images/generations" - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return } @@ -77,15 +132,42 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons return } +// WrapReader wraps an io.Reader with filename and Content-type. +func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Name() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} + // ImageEditRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageEditRequest struct { - Image *os.File `json:"image,omitempty"` - Mask *os.File `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` + Image io.Reader `json:"image,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -93,15 +175,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename verification can be postponed + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } // mask, it is optional if request.Mask != nil { - err = builder.CreateFormFile("mask", request.Mask) + // filename verification can be postponed + err = builder.CreateFormFileReader("mask", request.Mask, "") if err != nil { return } @@ -132,8 +215,13 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/edits", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/edits", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } @@ -143,12 +231,14 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) } // ImageVariRequest represents the request structure for the image API. +// Use WrapReader to wrap an io.Reader with filename and Content-type. type ImageVariRequest struct { - Image *os.File `json:"image,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` + Image io.Reader `json:"image,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` } // CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API. @@ -157,8 +247,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image - err = builder.CreateFormFile("image", request.Image) + // image, filename verification can be postponed + err = builder.CreateFormFileReader("image", request.Image, "") if err != nil { return } @@ -183,8 +273,13 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/images/variations", request.Model), - withBody(body), withContentType(builder.FormDataContentType())) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/images/variations", withModel(request.Model)), + withBody(body), + withContentType(builder.FormDataContentType()), + ) if err != nil { return } diff --git a/image_api_test.go b/image_api_test.go index 2eb46f2b4..f6057b77d 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "path/filepath" "testing" "time" @@ -36,7 +37,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { var err error var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -86,24 +87,17 @@ func TestImageEdit(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } + defer origin.Close() - mask, err := os.Create("mask.png") + mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png")) if err != nil { - t.Error("open mask file error") - return + t.Fatalf("open mask file error: %v", err) } - - defer func() { - mask.Close() - origin.Close() - os.Remove("mask.png") - os.Remove("image.png") - }() + defer mask.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -121,16 +115,11 @@ func TestImageEditWithoutMask(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ Image: origin, @@ -146,7 +135,7 @@ func TestImageEditWithoutMask(t *testing.T) { func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -178,16 +167,11 @@ func TestImageVariation(t *testing.T) { defer teardown() server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint) - origin, err := os.Create("image.png") + origin, err := os.Create(filepath.Join(t.TempDir(), "image.png")) if err != nil { - t.Error("open origin file error") - return + t.Fatalf("open origin file error: %v", err) } - - defer func() { - origin.Close() - os.Remove("image.png") - }() + defer origin.Close() _, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{ Image: origin, @@ -202,7 +186,7 @@ func TestImageVariation(t *testing.T) { func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) { var resBytes []byte - // imagess only accepts POST requests + // images only accepts POST requests if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } diff --git a/image_test.go b/image_test.go index 9332dd5cd..c2c8f42dc 100644 --- a/image_test.go +++ b/image_test.go @@ -4,6 +4,7 @@ import ( utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test/checks" + "bytes" "context" "fmt" "io" @@ -39,120 +40,284 @@ func (fb *mockFormBuilder) FormDataContentType() string { } func TestImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() - - req := ImageEditRequest{ - Mask: &os.File{}, - } - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return mockFailedErr - } - _, err := client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error { - if name == "mask" { - return mockFailedErr - } - return nil + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return nil + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageEditRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "mask", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error { + if name == "mask" { + return mockFailedErr + } + return nil + } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "prompt", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "prompt" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}, + }, } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr - } - return nil + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateEditImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateEditImage should return error if form builder fails") + }) } - failForField = "prompt" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "n" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "size" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") - - failForField = "response_format" - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, + } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr - } - _, err = client.CreateEditImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") + _, err := client.CreateEditImage(ctx, ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateEditImage should return error if request builder fails") + }) } func TestVariImageFormBuilderFailures(t *testing.T) { - config := DefaultConfig("") - config.BaseURL = "" - client := NewClientWithConfig(config) - - mockBuilder := &mockFormBuilder{} - client.createFormBuilder = func(io.Writer) utils.FormBuilder { - return mockBuilder - } ctx := context.Background() + mockFailedErr := fmt.Errorf("mock form builder fail") - req := ImageVariRequest{} + newClient := func(fb *mockFormBuilder) *Client { + cfg := DefaultConfig("") + cfg.BaseURL = "" + c := NewClientWithConfig(cfg) + c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb } + return c + } - mockFailedErr := fmt.Errorf("mock form builder fail") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return mockFailedErr + tests := []struct { + name string + setup func(*mockFormBuilder) + req ImageVariRequest + }{ + { + name: "image", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "n", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "n" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "size", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "size" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "response_format", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(field string, _ string) error { + if field == "response_format" { + return mockFailedErr + } + return nil + } + fb.mockClose = func() error { return nil } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, + { + name: "close", + setup: func(fb *mockFormBuilder) { + fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil } + fb.mockWriteField = func(string, string) error { return nil } + fb.mockClose = func() error { return mockFailedErr } + }, + req: ImageVariRequest{Image: bytes.NewBuffer(nil)}, + }, } - _, err := client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") - mockBuilder.mockCreateFormFile = func(string, *os.File) error { - return nil + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + fb := &mockFormBuilder{} + tc.setup(fb) + client := newClient(fb) + _, err := client.CreateVariImage(ctx, tc.req) + checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + }) } - var failForField string - mockBuilder.mockWriteField = func(fieldname, _ string) error { - if fieldname == failForField { - return mockFailedErr + t.Run("new request", func(t *testing.T) { + fb := &mockFormBuilder{ + mockCreateFormFileReader: func(string, io.Reader, string) error { return nil }, + mockWriteField: func(string, string) error { return nil }, + mockClose: func() error { return nil }, } - return nil - } + client := newClient(fb) + client.requestBuilder = &failingRequestBuilder{} + + _, err := client.CreateVariImage(ctx, ImageVariRequest{Image: bytes.NewBuffer(nil)}) + checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateVariImage should return error if request builder fails") + }) +} + +type testNamedReader struct{ io.Reader } - failForField = "n" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") +func (testNamedReader) Name() string { return "named.txt" } - failForField = "size" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") +func TestWrapReader(t *testing.T) { + r := bytes.NewBufferString("data") + wrapped := WrapReader(r, "file.png", "image/png") + f, ok := wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped reader missing Name or ContentType") + } + if f.Name() != "file.png" { + t.Fatalf("expected name file.png, got %s", f.Name()) + } + if f.ContentType() != "image/png" { + t.Fatalf("expected content type image/png, got %s", f.ContentType()) + } - failForField = "response_format" - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails") + // test name from underlying reader + nr := testNamedReader{Reader: bytes.NewBufferString("d")} + wrapped = WrapReader(nr, "", "text/plain") + f, ok = wrapped.(interface { + Name() string + ContentType() string + }) + if !ok { + t.Fatal("wrapped named reader missing Name or ContentType") + } + if f.Name() != "named.txt" { + t.Fatalf("expected name named.txt, got %s", f.Name()) + } + if f.ContentType() != "text/plain" { + t.Fatalf("expected content type text/plain, got %s", f.ContentType()) + } - failForField = "" - mockBuilder.mockClose = func() error { - return mockFailedErr + // no name provided + wrapped = WrapReader(bytes.NewBuffer(nil), "", "") + f2, ok := wrapped.(interface{ Name() string }) + if !ok { + t.Fatal("wrapped anonymous reader missing Name") + } + if f2.Name() != "" { + t.Fatalf("expected empty name, got %s", f2.Name()) } - _, err = client.CreateVariImage(ctx, req) - checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails") } diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go index d48f28177..f6c226c5e 100644 --- a/internal/error_accumulator_test.go +++ b/internal/error_accumulator_test.go @@ -1,41 +1,39 @@ package openai_test import ( - "bytes" - "errors" "testing" - utils "github.com/sashabaranov/go-openai/internal" + openai "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" + "github.com/sashabaranov/go-openai/internal/test/checks" ) -func TestErrorAccumulatorBytes(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &bytes.Buffer{}, +func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } + checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}"))) + checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}"))) - errBytes := accumulator.Bytes() - if len(errBytes) != 0 { - t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes)) + expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}" + if string(ea.Bytes()) != expected { + t.Fatalf("Expected %q, got %q", expected, ea.Bytes()) } +} - err := accumulator.Write([]byte("{}")) - if err != nil { - t.Fatalf("%+v", err) +func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) { + ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator) + if !ok { + t.Fatal("type assertion to *DefaultErrorAccumulator failed") } - - errBytes = accumulator.Bytes() - if len(errBytes) == 0 { - t.Fatalf("Did not return error bytes when has error: %s", string(errBytes)) + if len(ea.Bytes()) != 0 { + t.Fatal("Buffer should be empty initially") } } -func TestErrorByteWriteErrors(t *testing.T) { - accumulator := &utils.DefaultErrorAccumulator{ - Buffer: &test.FailingErrorBuffer{}, - } - err := accumulator.Write([]byte("{")) - if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) { - t.Fatalf("Did not return error when write failed: %v", err) - } +func TestDefaultErrorAccumulator_WriteError(t *testing.T) { + ea := &openai.DefaultErrorAccumulator{Buffer: &test.FailingErrorBuffer{}} + err := ea.Write([]byte("fail")) + checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Write should propagate buffer errors") } diff --git a/internal/form_builder.go b/internal/form_builder.go index 2224fad45..a17e820c0 100644 --- a/internal/form_builder.go +++ b/internal/form_builder.go @@ -4,8 +4,10 @@ import ( "fmt" "io" "mime/multipart" + "net/textproto" "os" - "path" + "path/filepath" + "strings" ) type FormBuilder interface { @@ -30,8 +32,50 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er return fb.createFormFile(fieldname, file, file.Name()) } +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +// CreateFormFileReader creates a form field with a file reader. +// The filename in Content-Disposition is required. func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error { - return fb.createFormFile(fieldname, r, path.Base(filename)) + if filename == "" { + if f, ok := r.(interface{ Name() string }); ok { + filename = f.Name() + } + } + var contentType string + if f, ok := r.(interface{ ContentType() string }); ok { + contentType = f.ContentType() + } + + h := make(textproto.MIMEHeader) + h.Set( + "Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + escapeQuotes(fieldname), + escapeQuotes(filepath.Base(filename)), + ), + ) + // content type is optional, but it can be set + if contentType != "" { + h.Set("Content-Type", contentType) + } + + fieldWriter, err := fb.writer.CreatePart(h) + if err != nil { + return err + } + + _, err = io.Copy(fieldWriter, r) + if err != nil { + return err + } + + return nil } func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error { @@ -53,6 +97,9 @@ func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, file } func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error { + if fieldname == "" { + return fmt.Errorf("fieldname cannot be empty") + } return fb.writer.WriteField(fieldname, value) } diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go index d3faf9982..53ef11d23 100644 --- a/internal/form_builder_test.go +++ b/internal/form_builder_test.go @@ -1,15 +1,58 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" + "errors" + "io" + "github.com/sashabaranov/go-openai/internal/test/checks" "bytes" - "errors" "os" + "strings" "testing" ) +type mockFormBuilder struct { + mockCreateFormFile func(string, *os.File) error + mockWriteField func(string, string) error + mockClose func() error +} + +func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error { + return m.mockCreateFormFile(fieldname, file) +} + +func (m *mockFormBuilder) WriteField(fieldname, value string) error { + return m.mockWriteField(fieldname, value) +} + +func (m *mockFormBuilder) Close() error { + return m.mockClose() +} + +func (m *mockFormBuilder) FormDataContentType() string { + return "" +} + +func TestCloseMethod(t *testing.T) { + t.Run("NormalClose", func(t *testing.T) { + body := &bytes.Buffer{} + builder := NewFormBuilder(body) + checks.NoError(t, builder.Close(), "正常关闭应成功") + }) + + t.Run("ErrorPropagation", func(t *testing.T) { + errorMock := errors.New("mock close error") + mockBuilder := &mockFormBuilder{ + mockClose: func() error { + return errorMock + }, + } + err := mockBuilder.Close() + checks.ErrorIs(t, err, errorMock, "应传递关闭错误") + }) +} + type failingWriter struct { } @@ -20,15 +63,11 @@ func (*failingWriter) Write([]byte) (int, error) { } func TestFormBuilderWithFailingWriter(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } defer file.Close() - defer os.Remove(file.Name()) builder := NewFormBuilder(&failingWriter{}) err = builder.CreateFormFile("file", file) @@ -36,15 +75,11 @@ func TestFormBuilderWithFailingWriter(t *testing.T) { } func TestFormBuilderWithClosedFile(t *testing.T) { - dir, cleanup := test.CreateTestDirectory(t) - defer cleanup() - - file, err := os.CreateTemp(dir, "") + file, err := os.CreateTemp(t.TempDir(), "") if err != nil { - t.Errorf("Error creating tmp file: %v", err) + t.Fatalf("Error creating tmp file: %v", err) } file.Close() - defer os.Remove(file.Name()) body := &bytes.Buffer{} builder := NewFormBuilder(body) @@ -52,3 +87,104 @@ func TestFormBuilderWithClosedFile(t *testing.T) { checks.HasError(t, err, "formbuilder should return error if file is closed") checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed") } + +type failingReader struct { +} + +var errMockFailingReaderError = errors.New("mock reader failed") + +func (*failingReader) Read([]byte) (int, error) { + return 0, errMockFailingReaderError +} + +type readerWithNameAndContentType struct { + io.Reader +} + +func (*readerWithNameAndContentType) Name() string { + return "" +} + +func (*readerWithNameAndContentType) ContentType() string { + return "image/png" +} + +func TestFormBuilderWithReader(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "") + if err != nil { + t.Fatalf("Error creating tmp file: %v", err) + } + defer file.Close() + builder := NewFormBuilder(&failingWriter{}) + err = builder.CreateFormFileReader("file", file, file.Name()) + checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails") + + builder = NewFormBuilder(&bytes.Buffer{}) + reader := &failingReader{} + err = builder.CreateFormFileReader("file", reader, "") + checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails") + + successReader := &bytes.Buffer{} + err = builder.CreateFormFileReader("file", successReader, "") + checks.NoError(t, err, "formbuilder should not return error") + + rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}} + err = builder.CreateFormFileReader("file", rnc, "") + checks.NoError(t, err, "formbuilder should not return error") +} + +func TestFormDataContentType(t *testing.T) { + t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + contentType := builder.FormDataContentType() + if contentType == "" { + t.Errorf("expected non-empty content type, got empty string") + } + }) +} + +func TestWriteField(t *testing.T) { + t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("", "some value") + checks.HasError(t, err, "fieldname is required") + }) + + t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.WriteField("key", "value") + checks.NoError(t, err, "should write field without error") + }) +} + +func TestCreateFormFile(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "") + if err == nil { + t.Fatal("expected error for empty filename") + } + + builder = NewFormBuilder(&failingWriter{}) + err = builder.createFormFile("file", bytes.NewBufferString("data"), "name") + checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error") +} + +func TestCreateFormFileSuccess(t *testing.T) { + buf := &bytes.Buffer{} + builder := NewFormBuilder(buf) + + err := builder.createFormFile("file", bytes.NewBufferString("data"), "foo.txt") + checks.NoError(t, err, "createFormFile should succeed") + + if !strings.Contains(buf.String(), "filename=\"foo.txt\"") { + t.Fatalf("expected filename header, got %q", buf.String()) + } +} diff --git a/internal/marshaller_test.go b/internal/marshaller_test.go new file mode 100644 index 000000000..70694faed --- /dev/null +++ b/internal/marshaller_test.go @@ -0,0 +1,34 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONMarshaller_Normal(t *testing.T) { + jm := &openai.JSONMarshaller{} + data := map[string]string{"key": "value"} + + b, err := jm.Marshal(data) + checks.NoError(t, err) + if len(b) == 0 { + t.Fatal("should return non-empty bytes") + } +} + +func TestJSONMarshaller_InvalidInput(t *testing.T) { + jm := &openai.JSONMarshaller{} + _, err := jm.Marshal(make(chan int)) + checks.HasError(t, err, "should return error for unsupported type") +} + +func TestJSONMarshaller_EmptyValue(t *testing.T) { + jm := &openai.JSONMarshaller{} + b, err := jm.Marshal(nil) + checks.NoError(t, err) + if string(b) != "null" { + t.Fatalf("unexpected marshaled value: %s", string(b)) + } +} diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e26022a6b..adccb158e 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "io" "net/http" "reflect" "testing" @@ -59,3 +60,37 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { t.Errorf("Build() got = %v, want %v", got, want) } } + +func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) { + b := NewRequestBuilder() + ctx := context.Background() + method := http.MethodPost + url := "/reader" + bodyContent := "hello" + body := bytes.NewBufferString(bodyContent) + header := http.Header{"X-Test": []string{"val"}} + + req, err := b.Build(ctx, method, url, body, header) + if err != nil { + t.Fatalf("Build returned error: %v", err) + } + + gotBody, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("cannot read body: %v", err) + } + if string(gotBody) != bodyContent { + t.Fatalf("expected body %q, got %q", bodyContent, string(gotBody)) + } + if req.Header.Get("X-Test") != "val" { + t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test")) + } +} + +func TestRequestBuilderInvalidURL(t *testing.T) { + b := NewRequestBuilder() + _, err := b.Build(context.Background(), http.MethodGet, ":", nil, nil) + if err == nil { + t.Fatal("expected error for invalid URL") + } +} diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..6bd0964c6 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..0677054df --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,19 @@ +package checks_test + +import ( + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestChecksSuccessPaths(t *testing.T) { + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) + target := errors.New("x") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..fb1f4bf06 --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,24 @@ +//nolint:testpackage // need access to unexported fields and types for testing +package test + +import ( + "errors" + "testing" +) + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 0e63ae82f..dc5fa6646 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -19,16 +19,6 @@ func CreateTestFile(t *testing.T, path string) { file.Close() } -// CreateTestDirectory creates a temporary folder which will be deleted when cleanup is called. -func CreateTestDirectory(t *testing.T) (path string, cleanup func()) { - t.Helper() - - path, err := os.MkdirTemp(os.TempDir(), "") - checks.NoError(t, err) - - return path, func() { os.RemoveAll(path) } -} - // TokenRoundTripper is a struct that implements the RoundTripper // interface, specifically to handle the authentication token by adding a token // to the request header. We need this because the API requires that each diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..aa177679b --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,54 @@ +package test_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + internaltest.CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -23,6 +23,18 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { // to make the registered paths friendlier to a regex match in the route handler // in OpenAITestServer diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..f8ce731d1 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,80 @@ +package test_test + +import ( + "io" + "net/http" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestGetTestToken(t *testing.T) { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := internaltest.NewTestServer() + if ts == nil { + t.Fatalf("server not properly initialized") + } + if ts.HandlerCount() != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if !ts.HasHandler("/foo/*") { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} diff --git a/internal/unmarshaler_test.go b/internal/unmarshaler_test.go new file mode 100644 index 000000000..d63eac779 --- /dev/null +++ b/internal/unmarshaler_test.go @@ -0,0 +1,37 @@ +package openai_test + +import ( + "testing" + + openai "github.com/sashabaranov/go-openai/internal" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestJSONUnmarshaler_Normal(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{"key":"value"}`) + var v map[string]string + + err := jm.Unmarshal(data, &v) + checks.NoError(t, err) + if v["key"] != "value" { + t.Fatal("unmarshal result mismatch") + } +} + +func TestJSONUnmarshaler_InvalidJSON(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + data := []byte(`{invalid}`) + var v map[string]interface{} + + err := jm.Unmarshal(data, &v) + checks.HasError(t, err, "should return error for invalid JSON") +} + +func TestJSONUnmarshaler_EmptyInput(t *testing.T) { + jm := &openai.JSONUnmarshaler{} + var v interface{} + + err := jm.Unmarshal(nil, &v) + checks.HasError(t, err, "should return error for nil input") +} diff --git a/jsonschema/containsref_test.go b/jsonschema/containsref_test.go new file mode 100644 index 000000000..dc1842775 --- /dev/null +++ b/jsonschema/containsref_test.go @@ -0,0 +1,48 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// SelfRef struct used to produce a self-referential schema. +type SelfRef struct { + Friends []SelfRef `json:"friends"` +} + +// Address struct referenced by Person without self-reference. +type Address struct { + Street string `json:"street"` +} + +type Person struct { + Address Address `json:"address"` +} + +// TestGenerateSchemaForType_SelfRef ensures that self-referential types are not +// flattened during schema generation. +func TestGenerateSchemaForType_SelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(SelfRef{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["SelfRef"]; !ok { + t.Fatal("expected defs to contain SelfRef for self reference") + } +} + +// TestGenerateSchemaForType_NoSelfRef ensures that non-self-referential types +// are flattened and do not reappear in $defs. +func TestGenerateSchemaForType_NoSelfRef(t *testing.T) { + schema, err := jsonschema.GenerateSchemaForType(Person{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, ok := schema.Defs["Person"]; ok { + t.Fatal("unexpected Person definition in defs") + } + if _, ok := schema.Defs["Address"]; !ok { + t.Fatal("expected Address definition in defs") + } +} diff --git a/jsonschema/json.go b/jsonschema/json.go index cb941eb75..75e3b5173 100644 --- a/jsonschema/json.go +++ b/jsonschema/json.go @@ -4,7 +4,13 @@ // and/or pass in the schema in []byte format. package jsonschema -import "encoding/json" +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" +) type DataType string @@ -29,14 +35,27 @@ type Definition struct { // one element, where each element is unique. You will probably only use this with strings. Enum []string `json:"enum,omitempty"` // Properties describes the properties of an object, if the schema type is Object. - Properties map[string]Definition `json:"properties"` + Properties map[string]Definition `json:"properties,omitempty"` // Required specifies which properties are required, if the schema type is Object. Required []string `json:"required,omitempty"` // Items specifies which data type an array contains, if the schema type is Array. Items *Definition `json:"items,omitempty"` + // AdditionalProperties is used to control the handling of properties in an object + // that are not explicitly defined in the properties section of the schema. example: + // additionalProperties: true + // additionalProperties: false + // additionalProperties: jsonschema.Definition{Type: jsonschema.String} + AdditionalProperties any `json:"additionalProperties,omitempty"` + // Whether the schema is nullable or not. + Nullable bool `json:"nullable,omitempty"` + + // Ref Reference to a definition in $defs or external schema. + Ref string `json:"$ref,omitempty"` + // Defs A map of reusable schema definitions. + Defs map[string]Definition `json:"$defs,omitempty"` } -func (d Definition) MarshalJSON() ([]byte, error) { +func (d *Definition) MarshalJSON() ([]byte, error) { if d.Properties == nil { d.Properties = make(map[string]Definition) } @@ -44,6 +63,173 @@ func (d Definition) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Alias }{ - Alias: (Alias)(d), + Alias: (Alias)(*d), }) } + +func (d *Definition) Unmarshal(content string, v any) error { + return VerifySchemaAndUnmarshal(*d, []byte(content), v) +} + +func GenerateSchemaForType(v any) (*Definition, error) { + var defs = make(map[string]Definition) + def, err := reflectSchema(reflect.TypeOf(v), defs) + if err != nil { + return nil, err + } + // If the schema has a root $ref, resolve it by: + // 1. Extracting the key from the $ref. + // 2. Detaching the referenced definition from $defs. + // 3. Checking for self-references in the detached definition. + // - If a self-reference is found, restore the original $defs structure. + // 4. Flattening the referenced definition into the root schema. + // 5. Clearing the $ref field in the root schema. + if def.Ref != "" { + origRef := def.Ref + key := strings.TrimPrefix(origRef, "#/$defs/") + if root, ok := defs[key]; ok { + delete(defs, key) + root.Defs = defs + if containsRef(root, origRef) { + root.Defs = nil + defs[key] = root + } + *def = root + } + def.Ref = "" + } + def.Defs = defs + return def, nil +} + +func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) { + var d Definition + switch t.Kind() { + case reflect.String: + d.Type = String + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + d.Type = Integer + case reflect.Float32, reflect.Float64: + d.Type = Number + case reflect.Bool: + d.Type = Boolean + case reflect.Slice, reflect.Array: + d.Type = Array + items, err := reflectSchema(t.Elem(), defs) + if err != nil { + return nil, err + } + d.Items = items + case reflect.Struct: + if t.Name() != "" { + if _, ok := defs[t.Name()]; !ok { + defs[t.Name()] = Definition{} + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + defs[t.Name()] = *object + } + return &Definition{Ref: "#/$defs/" + t.Name()}, nil + } + d.Type = Object + d.AdditionalProperties = false + object, err := reflectSchemaObject(t, defs) + if err != nil { + return nil, err + } + d = *object + case reflect.Ptr: + definition, err := reflectSchema(t.Elem(), defs) + if err != nil { + return nil, err + } + d = *definition + case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128, + reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.UnsafePointer: + return nil, fmt.Errorf("unsupported type: %s", t.Kind().String()) + default: + } + return &d, nil +} + +func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) { + var d = Definition{ + Type: Object, + AdditionalProperties: false, + } + properties := make(map[string]Definition) + var requiredFields []string + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if !field.IsExported() { + continue + } + jsonTag := field.Tag.Get("json") + var required = true + switch { + case jsonTag == "-": + continue + case jsonTag == "": + jsonTag = field.Name + case strings.HasSuffix(jsonTag, ",omitempty"): + jsonTag = strings.TrimSuffix(jsonTag, ",omitempty") + required = false + } + + item, err := reflectSchema(field.Type, defs) + if err != nil { + return nil, err + } + description := field.Tag.Get("description") + if description != "" { + item.Description = description + } + enum := field.Tag.Get("enum") + if enum != "" { + item.Enum = strings.Split(enum, ",") + } + + if n := field.Tag.Get("nullable"); n != "" { + nullable, _ := strconv.ParseBool(n) + item.Nullable = nullable + } + + properties[jsonTag] = *item + + if s := field.Tag.Get("required"); s != "" { + required, _ = strconv.ParseBool(s) + } + if required { + requiredFields = append(requiredFields, jsonTag) + } + } + d.Required = requiredFields + d.Properties = properties + return &d, nil +} + +func containsRef(def Definition, targetRef string) bool { + if def.Ref == targetRef { + return true + } + + for _, d := range def.Defs { + if containsRef(d, targetRef) { + return true + } + } + + for _, prop := range def.Properties { + if containsRef(prop, targetRef) { + return true + } + } + + if def.Items != nil && containsRef(*def.Items, targetRef) { + return true + } + return false +} diff --git a/jsonschema/json_additional_test.go b/jsonschema/json_additional_test.go new file mode 100644 index 000000000..70cf37490 --- /dev/null +++ b/jsonschema/json_additional_test.go @@ -0,0 +1,73 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// Test Definition.Unmarshal, including success path, validation error, +// JSON syntax error and type mismatch during unmarshalling. +func TestDefinitionUnmarshal(t *testing.T) { + schema := jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + }, + } + + var dst struct { + Name string `json:"name"` + } + if err := schema.Unmarshal(`{"name":"foo"}`, &dst); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dst.Name != "foo" { + t.Errorf("expected name to be foo, got %q", dst.Name) + } + + if err := schema.Unmarshal(`{`, &dst); err == nil { + t.Error("expected error for malformed json") + } + + if err := schema.Unmarshal(`{"name":1}`, &dst); err == nil { + t.Error("expected validation error") + } + + numSchema := jsonschema.Definition{Type: jsonschema.Number} + var s string + if err := numSchema.Unmarshal(`123`, &s); err == nil { + t.Error("expected unmarshal type error") + } +} + +// Ensure GenerateSchemaForType returns an error when encountering unsupported types. +func TestGenerateSchemaForTypeUnsupported(t *testing.T) { + type Bad struct { + Ch chan int `json:"ch"` + } + if _, err := jsonschema.GenerateSchemaForType(Bad{}); err == nil { + t.Fatal("expected error for unsupported type") + } +} + +// Validate should fail when provided data does not match the expected container types. +func TestValidateInvalidContainers(t *testing.T) { + objSchema := jsonschema.Definition{Type: jsonschema.Object} + if jsonschema.Validate(objSchema, 1) { + t.Error("expected object validation to fail for non-map input") + } + + arrSchema := jsonschema.Definition{Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}} + if jsonschema.Validate(arrSchema, 1) { + t.Error("expected array validation to fail for non-slice input") + } +} + +// Validate should return false when $ref cannot be resolved. +func TestValidateRefNotFound(t *testing.T) { + refSchema := jsonschema.Definition{Ref: "#/$defs/Missing"} + if jsonschema.Validate(refSchema, "data", jsonschema.WithDefs(map[string]jsonschema.Definition{})) { + t.Error("expected validation to fail when reference is missing") + } +} diff --git a/jsonschema/json_errors_test.go b/jsonschema/json_errors_test.go new file mode 100644 index 000000000..3b864fc21 --- /dev/null +++ b/jsonschema/json_errors_test.go @@ -0,0 +1,27 @@ +package jsonschema_test + +import ( + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +// TestGenerateSchemaForType_ErrorPaths verifies error handling for unsupported types. +func TestGenerateSchemaForType_ErrorPaths(t *testing.T) { + type anon struct{ Ch chan int } + tests := []struct { + name string + v any + }{ + {"slice", []chan int{}}, + {"anon struct", anon{}}, + {"pointer", (*chan int)(nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, err := jsonschema.GenerateSchemaForType(tt.v); err == nil { + t.Errorf("expected error for %s", tt.name) + } + }) + } +} diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go index 744706082..34f5d88eb 100644 --- a/jsonschema/json_test.go +++ b/jsonschema/json_test.go @@ -17,7 +17,7 @@ func TestDefinition_MarshalJSON(t *testing.T) { { name: "Test with empty Definition", def: jsonschema.Definition{}, - want: `{"properties":{}}`, + want: `{}`, }, { name: "Test with Definition properties set", @@ -31,15 +31,14 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"string", - "description":"A string type", - "properties":{ - "name":{ - "type":"string", - "properties":{} - } - } -}`, + "type":"string", + "description":"A string type", + "properties":{ + "name":{ + "type":"string" + } + } + }`, }, { name: "Test with nested Definition properties", @@ -60,23 +59,21 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"object", - "properties":{ - "user":{ - "type":"object", - "properties":{ - "name":{ - "type":"string", - "properties":{} - }, - "age":{ - "type":"integer", - "properties":{} - } - } - } - } -}`, + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + } + } + } + } + }`, }, { name: "Test with complex nested Definition", @@ -108,36 +105,32 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"object", - "properties":{ - "user":{ - "type":"object", - "properties":{ - "name":{ - "type":"string", - "properties":{} - }, - "age":{ - "type":"integer", - "properties":{} - }, - "address":{ - "type":"object", - "properties":{ - "city":{ - "type":"string", - "properties":{} - }, - "country":{ - "type":"string", - "properties":{} - } - } - } - } - } - } -}`, + "type":"object", + "properties":{ + "user":{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "address":{ + "type":"object", + "properties":{ + "city":{ + "type":"string" + }, + "country":{ + "type":"string" + } + } + } + } + } + } + }`, }, { name: "Test with Array type Definition", @@ -153,20 +146,16 @@ func TestDefinition_MarshalJSON(t *testing.T) { }, }, want: `{ - "type":"array", - "items":{ - "type":"string", - "properties":{ - - } - }, - "properties":{ - "name":{ - "type":"string", - "properties":{} - } - } -}`, + "type":"array", + "items":{ + "type":"string" + }, + "properties":{ + "name":{ + "type":"string" + } + } + }`, }, } @@ -193,6 +182,476 @@ func TestDefinition_MarshalJSON(t *testing.T) { } } +type User struct { + ID int `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Orders []*Order `json:"orders,omitempty"` +} + +type Order struct { + ID int `json:"id,omitempty"` + Amount float64 `json:"amount,omitempty"` + Buyer *User `json:"buyer,omitempty"` +} + +func TestStructToSchema(t *testing.T) { + type Tweet struct { + Text string `json:"text"` + } + + type Person struct { + Name string `json:"name,omitempty"` + Age int `json:"age,omitempty"` + Friends []Person `json:"friends,omitempty"` + Tweets []Tweet `json:"tweets,omitempty"` + } + + type MyStructuredResponse struct { + PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"` + CamelCase string `json:"camel_case" required:"true" description:"CamelCase"` + KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"` + SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"` + } + + tests := []struct { + name string + in any + want string + }{ + { + name: "Test with empty struct", + in: struct{}{}, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with struct containing many fields", + in: struct { + Name string `json:"name"` + Age int `json:"age"` + Active bool `json:"active"` + Height float64 `json:"height"` + Cities []struct { + Name string `json:"name"` + State string `json:"state"` + } `json:"cities"` + }{ + Name: "John Doe", + Age: 30, + Cities: []struct { + Name string `json:"name"` + State string `json:"state"` + }{ + {Name: "New York", State: "NY"}, + {Name: "Los Angeles", State: "CA"}, + }, + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "age":{ + "type":"integer" + }, + "active":{ + "type":"boolean" + }, + "height":{ + "type":"number" + }, + "cities":{ + "type":"array", + "items":{ + "additionalProperties":false, + "type":"object", + "properties":{ + "name":{ + "type":"string" + }, + "state":{ + "type":"string" + } + }, + "required":["name","state"] + } + } + }, + "required":["name","age","active","height","cities"], + "additionalProperties":false + }`, + }, + { + name: "Test with description tag", + in: struct { + Name string `json:"name" description:"The name of the person"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string", + "description":"The name of the person" + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + { + name: "Test with required tag", + in: struct { + Name string `json:"name" required:"false"` + }{ + Name: "John Doe", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, + { + name: "Test with enum tag", + in: struct { + Color string `json:"color" enum:"red,green,blue"` + }{ + Color: "red", + }, + want: `{ + "type":"object", + "properties":{ + "color":{ + "type":"string", + "enum":["red","green","blue"] + } + }, + "required":["color"], + "additionalProperties":false + }`, + }, + { + name: "Test with nullable tag", + in: struct { + Name *string `json:"name" nullable:"true"` + }{ + Name: nil, + }, + want: `{ + + "type":"object", + "properties":{ + "name":{ + "type":"string", + "nullable":true + } + }, + "required":["name"], + "additionalProperties":false + }`, + }, + { + name: "Test with exclude mark", + in: struct { + Name string `json:"-"` + }{ + Name: "Name", + }, + want: `{ + "type":"object", + "additionalProperties":false + }`, + }, + { + name: "Test with no json tag", + in: struct { + Name string + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "Name":{ + "type":"string" + } + }, + "required":["Name"], + "additionalProperties":false + }`, + }, + { + name: "Test with omitempty tag", + in: struct { + Name string `json:"name,omitempty"` + }{ + Name: "", + }, + want: `{ + "type":"object", + "properties":{ + "name":{ + "type":"string" + } + }, + "additionalProperties":false + }`, + }, + { + name: "Test with $ref and $defs", + in: struct { + Person Person `json:"person"` + Tweets []Tweet `json:"tweets"` + }{}, + want: `{ + "type" : "object", + "properties" : { + "person" : { + "$ref" : "#/$defs/Person" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "required" : [ "person", "tweets" ], + "additionalProperties" : false, + "$defs" : { + "Person" : { + "type" : "object", + "properties" : { + "age" : { + "type" : "integer" + }, + "friends" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Person" + } + }, + "name" : { + "type" : "string" + }, + "tweets" : { + "type" : "array", + "items" : { + "$ref" : "#/$defs/Tweet" + } + } + }, + "additionalProperties" : false + }, + "Tweet" : { + "type" : "object", + "properties" : { + "text" : { + "type" : "string" + } + }, + "required" : [ "text" ], + "additionalProperties" : false + } + } +}`, + }, + { + name: "Test Person", + in: Person{}, + want: `{ + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false, + "$defs": { + "Person": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/Person" + } + }, + "name": { + "type": "string" + }, + "tweets": { + "type": "array", + "items": { + "$ref": "#/$defs/Tweet" + } + } + }, + "additionalProperties": false + }, + "Tweet": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + }, + "required": [ + "text" + ], + "additionalProperties": false + } + } +}`, + }, + { + name: "Test MyStructuredResponse", + in: MyStructuredResponse{}, + want: `{ + "type": "object", + "properties": { + "camel_case": { + "type": "string", + "description": "CamelCase" + }, + "kebab_case": { + "type": "string", + "description": "KebabCase" + }, + "pascal_case": { + "type": "string", + "description": "PascalCase" + }, + "snake_case": { + "type": "string", + "description": "SnakeCase" + } + }, + "required": [ + "pascal_case", + "camel_case", + "kebab_case", + "snake_case" + ], + "additionalProperties": false +}`, + }, + { + name: "Test User", + in: User{}, + want: `{ + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false, + "$defs": { + "Order": { + "type": "object", + "properties": { + "amount": { + "type": "number" + }, + "buyer": { + "$ref": "#/$defs/User" + }, + "id": { + "type": "integer" + } + }, + "additionalProperties": false + }, + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string" + }, + "orders": { + "type": "array", + "items": { + "$ref": "#/$defs/Order" + } + } + }, + "additionalProperties": false + } + } +}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantBytes := []byte(tt.want) + + schema, err := jsonschema.GenerateSchemaForType(tt.in) + if err != nil { + t.Errorf("Failed to generate schema: error = %v", err) + return + } + + var want map[string]interface{} + err = json.Unmarshal(wantBytes, &want) + if err != nil { + t.Errorf("Failed to Unmarshal JSON: error = %v", err) + return + } + + got := structToMap(t, schema) + gotPtr := structToMap(t, &schema) + + if !reflect.DeepEqual(got, want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, want) + } + if !reflect.DeepEqual(gotPtr, want) { + t.Errorf("MarshalJSON() gotPtr = %v, want %v", gotPtr, want) + } + }) + } +} + func structToMap(t *testing.T, v any) map[string]any { t.Helper() gotBytes, err := json.Marshal(v) diff --git a/jsonschema/validate.go b/jsonschema/validate.go new file mode 100644 index 000000000..1bd2f809c --- /dev/null +++ b/jsonschema/validate.go @@ -0,0 +1,140 @@ +package jsonschema + +import ( + "encoding/json" + "errors" +) + +func CollectDefs(def Definition) map[string]Definition { + result := make(map[string]Definition) + collectDefsRecursive(def, result, "#") + return result +} + +func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) { + for k, v := range def.Defs { + path := prefix + "/$defs/" + k + result[path] = v + collectDefsRecursive(v, result, path) + } + for k, sub := range def.Properties { + collectDefsRecursive(sub, result, prefix+"/properties/"+k) + } + if def.Items != nil { + collectDefsRecursive(*def.Items, result, prefix) + } +} + +func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error { + var data any + err := json.Unmarshal(content, &data) + if err != nil { + return err + } + if !Validate(schema, data, WithDefs(CollectDefs(schema))) { + return errors.New("data validation failed against the provided schema") + } + return json.Unmarshal(content, &v) +} + +type validateArgs struct { + Defs map[string]Definition +} + +type ValidateOption func(*validateArgs) + +func WithDefs(defs map[string]Definition) ValidateOption { + return func(option *validateArgs) { + option.Defs = defs + } +} + +func Validate(schema Definition, data any, opts ...ValidateOption) bool { + args := validateArgs{} + for _, opt := range opts { + opt(&args) + } + if len(opts) == 0 { + args.Defs = CollectDefs(schema) + } + switch schema.Type { + case Object: + return validateObject(schema, data, args.Defs) + case Array: + return validateArray(schema, data, args.Defs) + case String: + v, ok := data.(string) + if ok && len(schema.Enum) > 0 { + return contains(schema.Enum, v) + } + return ok + case Number: // float64 and int + _, ok := data.(float64) + if !ok { + _, ok = data.(int) + } + return ok + case Boolean: + _, ok := data.(bool) + return ok + case Integer: + // Golang unmarshals all numbers as float64, so we need to check if the float64 is an integer + if num, ok := data.(float64); ok { + return num == float64(int64(num)) + } + _, ok := data.(int) + return ok + case Null: + return data == nil + default: + if schema.Ref != "" && args.Defs != nil { + if v, ok := args.Defs[schema.Ref]; ok { + return Validate(v, data, WithDefs(args.Defs)) + } + } + return false + } +} + +func validateObject(schema Definition, data any, defs map[string]Definition) bool { + dataMap, ok := data.(map[string]any) + if !ok { + return false + } + for _, field := range schema.Required { + if _, exists := dataMap[field]; !exists { + return false + } + } + for key, valueSchema := range schema.Properties { + value, exists := dataMap[key] + if exists && !Validate(valueSchema, value, WithDefs(defs)) { + return false + } else if !exists && contains(schema.Required, key) { + return false + } + } + return true +} + +func validateArray(schema Definition, data any, defs map[string]Definition) bool { + dataArray, ok := data.([]any) + if !ok { + return false + } + for _, item := range dataArray { + if !Validate(*schema.Items, item, WithDefs(defs)) { + return false + } + } + return true +} + +func contains[S ~[]E, E comparable](s S, v E) bool { + for i := range s { + if v == s[i] { + return true + } + } + return false +} diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go new file mode 100644 index 000000000..aefdf4069 --- /dev/null +++ b/jsonschema/validate_test.go @@ -0,0 +1,347 @@ +package jsonschema_test + +import ( + "reflect" + "testing" + + "github.com/sashabaranov/go-openai/jsonschema" +) + +func Test_Validate(t *testing.T) { + type args struct { + data any + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want bool + }{ + // string integer number boolean + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.String}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.String}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Integer}}, true}, + {"", args{data: 123.4, schema: jsonschema.Definition{Type: jsonschema.Integer}}, false}, + {"", args{data: "ABC", schema: jsonschema.Definition{Type: jsonschema.Number}}, false}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Number}}, true}, + {"", args{data: false, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, true}, + {"", args{data: 123, schema: jsonschema.Definition{Type: jsonschema.Boolean}}, false}, + {"", args{data: nil, schema: jsonschema.Definition{Type: jsonschema.Null}}, true}, + {"", args{data: 0, schema: jsonschema.Definition{Type: jsonschema.Null}}, false}, + // array + {"", args{data: []any{"a", "b", "c"}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, true}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}, + }, false}, + {"", args{data: []any{1, 2, 3}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, true}, + {"", args{data: []any{1, 2, 3.4}, schema: jsonschema.Definition{ + Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Integer}}, + }, false}, + // object + {"", args{data: map[string]any{ + "string": "abc", + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, true}, + {"", args{data: map[string]any{ + "integer": 123, + "number": 123.4, + "boolean": false, + "array": []any{1, 2, 3}, + }, schema: jsonschema.Definition{Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + "number": {Type: jsonschema.Number}, + "boolean": {Type: jsonschema.Boolean}, + "array": {Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.Number}}, + }, + Required: []string{"string"}, + }}, false}, + { + "test schema with ref and defs", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "male", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, true}, + { + "test enum invalid value", args{data: map[string]any{ + "person": map[string]any{ + "name": "John", + "gender": "other", + "age": 28, + "profile": map[string]any{ + "full_name": "John Doe", + }, + }, + }, schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := jsonschema.Validate(tt.args.schema, tt.args.data); got != tt.want { + t.Errorf("Validate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnmarshal(t *testing.T) { + type args struct { + schema jsonschema.Definition + content []byte + v any + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + }, + content: []byte(`{"string":"abc","number":123.4}`), + v: &struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, false}, + {"", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "number": {Type: jsonschema.Number}, + }, + Required: []string{"string", "number"}, + }, + content: []byte(`{"string":"abc"}`), + v: struct { + String string `json:"string"` + Number float64 `json:"number"` + }{}, + }, true}, + {"validate integer", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, false}, + {"validate integer failed", args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "string": {Type: jsonschema.String}, + "integer": {Type: jsonschema.Integer}, + }, + Required: []string{"string", "integer"}, + }, + content: []byte(`{"string":"abc","integer":123.4}`), + v: &struct { + String string `json:"string"` + Integer int `json:"integer"` + }{}, + }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v) + if (err != nil) != tt.wantErr { + t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestCollectDefs(t *testing.T) { + type args struct { + schema jsonschema.Definition + } + tests := []struct { + name string + args args + want map[string]jsonschema.Definition + }{ + { + "test collect defs", + args{ + schema: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "person": {Ref: "#/$defs/Person"}, + }, + Required: []string{"person"}, + Defs: map[string]jsonschema.Definition{ + "Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + }, + map[string]jsonschema.Definition{ + "#/$defs/Person": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "name": {Type: jsonschema.String}, + "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}}, + "age": {Type: jsonschema.Integer}, + "profile": {Ref: "#/$defs/Person/$defs/Profile"}, + "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}}, + }, + Required: []string{"name", "gender", "age", "profile"}, + Defs: map[string]jsonschema.Definition{ + "Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + }, + }, + "#/$defs/Person/$defs/Profile": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "full_name": {Type: jsonschema.String}, + }, + }, + "#/$defs/Tweet": { + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "text": {Type: jsonschema.String}, + "person": {Ref: "#/$defs/Person"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := jsonschema.CollectDefs(tt.args.schema) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CollectDefs() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/messages.go b/messages.go index 6fd0adbc9..3852d2e37 100644 --- a/messages.go +++ b/messages.go @@ -41,6 +41,7 @@ type MessageContent struct { Type string `json:"type"` Text *MessageText `json:"text,omitempty"` ImageFile *ImageFile `json:"image_file,omitempty"` + ImageURL *ImageURL `json:"image_url,omitempty"` } type MessageText struct { Value string `json:"value"` @@ -51,11 +52,17 @@ type ImageFile struct { FileID string `json:"file_id"` } +type ImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + type MessageRequest struct { - Role string `json:"role"` - Content string `json:"content"` - FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility - Metadata map[string]any `json:"metadata,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + FileIds []string `json:"file_ids,omitempty"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` } type MessageFile struct { @@ -73,10 +80,19 @@ type MessageFilesList struct { httpHeader } +type MessageDeletionStatus struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + // CreateMessage creates a new message. func (c *Client) CreateMessage(ctx context.Context, threadID string, request MessageRequest) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s", threadID, messagesSuffix) - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -91,6 +107,7 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, order *string, after *string, before *string, + runID *string, ) (messages MessagesList, err error) { urlValues := url.Values{} if limit != nil { @@ -105,13 +122,18 @@ func (c *Client) ListMessage(ctx context.Context, threadID string, if before != nil { urlValues.Add("before", *before) } + if runID != nil { + urlValues.Add("run_id", *runID) + } + encodedValues := "" if len(urlValues) > 0 { encodedValues = "?" + urlValues.Encode() } urlSuffix := fmt.Sprintf("/threads/%s/%s%s", threadID, messagesSuffix, encodedValues) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -126,7 +148,8 @@ func (c *Client) RetrieveMessage( threadID, messageID string, ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -143,7 +166,7 @@ func (c *Client) ModifyMessage( ) (msg Message, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), - withBody(map[string]any{"metadata": metadata}), withBetaAssistantV1()) + withBody(map[string]any{"metadata": metadata}), withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -158,7 +181,8 @@ func (c *Client) RetrieveMessageFile( threadID, messageID, fileID string, ) (file MessageFile, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files/%s", threadID, messagesSuffix, messageID, fileID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -173,7 +197,8 @@ func (c *Client) ListMessageFiles( threadID, messageID string, ) (files MessageFilesList, err error) { urlSuffix := fmt.Sprintf("/threads/%s/%s/%s/files", threadID, messagesSuffix, messageID) - req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), withBetaAssistantV1()) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -181,3 +206,19 @@ func (c *Client) ListMessageFiles( err = c.sendRequest(req, &files) return } + +// DeleteMessage deletes a message.. +func (c *Client) DeleteMessage( + ctx context.Context, + threadID, messageID string, +) (status MessageDeletionStatus, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/%s/%s", threadID, messagesSuffix, messageID) + req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + if err != nil { + return + } + + err = c.sendRequest(req, &status) + return +} diff --git a/messages_test.go b/messages_test.go index a18be20bd..b25755f98 100644 --- a/messages_test.go +++ b/messages_test.go @@ -8,20 +8,17 @@ import ( "testing" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) var emptyStr = "" -// TestMessages Tests the messages endpoint of the API using the mocked server. -func TestMessages(t *testing.T) { +func setupServerForTestMessage(t *testing.T, server *test.ServerTest) { threadID := "thread_abc123" messageID := "msg_abc123" fileID := "file_abc123" - client, server, teardown := setupOpenAITestServer() - defer teardown() - server.RegisterHandler( "/v1/threads/"+threadID+"/messages/"+messageID+"/files/"+fileID, func(w http.ResponseWriter, r *http.Request) { @@ -115,6 +112,13 @@ func TestMessages(t *testing.T) { Metadata: nil, }) fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + resBytes, _ := json.Marshal(openai.MessageDeletionStatus{ + ID: messageID, + Object: "thread.message.deleted", + Deleted: true, + }) + fmt.Fprintln(w, string(resBytes)) default: t.Fatalf("unsupported messages http method: %s", r.Method) } @@ -176,7 +180,18 @@ func TestMessages(t *testing.T) { } }, ) +} +// TestMessages Tests the messages endpoint of the API using the mocked server. +func TestMessages(t *testing.T) { + threadID := "thread_abc123" + messageID := "msg_abc123" + fileID := "file_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + setupServerForTestMessage(t, server) ctx := context.Background() // static assertion of return type @@ -193,7 +208,7 @@ func TestMessages(t *testing.T) { } var msgs openai.MessagesList - msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil) + msgs, err = client.ListMessage(ctx, threadID, nil, nil, nil, nil, nil) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") @@ -204,7 +219,8 @@ func TestMessages(t *testing.T) { order := "desc" after := "obj_foo" before := "obj_bar" - msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before) + runID := "run_abc123" + msgs, err = client.ListMessage(ctx, threadID, &limit, &order, &after, &before, &runID) checks.NoError(t, err, "ListMessages error") if len(msgs.Messages) != 1 { t.Fatalf("unexpected length of fetched messages") @@ -225,6 +241,17 @@ func TestMessages(t *testing.T) { t.Fatalf("expected message metadata to get modified") } + msgDel, err := client.DeleteMessage(ctx, threadID, messageID) + checks.NoError(t, err, "DeleteMessage error") + if msgDel.ID != messageID { + t.Fatalf("unexpected message id: '%s'", msg.ID) + } + if !msgDel.Deleted { + t.Fatalf("expected deleted is true") + } + _, err = client.DeleteMessage(ctx, threadID, "not_exist_id") + checks.HasError(t, err, "DeleteMessage error") + // message files var msgFile openai.MessageFile msgFile, err = client.RetrieveMessageFile(ctx, threadID, messageID, fileID) diff --git a/models_test.go b/models_test.go index 24a28ed23..7fd010c34 100644 --- a/models_test.go +++ b/models_test.go @@ -47,6 +47,24 @@ func TestGetModel(t *testing.T) { checks.NoError(t, err, "GetModel error") } +// TestGetModelO3 Tests the retrieve O3 model endpoint of the API using the mocked server. +func TestGetModelO3(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o3", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o3") + checks.NoError(t, err, "GetModel error for O3") +} + +// TestGetModelO4Mini Tests the retrieve O4Mini model endpoint of the API using the mocked server. +func TestGetModelO4Mini(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/models/o4-mini", handleGetModelEndpoint) + _, err := client.GetModel(context.Background(), "o4-mini") + checks.NoError(t, err, "GetModel error for O4Mini") +} + func TestAzureGetModel(t *testing.T) { client, server, teardown := setupAzureTestServer() defer teardown() diff --git a/moderation.go b/moderation.go index ae285ef83..a0e09c0ee 100644 --- a/moderation.go +++ b/moderation.go @@ -14,8 +14,10 @@ import ( // If you use text-moderation-stable, we will provide advanced notice before updating the model. // Accuracy of text-moderation-stable may be slightly lower than for text-moderation-latest. const ( - ModerationTextStable = "text-moderation-stable" - ModerationTextLatest = "text-moderation-latest" + ModerationOmniLatest = "omni-moderation-latest" + ModerationOmni20240926 = "omni-moderation-2024-09-26" + ModerationTextStable = "text-moderation-stable" + ModerationTextLatest = "text-moderation-latest" // Deprecated: use ModerationTextStable and ModerationTextLatest instead. ModerationText001 = "text-moderation-001" ) @@ -25,8 +27,10 @@ var ( ) var validModerationModel = map[string]struct{}{ - ModerationTextStable: {}, - ModerationTextLatest: {}, + ModerationOmniLatest: {}, + ModerationOmni20240926: {}, + ModerationTextStable: {}, + ModerationTextLatest: {}, } // ModerationRequest represents a request structure for moderation API. @@ -88,7 +92,12 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re err = ErrModerationInvalidModel return } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/moderations", request.Model), withBody(&request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/moderations", withModel(request.Model)), + withBody(&request), + ) if err != nil { return } diff --git a/moderation_test.go b/moderation_test.go index 61171c384..a97f25bc6 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -37,6 +37,8 @@ func TestModerationsWithDifferentModelOptions(t *testing.T) { getModerationModelTestOption(openai.GPT3Dot5Turbo, openai.ErrModerationInvalidModel), getModerationModelTestOption(openai.ModerationTextStable, nil), getModerationModelTestOption(openai.ModerationTextLatest, nil), + getModerationModelTestOption(openai.ModerationOmni20240926, nil), + getModerationModelTestOption(openai.ModerationOmniLatest, nil), getModerationModelTestOption("", nil), ) client, server, teardown := setupOpenAITestServer() diff --git a/openai_test.go b/openai_test.go index 729d8880c..a55f3a858 100644 --- a/openai_test.go +++ b/openai_test.go @@ -29,9 +29,9 @@ func setupAzureTestServer() (client *openai.Client, server *test.ServerTest, tea // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer +// https://beta.openai.com/tokenizer. // -// TODO: implement an actual tokenizer for GPT-3 and Codex (once available) +// TODO: implement an actual tokenizer for GPT-3 and Codex (once available). func numTokens(s string) int { return int(float32(len(s)) / 4) } diff --git a/reasoning_validator.go b/reasoning_validator.go new file mode 100644 index 000000000..1d26ca047 --- /dev/null +++ b/reasoning_validator.go @@ -0,0 +1,82 @@ +package openai + +import ( + "errors" + "strings" +) + +var ( + // Deprecated: use ErrReasoningModelMaxTokensDeprecated instead. + ErrO1MaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") //nolint:lll + ErrCompletionUnsupportedModel = errors.New("this model is not supported with this method, please use CreateChatCompletion client method instead") //nolint:lll + ErrCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateCompletionStream") //nolint:lll + ErrCompletionRequestPromptTypeNotSupported = errors.New("the type of CompletionRequest.Prompt only supports string and []string") //nolint:lll +) + +var ( + ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll + ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll + // Deprecated: use ErrReasoningModelLimitations* instead. + ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +var ( + //nolint:lll + ErrReasoningModelMaxTokensDeprecated = errors.New("this model is not supported MaxTokens, please use MaxCompletionTokens") + ErrReasoningModelLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll + ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll +) + +// ReasoningValidator handles validation for reasoning model requests. +type ReasoningValidator struct{} + +// NewReasoningValidator creates a new validator for reasoning models. +func NewReasoningValidator() *ReasoningValidator { + return &ReasoningValidator{} +} + +// Validate performs all validation checks for reasoning models. +func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error { + o1Series := strings.HasPrefix(request.Model, "o1") + o3Series := strings.HasPrefix(request.Model, "o3") + o4Series := strings.HasPrefix(request.Model, "o4") + gpt5Series := strings.HasPrefix(request.Model, "gpt-5") + + if !o1Series && !o3Series && !o4Series && !gpt5Series { + return nil + } + + if err := v.validateReasoningModelParams(request); err != nil { + return err + } + + return nil +} + +// validateReasoningModelParams checks reasoning model parameters. +func (v *ReasoningValidator) validateReasoningModelParams(request ChatCompletionRequest) error { + if request.MaxTokens > 0 { + return ErrReasoningModelMaxTokensDeprecated + } + if request.LogProbs { + return ErrReasoningModelLimitationsLogprobs + } + if request.Temperature > 0 && request.Temperature != 1 { + return ErrReasoningModelLimitationsOther + } + if request.TopP > 0 && request.TopP != 1 { + return ErrReasoningModelLimitationsOther + } + if request.N > 0 && request.N != 1 { + return ErrReasoningModelLimitationsOther + } + if request.PresencePenalty > 0 { + return ErrReasoningModelLimitationsOther + } + if request.FrequencyPenalty > 0 { + return ErrReasoningModelLimitationsOther + } + + return nil +} diff --git a/run.go b/run.go index 1f3cb7eb7..9c51aaf8d 100644 --- a/run.go +++ b/run.go @@ -28,6 +28,16 @@ type Run struct { Metadata map[string]any `json:"metadata"` Usage Usage `json:"usage,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + httpHeader } @@ -40,6 +50,7 @@ const ( RunStatusCancelling RunStatus = "cancelling" RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" RunStatusExpired RunStatus = "expired" RunStatusCancelled RunStatus = "cancelled" ) @@ -72,12 +83,62 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model string `json:"model,omitempty"` - Instructions string `json:"instructions,omitempty"` - AdditionalInstructions string `json:"additional_instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + + // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. + // lower values are more focused and deterministic. + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + + // The maximum number of prompt tokens that may be used over the course of the run. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. + MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` + + // The maximum number of completion tokens that may be used over the course of the run. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + + // ThreadTruncationStrategy defines the truncation strategy to use for the thread. + TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` + // Disable the default behavior of parallel tool calls by setting it: false. + ParallelToolCalls any `json:"parallel_tool_calls,omitempty"` +} + +// ThreadTruncationStrategy defines the truncation strategy to use for the thread. +// https://platform.openai.com/docs/assistants/how-it-works/truncation-strategy. +type ThreadTruncationStrategy struct { + // default 'auto'. + Type TruncationStrategy `json:"type,omitempty"` + // this field should be set if the truncation strategy is set to LastMessages. + LastMessages *int `json:"last_messages,omitempty"` +} + +// TruncationStrategy defines the existing truncation strategies existing for thread management in an assistant. +type TruncationStrategy string + +const ( + // TruncationStrategyAuto messages in the middle of the thread will be dropped to fit the context length of the model. + TruncationStrategyAuto = TruncationStrategy("auto") + // TruncationStrategyLastMessages the thread will be truncated to the n most recent messages in the thread. + TruncationStrategyLastMessages = TruncationStrategy("last_messages") +) + +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` } type RunModifyRequest struct { @@ -182,8 +243,7 @@ func (c *Client) CreateRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -203,8 +263,7 @@ func (c *Client) RetrieveRun( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -226,8 +285,7 @@ func (c *Client) ModifyRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -266,8 +324,7 @@ func (c *Client) ListRuns( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -288,8 +345,7 @@ func (c *Client) SubmitToolOutputs( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -308,8 +364,7 @@ func (c *Client) CancelRun( ctx, http.MethodPost, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -328,8 +383,7 @@ func (c *Client) CreateThreadAndRun( http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -350,8 +404,7 @@ func (c *Client) RetrieveRunStep( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -391,8 +444,7 @@ func (c *Client) ListRunSteps( ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1(), - ) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/speech.go b/speech.go index 7e22e755c..60e7694fd 100644 --- a/speech.go +++ b/speech.go @@ -2,27 +2,31 @@ package openai import ( "context" - "errors" "net/http" ) type SpeechModel string const ( - TTSModel1 SpeechModel = "tts-1" - TTSModel1HD SpeechModel = "tts-1-hd" - TTSModelCanary SpeechModel = "canary-tts" + TTSModel1 SpeechModel = "tts-1" + TTSModel1HD SpeechModel = "tts-1-hd" + TTSModelCanary SpeechModel = "canary-tts" + TTSModelGPT4oMini SpeechModel = "gpt-4o-mini-tts" ) type SpeechVoice string const ( VoiceAlloy SpeechVoice = "alloy" + VoiceAsh SpeechVoice = "ash" + VoiceBallad SpeechVoice = "ballad" + VoiceCoral SpeechVoice = "coral" VoiceEcho SpeechVoice = "echo" VoiceFable SpeechVoice = "fable" VoiceOnyx SpeechVoice = "onyx" VoiceNova SpeechVoice = "nova" VoiceShimmer SpeechVoice = "shimmer" + VoiceVerse SpeechVoice = "verse" ) type SpeechResponseFormat string @@ -36,46 +40,20 @@ const ( SpeechResponseFormatPcm SpeechResponseFormat = "pcm" ) -var ( - ErrInvalidSpeechModel = errors.New("invalid speech model") - ErrInvalidVoice = errors.New("invalid voice") -) - type CreateSpeechRequest struct { Model SpeechModel `json:"model"` Input string `json:"input"` Voice SpeechVoice `json:"voice"` + Instructions string `json:"instructions,omitempty"` // Optional, Doesnt work with tts-1 or tts-1-hd. ResponseFormat SpeechResponseFormat `json:"response_format,omitempty"` // Optional, default to mp3 Speed float64 `json:"speed,omitempty"` // Optional, default to 1.0 } -func contains[T comparable](s []T, e T) bool { - for _, v := range s { - if v == e { - return true - } - } - return false -} - -func isValidSpeechModel(model SpeechModel) bool { - return contains([]SpeechModel{TTSModel1, TTSModel1HD, TTSModelCanary}, model) -} - -func isValidVoice(voice SpeechVoice) bool { - return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) -} - func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { - if !isValidSpeechModel(request.Model) { - err = ErrInvalidSpeechModel - return - } - if !isValidVoice(request.Voice) { - err = ErrInvalidVoice - return - } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/audio/speech", string(request.Model)), + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL("/audio/speech", withModel(string(request.Model))), withBody(request), withContentType("application/json"), ) diff --git a/speech_test.go b/speech_test.go index d9ba58b13..67a3feabc 100644 --- a/speech_test.go +++ b/speech_test.go @@ -21,10 +21,8 @@ func TestSpeechIntegration(t *testing.T) { defer teardown() server.RegisterHandler("/v1/audio/speech", func(w http.ResponseWriter, r *http.Request) { - dir, cleanup := test.CreateTestDirectory(t) - path := filepath.Join(dir, "fake.mp3") + path := filepath.Join(t.TempDir(), "fake.mp3") test.CreateTestFile(t, path) - defer cleanup() // audio endpoints only accept POST requests if r.Method != "POST" { @@ -95,21 +93,4 @@ func TestSpeechIntegration(t *testing.T) { err = os.WriteFile("test.mp3", buf, 0644) checks.NoError(t, err, "Create error") }) - t.Run("invalid model", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: "invalid_model", - Input: "Hello!", - Voice: openai.VoiceAlloy, - }) - checks.ErrorIs(t, err, openai.ErrInvalidSpeechModel, "CreateSpeech error") - }) - - t.Run("invalid voice", func(t *testing.T) { - _, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ - Model: openai.TTSModel1, - Input: "Hello!", - Voice: "invalid_voice", - }) - checks.ErrorIs(t, err, openai.ErrInvalidVoice, "CreateSpeech error") - }) } diff --git a/stream.go b/stream.go index b277f3c29..a61c7c970 100644 --- a/stream.go +++ b/stream.go @@ -3,6 +3,7 @@ package openai import ( "context" "errors" + "net/http" ) var ( @@ -33,7 +34,12 @@ func (c *Client) CreateCompletionStream( } request.Stream = true - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix, withModel(request.Model)), + withBody(request), + ) if err != nil { return nil, err } diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..6faefe0a7 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -6,13 +6,14 @@ import ( "fmt" "io" "net/http" + "regexp" utils "github.com/sashabaranov/go-openai/internal" ) var ( - headerData = []byte("data: ") - errorPrefix = []byte(`data: {"error":`) + headerData = regexp.MustCompile(`^data:\s*`) + errorPrefix = regexp.MustCompile(`^data:\s*{"error":`) ) type streamable interface { @@ -32,17 +33,28 @@ type streamReader[T streamable] struct { } func (stream *streamReader[T]) Recv() (response T, err error) { - if stream.isFinished { - err = io.EOF + rawLine, err := stream.RecvRaw() + if err != nil { return } - response, err = stream.processLines() - return + err = stream.unmarshaler.Unmarshal(rawLine, &response) + if err != nil { + return + } + return response, nil +} + +func (stream *streamReader[T]) RecvRaw() ([]byte, error) { + if stream.isFinished { + return nil, io.EOF + } + + return stream.processLines() } //nolint:gocognit -func (stream *streamReader[T]) processLines() (T, error) { +func (stream *streamReader[T]) processLines() ([]byte, error) { var ( emptyMessagesCount uint hasErrorPrefix bool @@ -53,44 +65,38 @@ func (stream *streamReader[T]) processLines() (T, error) { if readErr != nil || hasErrorPrefix { respErr := stream.unmarshalError() if respErr != nil { - return *new(T), fmt.Errorf("error, %w", respErr.Error) + return nil, fmt.Errorf("error, %w", respErr.Error) } - return *new(T), readErr + return nil, readErr } noSpaceLine := bytes.TrimSpace(rawLine) - if bytes.HasPrefix(noSpaceLine, errorPrefix) { + if errorPrefix.Match(noSpaceLine) { hasErrorPrefix = true } - if !bytes.HasPrefix(noSpaceLine, headerData) || hasErrorPrefix { + if !headerData.Match(noSpaceLine) || hasErrorPrefix { if hasErrorPrefix { - noSpaceLine = bytes.TrimPrefix(noSpaceLine, headerData) + noSpaceLine = headerData.ReplaceAll(noSpaceLine, nil) } writeErr := stream.errAccumulator.Write(noSpaceLine) if writeErr != nil { - return *new(T), writeErr + return nil, writeErr } emptyMessagesCount++ if emptyMessagesCount > stream.emptyMessagesLimit { - return *new(T), ErrTooManyEmptyStreamMessages + return nil, ErrTooManyEmptyStreamMessages } continue } - noPrefixLine := bytes.TrimPrefix(noSpaceLine, headerData) + noPrefixLine := headerData.ReplaceAll(noSpaceLine, nil) if string(noPrefixLine) == "[DONE]" { stream.isFinished = true - return *new(T), io.EOF - } - - var response T - unmarshalErr := stream.unmarshaler.Unmarshal(noPrefixLine, &response) - if unmarshalErr != nil { - return *new(T), unmarshalErr + return nil, io.EOF } - return response, nil + return noPrefixLine, nil } } diff --git a/stream_reader_test.go b/stream_reader_test.go index cd6e46eff..449a14b43 100644 --- a/stream_reader_test.go +++ b/stream_reader_test.go @@ -63,3 +63,16 @@ func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) { _, err := stream.Recv() checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Did not return error when write failed", err.Error()) } + +func TestStreamReaderRecvRaw(t *testing.T) { + stream := &streamReader[ChatCompletionStreamResponse]{ + reader: bufio.NewReader(bytes.NewReader([]byte("data: {\"key\": \"value\"}\n"))), + } + rawLine, err := stream.RecvRaw() + if err != nil { + t.Fatalf("Did not return raw line: %v", err) + } + if !bytes.Equal(rawLine, []byte("{\"key\": \"value\"}")) { + t.Fatalf("Did not return raw line: %v", string(rawLine)) + } +} diff --git a/stream_test.go b/stream_test.go index 2822a3535..9dd95bb5f 100644 --- a/stream_test.go +++ b/stream_test.go @@ -169,7 +169,7 @@ func TestCreateCompletionStreamRateLimitError(t *testing.T) { var apiErr *openai.APIError _, err := client.CreateCompletionStream(context.Background(), openai.CompletionRequest{ MaxTokens: 5, - Model: openai.GPT3Ada, + Model: openai.GPT3Babbage002, Prompt: "Hello!", Stream: true, }) diff --git a/thread.go b/thread.go index 291f3dcab..bc08e2bcb 100644 --- a/thread.go +++ b/thread.go @@ -10,34 +10,98 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string const ( - ThreadMessageRoleUser ThreadMessageRole = "user" + ThreadMessageRoleAssistant ThreadMessageRole = "assistant" + ThreadMessageRoleUser ThreadMessageRole = "user" ) type ThreadMessage struct { - Role ThreadMessageRole `json:"role"` - Content string `json:"content"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Role ThreadMessageRole `json:"role"` + Content string `json:"content"` + FileIDs []string `json:"file_ids,omitempty"` + Attachments []ThreadAttachment `json:"attachments,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ThreadAttachment struct { + FileID string `json:"file_id"` + Tools []ThreadAttachmentTool `json:"tools"` +} + +type ThreadAttachmentTool struct { + Type string `json:"type"` } type ThreadDeleteResponse struct { @@ -51,7 +115,7 @@ type ThreadDeleteResponse struct { // CreateThread creates a new thread. func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -64,7 +128,7 @@ func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (respo func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -81,7 +145,7 @@ func (c *Client) ModifyThread( ) (response Thread, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } @@ -97,7 +161,7 @@ func (c *Client) DeleteThread( ) (response ThreadDeleteResponse, err error) { urlSuffix := threadsSuffix + "/" + threadID req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), - withBetaAssistantV1()) + withBetaAssistantVersion(c.config.AssistantVersion)) if err != nil { return } diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 000000000..682bb1cf9 --- /dev/null +++ b/vector_store.go @@ -0,0 +1,348 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + FirstID *string `json:"first_id"` + LastID *string `json:"last_id"` + HasMore bool `json:"has_more"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 000000000..58b9a857e --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +}