diff --git a/.codecov.yml b/.codecov.yml
new file mode 100644
index 000000000..81773666c
--- /dev/null
+++ b/.codecov.yml
@@ -0,0 +1,4 @@
+coverage:
+ ignore:
+ - "examples/**"
+ - "internal/test/**"
diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml
index f4cbe7c8b..2c9730656 100644
--- a/.github/workflows/pr.yml
+++ b/.github/workflows/pr.yml
@@ -16,12 +16,14 @@ jobs:
go-version: '1.24'
- name: Run vet
run: |
- go vet .
+ go vet -stdversion ./...
- name: Run golangci-lint
- uses: golangci/golangci-lint-action@v6
+ uses: golangci/golangci-lint-action@v7
with:
- version: v1.64.5
+ version: v2.1.5
- name: Run tests
run: go test -race -covermode=atomic -coverprofile=coverage.out -v ./...
- name: Upload coverage reports to Codecov
- uses: codecov/codecov-action@v4
+ uses: codecov/codecov-action@v5
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
diff --git a/.golangci.yml b/.golangci.yml
index a5988825b..6391ad76f 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -1,258 +1,168 @@
-## Golden config for golangci-lint v1.47.3
-#
-# This is the best config for golangci-lint based on my experience and opinion.
-# It is very strict, but not extremely strict.
-# Feel free to adopt and change it for your needs.
-
-run:
- # Timeout for analysis, e.g. 30s, 5m.
- # Default: 1m
- timeout: 3m
-
-
-# This file contains only configs which differ from defaults.
-# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
-linters-settings:
- cyclop:
- # The maximal code complexity to report.
- # Default: 10
- max-complexity: 30
- # The maximal average package complexity.
- # If it's higher than 0.0 (float) the check is enabled
- # Default: 0.0
- package-average: 10.0
-
- errcheck:
- # Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
- # Such cases aren't reported by default.
- # Default: false
- check-type-assertions: true
-
- funlen:
- # Checks the number of lines in a function.
- # If lower than 0, disable the check.
- # Default: 60
- lines: 100
- # Checks the number of statements in a function.
- # If lower than 0, disable the check.
- # Default: 40
- statements: 50
-
- gocognit:
- # Minimal code complexity to report
- # Default: 30 (but we recommend 10-20)
- min-complexity: 20
-
- gocritic:
- # Settings passed to gocritic.
- # The settings key is the name of a supported gocritic checker.
- # The list of supported checkers can be find in https://go-critic.github.io/overview.
- settings:
- captLocal:
- # Whether to restrict checker to params only.
- # Default: true
- paramsOnly: false
- underef:
- # Whether to skip (*x).method() calls where x is a pointer receiver.
- # Default: true
- skipRecvDeref: false
-
- mnd:
- # List of function patterns to exclude from analysis.
- # Values always ignored: `time.Date`
- # Default: []
- ignored-functions:
- - os.Chmod
- - os.Mkdir
- - os.MkdirAll
- - os.OpenFile
- - os.WriteFile
- - prometheus.ExponentialBuckets
- - prometheus.ExponentialBucketsRange
- - prometheus.LinearBuckets
- - strconv.FormatFloat
- - strconv.FormatInt
- - strconv.FormatUint
- - strconv.ParseFloat
- - strconv.ParseInt
- - strconv.ParseUint
-
- gomodguard:
- blocked:
- # List of blocked modules.
- # Default: []
- modules:
- - github.com/golang/protobuf:
- recommendations:
- - google.golang.org/protobuf
- reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules"
- - github.com/satori/go.uuid:
- recommendations:
- - github.com/google/uuid
- reason: "satori's package is not maintained"
- - github.com/gofrs/uuid:
- recommendations:
- - github.com/google/uuid
- reason: "see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw"
-
- govet:
- # Enable all analyzers.
- # Default: false
- enable-all: true
- # Disable analyzers by name.
- # Run `go tool vet help` to see all analyzers.
- # Default: []
- disable:
- - fieldalignment # too strict
- # Settings per analyzer.
- settings:
- shadow:
- # Whether to be strict about shadowing; can be noisy.
- # Default: false
- strict: true
-
- nakedret:
- # Make an issue if func has more lines of code than this setting, and it has naked returns.
- # Default: 30
- max-func-lines: 0
-
- nolintlint:
- # Exclude following linters from requiring an explanation.
- # Default: []
- allow-no-explanation: [ funlen, gocognit, lll ]
- # Enable to require an explanation of nonzero length after each nolint directive.
- # Default: false
- require-explanation: true
- # Enable to require nolint directives to mention the specific linter being suppressed.
- # Default: false
- require-specific: true
-
- rowserrcheck:
- # database/sql is always checked
- # Default: []
- packages:
- - github.com/jmoiron/sqlx
-
- tenv:
- # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures.
- # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked.
- # Default: false
- all: true
-
-
+version: "2"
linters:
- disable-all: true
+ default: none
enable:
- ## enabled by default
- - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases
- - gosimple # Linter for Go source code that specializes in simplifying a code
- - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- - ineffassign # Detects when assignments to existing variables are not used
- - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks
- - unused # Checks Go code for unused constants, variables, functions and types
- ## disabled by default
- # - asasalint # Check for pass []any as any in variadic func(...any)
- - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers
- - bidichk # Checks for dangerous unicode character sequences
- - bodyclose # checks whether HTTP response body is closed successfully
- - contextcheck # check the function whether use a non-inherited context
- - cyclop # checks function and package cyclomatic complexity
- - dupl # Tool for code clone detection
- - durationcheck # check for two durations multiplied together
- - errname # Checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error.
- - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
- - exhaustive # check exhaustiveness of enum switch statements
- - forbidigo # Forbids identifiers
- - funlen # Tool for detection of long functions
- # - gochecknoglobals # check that no global variables exist
- - gochecknoinits # Checks that no init functions are present in Go code
- - gocognit # Computes and checks the cognitive complexity of functions
- - goconst # Finds repeated strings that could be replaced by a constant
- - gocritic # Provides diagnostics that check for bugs, performance and style issues.
- - gocyclo # Computes and checks the cyclomatic complexity of functions
- - godot # Check if comments end in a period
- - goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod.
- - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations.
- - goprintffuncname # Checks that printf-like functions are named with f at the end
- - gosec # Inspects source code for security problems
- - lll # Reports long lines
- - makezero # Finds slice declarations with non-zero initial length
- # - nakedret # Finds naked returns in functions greater than a specified function length
- - mnd # An analyzer to detect magic numbers.
- - nestif # Reports deeply nested if statements
- - nilerr # Finds the code that returns nil even if it checks that the error is not nil.
- - nilnil # Checks that there is no simultaneous return of nil error and an invalid value.
- # - noctx # noctx finds sending http request without context.Context
- - nolintlint # Reports ill-formed or insufficient nolint directives
- # - nonamedreturns # Reports all named returns
- - nosprintfhostport # Checks for misuse of Sprintf to construct a host with port in a URL.
- - predeclared # find code that shadows one of Go's predeclared identifiers
- - promlinter # Check Prometheus metrics naming via promlint
- - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- - rowserrcheck # checks whether Err of rows is checked successfully
- - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed.
- - stylecheck # Stylecheck is a replacement for golint
- - testpackage # linter that makes you use a separate _test package
- - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes
- - unconvert # Remove unnecessary type conversions
- - unparam # Reports unused function parameters
- - usetesting # Reports uses of functions with replacement inside the testing package
- - wastedassign # wastedassign finds wasted assignment statements.
- - whitespace # Tool for detection of leading and trailing whitespace
- ## you may want to enable
- #- decorder # check declaration order and count of types, constants, variables and functions
- #- exhaustruct # Checks if all structure fields are initialized
- #- goheader # Checks is file header matches to pattern
- #- ireturn # Accept Interfaces, Return Concrete Types
- #- prealloc # [premature optimization, but can be used in some cases] Finds slice declarations that could potentially be preallocated
- #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
- #- wrapcheck # Checks that errors returned from external packages are wrapped
- ## disabled
- #- containedctx # containedctx is a linter that detects struct contained context.Context field
- #- depguard # [replaced by gomodguard] Go linter that checks if package imports are in a list of acceptable packages
- #- dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
- #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted.
- #- forcetypeassert # [replaced by errcheck] finds forced type assertions
- #- gci # Gci controls golang package import order and makes it always deterministic.
- #- godox # Tool for detection of FIXME, TODO and other comment keywords
- #- goerr113 # [too strict] Golang linter to check the errors handling expressions
- #- gofmt # [replaced by goimports] Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- #- gofumpt # [replaced by goimports, gofumports is not available yet] Gofumpt checks whether code was gofumpt-ed.
- #- grouper # An analyzer to analyze expression groups.
- #- ifshort # Checks that your code uses short syntax for if-statements whenever possible
- #- importas # Enforces consistent import aliases
- #- maintidx # maintidx measures the maintainability index of each function.
- #- misspell # [useless] Finds commonly misspelled English words in comments
- #- nlreturn # [too strict and mostly code is not more readable] nlreturn checks for a new line before return and branch statements to increase code clarity
- #- nosnakecase # Detects snake case of variable naming and function name. # TODO: maybe enable after https://github.com/sivchari/nosnakecase/issues/14
- #- paralleltest # [too many false positives] paralleltest detects missing usage of t.Parallel() method in your Go test
- #- tagliatelle # Checks the struct tags.
- #- thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers
- #- wsl # [too strict and mostly code is not more readable] Whitespace Linter - Forces you to use empty lines!
-
-
+ - asciicheck
+ - bidichk
+ - bodyclose
+ - contextcheck
+ - cyclop
+ - dupl
+ - durationcheck
+ - errcheck
+ - errname
+ - errorlint
+ - exhaustive
+ - forbidigo
+ - funlen
+ - gochecknoinits
+ - gocognit
+ - goconst
+ - gocritic
+ - gocyclo
+ - godot
+ - gomoddirectives
+ - gomodguard
+ - goprintffuncname
+ - gosec
+ - govet
+ - ineffassign
+ - lll
+ - makezero
+ - mnd
+ - nestif
+ - nilerr
+ - nilnil
+ - nolintlint
+ - nosprintfhostport
+ - predeclared
+ - promlinter
+ - revive
+ - rowserrcheck
+ - sqlclosecheck
+ - staticcheck
+ - testpackage
+ - tparallel
+ - unconvert
+ - unparam
+ - unused
+ - usetesting
+ - wastedassign
+ - whitespace
+ settings:
+ cyclop:
+ max-complexity: 30
+ package-average: 10
+ errcheck:
+ check-type-assertions: true
+ funlen:
+ lines: 100
+ statements: 50
+ gocognit:
+ min-complexity: 20
+ gocritic:
+ settings:
+ captLocal:
+ paramsOnly: false
+ underef:
+ skipRecvDeref: false
+ gomodguard:
+ blocked:
+ modules:
+ - github.com/golang/protobuf:
+ recommendations:
+ - google.golang.org/protobuf
+ reason: see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
+ - github.com/satori/go.uuid:
+ recommendations:
+ - github.com/google/uuid
+ reason: satori's package is not maintained
+ - github.com/gofrs/uuid:
+ recommendations:
+ - github.com/google/uuid
+ reason: 'see recommendation from dev-infra team: https://confluence.gtforge.com/x/gQI6Aw'
+ govet:
+ disable:
+ - fieldalignment
+ enable-all: true
+ settings:
+ shadow:
+ strict: true
+ mnd:
+ ignored-functions:
+ - os.Chmod
+ - os.Mkdir
+ - os.MkdirAll
+ - os.OpenFile
+ - os.WriteFile
+ - prometheus.ExponentialBuckets
+ - prometheus.ExponentialBucketsRange
+ - prometheus.LinearBuckets
+ - strconv.FormatFloat
+ - strconv.FormatInt
+ - strconv.FormatUint
+ - strconv.ParseFloat
+ - strconv.ParseInt
+ - strconv.ParseUint
+ nakedret:
+ max-func-lines: 0
+ nolintlint:
+ require-explanation: true
+ require-specific: true
+ allow-no-explanation:
+ - funlen
+ - gocognit
+ - lll
+ rowserrcheck:
+ packages:
+ - github.com/jmoiron/sqlx
+ exclusions:
+ generated: lax
+ presets:
+ - comments
+ - common-false-positives
+ - legacy
+ - std-error-handling
+ rules:
+ - linters:
+ - forbidigo
+ - mnd
+ - revive
+ path : ^examples/.*\.go$
+ - linters:
+ - lll
+ source: ^//\s*go:generate\s
+ - linters:
+ - godot
+ source: (noinspection|TODO)
+ - linters:
+ - gocritic
+ source: //noinspection
+ - linters:
+ - errorlint
+ source: ^\s+if _, ok := err\.\([^.]+\.InternalError\); ok {
+ - linters:
+ - bodyclose
+ - dupl
+ - funlen
+ - goconst
+ - gosec
+ - noctx
+ - wrapcheck
+ - staticcheck
+ path: _test\.go
+ paths:
+ - third_party$
+ - builtin$
+ - examples$
issues:
- # Maximum count of issues with the same text.
- # Set to 0 to disable.
- # Default: 3
max-same-issues: 50
-
- exclude-rules:
- - source: "^//\\s*go:generate\\s"
- linters: [ lll ]
- - source: "(noinspection|TODO)"
- linters: [ godot ]
- - source: "//noinspection"
- linters: [ gocritic ]
- - source: "^\\s+if _, ok := err\\.\\([^.]+\\.InternalError\\); ok {"
- linters: [ errorlint ]
- - path: "_test\\.go"
- linters:
- - bodyclose
- - dupl
- - funlen
- - goconst
- - gosec
- - noctx
- - wrapcheck
+formatters:
+ enable:
+ - goimports
+ exclusions:
+ generated: lax
+ paths:
+ - third_party$
+ - builtin$
+ - examples$
diff --git a/README.md b/README.md
index 57d1d35bf..77b85e519 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@ This library provides unofficial Go clients for [OpenAI API](https://platform.op
* ChatGPT 4o, o1
* GPT-3, GPT-4
-* DALL·E 2, DALL·E 3
+* DALL·E 2, DALL·E 3, GPT Image 1
* Whisper
## Installation
@@ -357,6 +357,66 @@ func main() {
```
+
+GPT Image 1 image generation
+
+```go
+package main
+
+import (
+ "context"
+ "encoding/base64"
+ "fmt"
+ "os"
+
+ openai "github.com/sashabaranov/go-openai"
+)
+
+func main() {
+ c := openai.NewClient("your token")
+ ctx := context.Background()
+
+ req := openai.ImageRequest{
+ Prompt: "Parrot on a skateboard performing a trick. Large bold text \"SKATE MASTER\" banner at the bottom of the image. Cartoon style, natural light, high detail, 1:1 aspect ratio.",
+ Background: openai.CreateImageBackgroundOpaque,
+ Model: openai.CreateImageModelGptImage1,
+ Size: openai.CreateImageSize1024x1024,
+ N: 1,
+ Quality: openai.CreateImageQualityLow,
+ OutputCompression: 100,
+ OutputFormat: openai.CreateImageOutputFormatJPEG,
+ // Moderation: openai.CreateImageModerationLow,
+ // User: "",
+ }
+
+ resp, err := c.CreateImage(ctx, req)
+ if err != nil {
+ fmt.Printf("Image creation Image generation with GPT Image 1error: %v\n", err)
+ return
+ }
+
+ fmt.Println("Image Base64:", resp.Data[0].B64JSON)
+
+ // Decode the base64 data
+ imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
+ if err != nil {
+ fmt.Printf("Base64 decode error: %v\n", err)
+ return
+ }
+
+ // Write image to file
+ outputPath := "generated_image.jpg"
+ err = os.WriteFile(outputPath, imgBytes, 0644)
+ if err != nil {
+ fmt.Printf("Failed to write image file: %v\n", err)
+ return
+ }
+
+ fmt.Printf("The image was saved as %s\n", outputPath)
+}
+```
+
+
Configuring proxy
diff --git a/audio_test.go b/audio_test.go
index 9f32d5468..51b3f465d 100644
--- a/audio_test.go
+++ b/audio_test.go
@@ -2,12 +2,16 @@ package openai //nolint:testpackage // testing private field
import (
"bytes"
+ "context"
+ "errors"
"fmt"
"io"
+ "net/http"
"os"
"path/filepath"
"testing"
+ utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
@@ -107,3 +111,131 @@ func TestCreateFileField(t *testing.T) {
checks.HasError(t, err, "createFileField using file should return error when open file fails")
})
}
+
+// failingFormBuilder always returns an error when creating form files.
+type failingFormBuilder struct{ err error }
+
+func (f *failingFormBuilder) CreateFormFile(_ string, _ *os.File) error {
+ return f.err
+}
+
+func (f *failingFormBuilder) CreateFormFileReader(_ string, _ io.Reader, _ string) error {
+ return f.err
+}
+
+func (f *failingFormBuilder) WriteField(_, _ string) error {
+ return nil
+}
+
+func (f *failingFormBuilder) Close() error {
+ return nil
+}
+
+func (f *failingFormBuilder) FormDataContentType() string {
+ return "multipart/form-data"
+}
+
+// failingAudioRequestBuilder simulates an error during HTTP request construction.
+type failingAudioRequestBuilder struct{ err error }
+
+func (f *failingAudioRequestBuilder) Build(
+ _ context.Context,
+ _, _ string,
+ _ any,
+ _ http.Header,
+) (*http.Request, error) {
+ return nil, f.err
+}
+
+// errorHTTPClient always returns an error when making HTTP calls.
+type errorHTTPClient struct{ err error }
+
+func (e *errorHTTPClient) Do(_ *http.Request) (*http.Response, error) {
+ return nil, e.err
+}
+
+func TestCallAudioAPIMultipartFormError(t *testing.T) {
+ client := NewClient("test-token")
+ errForm := errors.New("mock create form file failure")
+ // Override form builder to force an error during multipart form creation.
+ client.createFormBuilder = func(_ io.Writer) utils.FormBuilder {
+ return &failingFormBuilder{err: errForm}
+ }
+
+ // Provide a reader so createFileField uses the reader path (no file open).
+ req := AudioRequest{FilePath: "fake.mp3", Reader: bytes.NewBuffer([]byte("dummy")), Model: Whisper1}
+ _, err := client.callAudioAPI(context.Background(), req, "transcriptions")
+ if err == nil {
+ t.Fatal("expected error but got none")
+ }
+ if !errors.Is(err, errForm) {
+ t.Errorf("expected error %v, got %v", errForm, err)
+ }
+}
+
+func TestCallAudioAPINewRequestError(t *testing.T) {
+ client := NewClient("test-token")
+ // Create a real temp file so multipart form succeeds.
+ tmp := t.TempDir()
+ path := filepath.Join(tmp, "file.mp3")
+ if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
+ t.Fatalf("failed to write temp file: %v", err)
+ }
+
+ errBuild := errors.New("mock build failure")
+ client.requestBuilder = &failingAudioRequestBuilder{err: errBuild}
+
+ req := AudioRequest{FilePath: path, Model: Whisper1}
+ _, err := client.callAudioAPI(context.Background(), req, "translations")
+ if err == nil {
+ t.Fatal("expected error but got none")
+ }
+ if !errors.Is(err, errBuild) {
+ t.Errorf("expected error %v, got %v", errBuild, err)
+ }
+}
+
+func TestCallAudioAPISendRequestErrorJSON(t *testing.T) {
+ client := NewClient("test-token")
+ // Create a real temp file so multipart form succeeds.
+ tmp := t.TempDir()
+ path := filepath.Join(tmp, "file.mp3")
+ if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
+ t.Fatalf("failed to write temp file: %v", err)
+ }
+
+ errHTTP := errors.New("mock HTTPClient failure")
+ // Override HTTP client to simulate a network error.
+ client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
+
+ req := AudioRequest{FilePath: path, Model: Whisper1}
+ _, err := client.callAudioAPI(context.Background(), req, "transcriptions")
+ if err == nil {
+ t.Fatal("expected error but got none")
+ }
+ if !errors.Is(err, errHTTP) {
+ t.Errorf("expected error %v, got %v", errHTTP, err)
+ }
+}
+
+func TestCallAudioAPISendRequestErrorText(t *testing.T) {
+ client := NewClient("test-token")
+ tmp := t.TempDir()
+ path := filepath.Join(tmp, "file.mp3")
+ if err := os.WriteFile(path, []byte("content"), 0644); err != nil {
+ t.Fatalf("failed to write temp file: %v", err)
+ }
+
+ errHTTP := errors.New("mock HTTPClient failure")
+ client.config.HTTPClient = &errorHTTPClient{err: errHTTP}
+
+ // Use a non-JSON response format to exercise the text path.
+ req := AudioRequest{FilePath: path, Model: Whisper1, Format: AudioResponseFormatText}
+ _, err := client.callAudioAPI(context.Background(), req, "translations")
+ if err == nil {
+ t.Fatal("expected error but got none")
+ }
+ if !errors.Is(err, errHTTP) {
+ t.Errorf("expected error %v, got %v", errHTTP, err)
+ }
+}
diff --git a/chat.go b/chat.go
index 0f91d481c..0aa018715 100644
--- a/chat.go
+++ b/chat.go
@@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"net/http"
+
+ "github.com/sashabaranov/go-openai/jsonschema"
)
// Chat message role defined by the OpenAI API.
@@ -221,13 +223,49 @@ type ChatCompletionResponseFormatJSONSchema struct {
Strict bool `json:"strict"`
}
+func (r *ChatCompletionResponseFormatJSONSchema) UnmarshalJSON(data []byte) error {
+ type rawJSONSchema struct {
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Schema json.RawMessage `json:"schema"`
+ Strict bool `json:"strict"`
+ }
+ var raw rawJSONSchema
+ if err := json.Unmarshal(data, &raw); err != nil {
+ return err
+ }
+ r.Name = raw.Name
+ r.Description = raw.Description
+ r.Strict = raw.Strict
+ if len(raw.Schema) > 0 && string(raw.Schema) != "null" {
+ var d jsonschema.Definition
+ err := json.Unmarshal(raw.Schema, &d)
+ if err != nil {
+ return err
+ }
+ r.Schema = &d
+ }
+ return nil
+}
+
+// ChatCompletionRequestExtensions contains third-party OpenAI API extensions
+// (e.g., vendor-specific implementations like vLLM).
+type ChatCompletionRequestExtensions struct {
+ // GuidedChoice is a vLLM-specific extension that restricts the model's output
+ // to one of the predefined string choices provided in this field. This feature
+ // is used to constrain the model's responses to a controlled set of options,
+ // ensuring predictable and consistent outputs in scenarios where specific
+ // choices are required.
+ GuidedChoice []string `json:"guided_choice,omitempty"`
+}
+
// ChatCompletionRequest represents a request structure for chat completion API.
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatCompletionMessage `json:"messages"`
// MaxTokens The maximum number of tokens that can be generated in the chat completion.
// This value can be used to control costs for text generated via API.
- // This value is now deprecated in favor of max_completion_tokens, and is not compatible with o1 series models.
+ // Deprecated: use MaxCompletionTokens. Not compatible with o1-series models.
// refs: https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
MaxTokens int `json:"max_tokens,omitempty"`
// MaxCompletionTokens An upper bound for the number of tokens that can be generated for a completion,
@@ -275,6 +313,26 @@ type ChatCompletionRequest struct {
Metadata map[string]string `json:"metadata,omitempty"`
// Configuration for a predicted output.
Prediction *Prediction `json:"prediction,omitempty"`
+ // ChatTemplateKwargs provides a way to add non-standard parameters to the request body.
+ // Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
+ // Such as think mode for qwen3. "chat_template_kwargs": {"enable_thinking": false}
+ // https://qwen.readthedocs.io/en/latest/deployment/vllm.html#thinking-non-thinking-modes
+ ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"`
+ // Specifies the latency tier to use for processing the request.
+ ServiceTier ServiceTier `json:"service_tier,omitempty"`
+ // Verbosity determines how many output tokens are generated. Lowering the number of
+ // tokens reduces overall latency. It can be set to "low", "medium", or "high".
+ // Note: This field is only confirmed to work with gpt-5, gpt-5-mini and gpt-5-nano.
+ // Also, it is not in the API reference of chat completion at the time of writing,
+ // though it is supported by the API.
+ Verbosity string `json:"verbosity,omitempty"`
+ // A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies.
+ // The IDs should be a string that uniquely identifies each user.
+ // We recommend hashing their username or email address, in order to avoid sending us any identifying information.
+ // https://platform.openai.com/docs/api-reference/chat/create#chat_create-safety_identifier
+ SafetyIdentifier string `json:"safety_identifier,omitempty"`
+ // Embedded struct for non-OpenAI extensions
+ ChatCompletionRequestExtensions
}
type StreamOptions struct {
@@ -358,6 +416,15 @@ const (
FinishReasonNull FinishReason = "null"
)
+type ServiceTier string
+
+const (
+ ServiceTierAuto ServiceTier = "auto"
+ ServiceTierDefault ServiceTier = "default"
+ ServiceTierFlex ServiceTier = "flex"
+ ServiceTierPriority ServiceTier = "priority"
+)
+
func (r FinishReason) MarshalJSON() ([]byte, error) {
if r == FinishReasonNull || r == "" {
return []byte("null"), nil
@@ -390,6 +457,7 @@ type ChatCompletionResponse struct {
Usage Usage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
PromptFilterResults []PromptFilterResult `json:"prompt_filter_results,omitempty"`
+ ServiceTier ServiceTier `json:"service_tier,omitempty"`
httpHeader
}
diff --git a/chat_test.go b/chat_test.go
index 514706c96..236cff736 100644
--- a/chat_test.go
+++ b/chat_test.go
@@ -331,6 +331,126 @@ func TestO3ModelsChatCompletionsBetaLimitations(t *testing.T) {
}
}
+func TestGPT5ModelsChatCompletionsBetaLimitations(t *testing.T) {
+ tests := []struct {
+ name string
+ in openai.ChatCompletionRequest
+ expectedError error
+ }{
+ {
+ name: "log_probs_unsupported",
+ in: openai.ChatCompletionRequest{
+ MaxCompletionTokens: 1000,
+ LogProbs: true,
+ Model: openai.GPT5,
+ },
+ expectedError: openai.ErrReasoningModelLimitationsLogprobs,
+ },
+ {
+ name: "set_temperature_unsupported",
+ in: openai.ChatCompletionRequest{
+ MaxCompletionTokens: 1000,
+ Model: openai.GPT5Mini,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleUser,
+ },
+ {
+ Role: openai.ChatMessageRoleAssistant,
+ },
+ },
+ Temperature: float32(2),
+ },
+ expectedError: openai.ErrReasoningModelLimitationsOther,
+ },
+ {
+ name: "set_top_unsupported",
+ in: openai.ChatCompletionRequest{
+ MaxCompletionTokens: 1000,
+ Model: openai.GPT5Nano,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleUser,
+ },
+ {
+ Role: openai.ChatMessageRoleAssistant,
+ },
+ },
+ Temperature: float32(1),
+ TopP: float32(0.1),
+ },
+ expectedError: openai.ErrReasoningModelLimitationsOther,
+ },
+ {
+ name: "set_n_unsupported",
+ in: openai.ChatCompletionRequest{
+ MaxCompletionTokens: 1000,
+ Model: openai.GPT5ChatLatest,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleUser,
+ },
+ {
+ Role: openai.ChatMessageRoleAssistant,
+ },
+ },
+ Temperature: float32(1),
+ TopP: float32(1),
+ N: 2,
+ },
+ expectedError: openai.ErrReasoningModelLimitationsOther,
+ },
+ {
+ name: "set_presence_penalty_unsupported",
+ in: openai.ChatCompletionRequest{
+ MaxCompletionTokens: 1000,
+ Model: openai.GPT5,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleUser,
+ },
+ {
+ Role: openai.ChatMessageRoleAssistant,
+ },
+ },
+ PresencePenalty: float32(0.1),
+ },
+ expectedError: openai.ErrReasoningModelLimitationsOther,
+ },
+ {
+ name: "set_frequency_penalty_unsupported",
+ in: openai.ChatCompletionRequest{
+ MaxCompletionTokens: 1000,
+ Model: openai.GPT5Mini,
+ Messages: []openai.ChatCompletionMessage{
+ {
+ Role: openai.ChatMessageRoleUser,
+ },
+ {
+ Role: openai.ChatMessageRoleAssistant,
+ },
+ },
+ FrequencyPenalty: float32(0.1),
+ },
+ expectedError: openai.ErrReasoningModelLimitationsOther,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ config := openai.DefaultConfig("whatever")
+ config.BaseURL = "http://localhost/v1"
+ client := openai.NewClientWithConfig(config)
+ ctx := context.Background()
+
+ _, err := client.CreateChatCompletion(ctx, tt.in)
+ checks.HasError(t, err)
+ msg := fmt.Sprintf("CreateChatCompletion should return wrong model error, returned: %s", err)
+ checks.ErrorIs(t, err, tt.expectedError, msg)
+ })
+ }
+}
+
func TestChatRequestOmitEmpty(t *testing.T) {
data, err := json.Marshal(openai.ChatCompletionRequest{
// We set model b/c it's required, so omitempty doesn't make sense
@@ -946,3 +1066,142 @@ func TestFinishReason(t *testing.T) {
}
}
}
+
+func TestChatCompletionResponseFormatJSONSchema_UnmarshalJSON(t *testing.T) {
+ type args struct {
+ data []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ "",
+ args{
+ data: []byte(`{
+ "name": "math_response",
+ "strict": true,
+ "schema": {
+ "type": "object",
+ "properties": {
+ "steps": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "explanation": { "type": "string" },
+ "output": { "type": "string" }
+ },
+ "required": ["explanation","output"],
+ "additionalProperties": false
+ }
+ },
+ "final_answer": { "type": "string" }
+ },
+ "required": ["steps","final_answer"],
+ "additionalProperties": false
+ }
+ }`),
+ },
+ false,
+ },
+ {
+ "",
+ args{
+ data: []byte(`{
+ "name": "math_response",
+ "strict": true,
+ "schema": null
+ }`),
+ },
+ false,
+ },
+ {
+ "",
+ args{
+ data: []byte(`[123,456]`),
+ },
+ true,
+ },
+ {
+ "",
+ args{
+ data: []byte(`{
+ "name": "math_response",
+ "strict": true,
+ "schema": 123456
+ }`),
+ },
+ true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var r openai.ChatCompletionResponseFormatJSONSchema
+ err := r.UnmarshalJSON(tt.args.data)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
+
+func TestChatCompletionRequest_UnmarshalJSON(t *testing.T) {
+ type args struct {
+ bs []byte
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ }{
+ {
+ "",
+ args{bs: []byte(`{
+ "model": "llama3-1b",
+ "messages": [
+ { "role": "system", "content": "You are a helpful math tutor." },
+ { "role": "user", "content": "solve 8x + 31 = 2" }
+ ],
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {
+ "name": "math_response",
+ "strict": true,
+ "schema": {
+ "type": "object",
+ "properties": {
+ "steps": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "explanation": { "type": "string" },
+ "output": { "type": "string" }
+ },
+ "required": ["explanation","output"],
+ "additionalProperties": false
+ }
+ },
+ "final_answer": { "type": "string" }
+ },
+ "required": ["steps","final_answer"],
+ "additionalProperties": false
+ }
+ }
+ }
+}`)},
+ false,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var m openai.ChatCompletionRequest
+ err := json.Unmarshal(tt.args.bs, &m)
+ if err != nil {
+ t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ })
+ }
+}
diff --git a/client.go b/client.go
index cef375348..413b8db03 100644
--- a/client.go
+++ b/client.go
@@ -84,6 +84,20 @@ func withBody(body any) requestOption {
}
}
+func withExtraBody(extraBody map[string]any) requestOption {
+ return func(args *requestOptions) {
+ // Assert that args.body is a map[string]any.
+ bodyMap, ok := args.body.(map[string]any)
+ if ok {
+ // If it's a map[string]any then only add extraBody
+ // fields to args.body otherwise keep only fields in request struct.
+ for key, value := range extraBody {
+ bodyMap[key] = value
+ }
+ }
+ }
+}
+
func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
diff --git a/completion.go b/completion.go
index 9c3a64dd5..27d69f587 100644
--- a/completion.go
+++ b/completion.go
@@ -21,7 +21,7 @@ const (
O3Mini = "o3-mini"
O3Mini20250131 = "o3-mini-2025-01-31"
O4Mini = "o4-mini"
- O4Mini2020416 = "o4-mini-2025-04-16"
+ O4Mini20250416 = "o4-mini-2025-04-16"
GPT432K0613 = "gpt-4-32k-0613"
GPT432K0314 = "gpt-4-32k-0314"
GPT432K = "gpt-4-32k"
@@ -49,6 +49,10 @@ const (
GPT4Dot1Nano20250414 = "gpt-4.1-nano-2025-04-14"
GPT4Dot5Preview = "gpt-4.5-preview"
GPT4Dot5Preview20250227 = "gpt-4.5-preview-2025-02-27"
+ GPT5 = "gpt-5"
+ GPT5Mini = "gpt-5-mini"
+ GPT5Nano = "gpt-5-nano"
+ GPT5ChatLatest = "gpt-5-chat-latest"
GPT3Dot5Turbo0125 = "gpt-3.5-turbo-0125"
GPT3Dot5Turbo1106 = "gpt-3.5-turbo-1106"
GPT3Dot5Turbo0613 = "gpt-3.5-turbo-0613"
@@ -104,7 +108,7 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
O3Mini: true,
O3Mini20250131: true,
O4Mini: true,
- O4Mini2020416: true,
+ O4Mini20250416: true,
O3: true,
O320250416: true,
GPT3Dot5Turbo: true,
@@ -142,6 +146,10 @@ var disabledModelsForEndpoints = map[string]map[string]bool{
GPT4Dot1Mini20250414: true,
GPT4Dot1Nano: true,
GPT4Dot1Nano20250414: true,
+ GPT5: true,
+ GPT5Mini: true,
+ GPT5Nano: true,
+ GPT5ChatLatest: true,
},
chatCompletionsSuffix: {
CodexCodeDavinci002: true,
@@ -242,7 +250,7 @@ type CompletionResponse struct {
Created int64 `json:"created"`
Model string `json:"model"`
Choices []CompletionChoice `json:"choices"`
- Usage Usage `json:"usage"`
+ Usage *Usage `json:"usage,omitempty"`
httpHeader
}
diff --git a/completion_test.go b/completion_test.go
index 27e2d150e..abfc3007e 100644
--- a/completion_test.go
+++ b/completion_test.go
@@ -192,7 +192,7 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
}
inputTokens *= n
completionTokens := completionReq.MaxTokens * len(prompts) * n
- res.Usage = openai.Usage{
+ res.Usage = &openai.Usage{
PromptTokens: inputTokens,
CompletionTokens: completionTokens,
TotalTokens: inputTokens + completionTokens,
@@ -300,3 +300,32 @@ func TestCompletionWithGPT4oModels(t *testing.T) {
})
}
}
+
+// TestCompletionWithGPT5Models Tests that GPT5 models are not supported for completion endpoint.
+func TestCompletionWithGPT5Models(t *testing.T) {
+ config := openai.DefaultConfig("whatever")
+ config.BaseURL = "http://localhost/v1"
+ client := openai.NewClientWithConfig(config)
+
+ models := []string{
+ openai.GPT5,
+ openai.GPT5Mini,
+ openai.GPT5Nano,
+ openai.GPT5ChatLatest,
+ }
+
+ for _, model := range models {
+ t.Run(model, func(t *testing.T) {
+ _, err := client.CreateCompletion(
+ context.Background(),
+ openai.CompletionRequest{
+ MaxTokens: 5,
+ Model: model,
+ },
+ )
+ if !errors.Is(err, openai.ErrCompletionUnsupportedModel) {
+ t.Fatalf("CreateCompletion should return ErrCompletionUnsupportedModel for %s model, but returned: %v", model, err)
+ }
+ })
+ }
+}
diff --git a/config.go b/config.go
index 4788ba62a..4b8cfb6fb 100644
--- a/config.go
+++ b/config.go
@@ -3,6 +3,7 @@ package openai
import (
"net/http"
"regexp"
+ "strings"
)
const (
@@ -70,7 +71,11 @@ func DefaultAzureConfig(apiKey, baseURL string) ClientConfig {
APIType: APITypeAzure,
APIVersion: "2023-05-15",
AzureModelMapperFunc: func(model string) string {
- return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
+ // only 3.5 models have the "." stripped in their names
+ if strings.Contains(model, "3.5") {
+ return regexp.MustCompile(`[.:]`).ReplaceAllString(model, "")
+ }
+ return strings.ReplaceAll(model, ":", "")
},
HTTPClient: &http.Client{},
diff --git a/config_test.go b/config_test.go
index 145c26066..f44b80825 100644
--- a/config_test.go
+++ b/config_test.go
@@ -20,6 +20,10 @@ func TestGetAzureDeploymentByModel(t *testing.T) {
Model: "gpt-3.5-turbo-0301",
Expect: "gpt-35-turbo-0301",
},
+ {
+ Model: "gpt-4.1",
+ Expect: "gpt-4.1",
+ },
{
Model: "text-embedding-ada-002",
Expect: "text-embedding-ada-002",
@@ -100,3 +104,24 @@ func TestDefaultAnthropicConfigWithEmptyValues(t *testing.T) {
t.Errorf("Expected BaseURL to be %v, got %v", expectedBaseURL, config.BaseURL)
}
}
+
+func TestClientConfigString(t *testing.T) {
+ // String() should always return the constant value
+ conf := openai.DefaultConfig("dummy-token")
+ expected := ""
+ got := conf.String()
+ if got != expected {
+ t.Errorf("ClientConfig.String() = %q; want %q", got, expected)
+ }
+}
+
+func TestGetAzureDeploymentByModel_NoMapper(t *testing.T) {
+ // On a zero-value or DefaultConfig, AzureModelMapperFunc is nil,
+ // so GetAzureDeploymentByModel should just return the input model.
+ conf := openai.DefaultConfig("dummy-token")
+ model := "some-model"
+ got := conf.GetAzureDeploymentByModel(model)
+ if got != model {
+ t.Errorf("GetAzureDeploymentByModel(%q) = %q; want %q", model, got, model)
+ }
+}
diff --git a/embeddings.go b/embeddings.go
index 4a0e682da..8593f8b5b 100644
--- a/embeddings.go
+++ b/embeddings.go
@@ -4,6 +4,7 @@ import (
"context"
"encoding/base64"
"encoding/binary"
+ "encoding/json"
"errors"
"math"
"net/http"
@@ -160,6 +161,9 @@ type EmbeddingRequest struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
+ // The ExtraBody field allows for the inclusion of arbitrary key-value pairs
+ // in the request body that may not be explicitly defined in this struct.
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
}
func (r EmbeddingRequest) Convert() EmbeddingRequest {
@@ -187,6 +191,9 @@ type EmbeddingRequestStrings struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
+ // The ExtraBody field allows for the inclusion of arbitrary key-value pairs
+ // in the request body that may not be explicitly defined in this struct.
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
}
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
@@ -196,6 +203,7 @@ func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
+ ExtraBody: r.ExtraBody,
}
}
@@ -219,6 +227,9 @@ type EmbeddingRequestTokens struct {
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
+ // The ExtraBody field allows for the inclusion of arbitrary key-value pairs
+ // in the request body that may not be explicitly defined in this struct.
+ ExtraBody map[string]any `json:"extra_body,omitempty"`
}
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
@@ -228,6 +239,7 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
User: r.User,
EncodingFormat: r.EncodingFormat,
Dimensions: r.Dimensions,
+ ExtraBody: r.ExtraBody,
}
}
@@ -241,11 +253,29 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()
+
+ // The body map is used to dynamically construct the request payload for the embedding API.
+ // Instead of relying on a fixed struct, the body map allows for flexible inclusion of fields
+ // based on their presence, avoiding unnecessary or empty fields in the request.
+ extraBody := baseReq.ExtraBody
+ baseReq.ExtraBody = nil
+
+ // Serialize baseReq to JSON
+ jsonData, err := json.Marshal(baseReq)
+ if err != nil {
+ return
+ }
+
+ // Deserialize JSON to map[string]any
+ var body map[string]any
+ _ = json.Unmarshal(jsonData, &body)
+
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
- withBody(baseReq),
+ withBody(body), // Main request body.
+ withExtraBody(extraBody), // Merge ExtraBody fields.
)
if err != nil {
return
diff --git a/embeddings_test.go b/embeddings_test.go
index 438978169..07f1262cb 100644
--- a/embeddings_test.go
+++ b/embeddings_test.go
@@ -51,6 +51,24 @@ func TestEmbedding(t *testing.T) {
t.Fatalf("Expected embedding request to contain model field")
}
+ // test embedding request with strings and extra_body param
+ embeddingReqWithExtraBody := openai.EmbeddingRequest{
+ Input: []string{
+ "The food was delicious and the waiter",
+ "Other examples of embedding request",
+ },
+ Model: model,
+ ExtraBody: map[string]any{
+ "input_type": "query",
+ "truncate": "NONE",
+ },
+ }
+ marshaled, err = json.Marshal(embeddingReqWithExtraBody)
+ checks.NoError(t, err, "Could not marshal embedding request")
+ if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
+ t.Fatalf("Expected embedding request to contain model field")
+ }
+
// test embedding request with strings
embeddingReqStrings := openai.EmbeddingRequestStrings{
Input: []string{
@@ -124,7 +142,33 @@ func TestEmbeddingEndpoint(t *testing.T) {
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
}
- // test create embeddings with strings (simple embedding request)
+ // test create embeddings with strings (ExtraBody in request)
+ res, err = client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ ExtraBody: map[string]any{
+ "input_type": "query",
+ "truncate": "NONE",
+ },
+ Dimensions: 1,
+ },
+ )
+ checks.NoError(t, err, "CreateEmbeddings error")
+ if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
+ t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
+ }
+
+ // test create embeddings with strings (ExtraBody in request and )
+ _, err = client.CreateEmbeddings(
+ context.Background(),
+ openai.EmbeddingRequest{
+ Input: make(chan int), // Channels are not serializable
+ Model: "example_model",
+ },
+ )
+ checks.HasError(t, err, "CreateEmbeddings error")
+
+ // test failed (Serialize JSON error)
res, err = client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
diff --git a/files_test.go b/files_test.go
index 3c1b99fb4..486ef892e 100644
--- a/files_test.go
+++ b/files_test.go
@@ -121,3 +121,20 @@ func TestFileUploadWithNonExistentPath(t *testing.T) {
_, err := client.CreateFile(ctx, req)
checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist")
}
+func TestCreateFileRequestBuilderFailure(t *testing.T) {
+ config := DefaultConfig("")
+ config.BaseURL = ""
+ client := NewClientWithConfig(config)
+ client.requestBuilder = &failingRequestBuilder{}
+
+ client.createFormBuilder = func(io.Writer) utils.FormBuilder {
+ return &mockFormBuilder{
+ mockWriteField: func(string, string) error { return nil },
+ mockCreateFormFile: func(string, *os.File) error { return nil },
+ mockClose: func() error { return nil },
+ }
+ }
+
+ _, err := client.CreateFile(context.Background(), FileRequest{FilePath: "client.go"})
+ checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateFile should return error if request builder fails")
+}
diff --git a/image.go b/image.go
index 577d7db95..84b9daf02 100644
--- a/image.go
+++ b/image.go
@@ -3,8 +3,8 @@ package openai
import (
"bytes"
"context"
+ "io"
"net/http"
- "os"
"strconv"
)
@@ -13,51 +13,101 @@ const (
CreateImageSize256x256 = "256x256"
CreateImageSize512x512 = "512x512"
CreateImageSize1024x1024 = "1024x1024"
+
// dall-e-3 supported only.
CreateImageSize1792x1024 = "1792x1024"
CreateImageSize1024x1792 = "1024x1792"
+
+ // gpt-image-1 supported only.
+ CreateImageSize1536x1024 = "1536x1024" // Landscape
+ CreateImageSize1024x1536 = "1024x1536" // Portrait
)
const (
- CreateImageResponseFormatURL = "url"
+ // dall-e-2 and dall-e-3 only.
CreateImageResponseFormatB64JSON = "b64_json"
+ CreateImageResponseFormatURL = "url"
)
const (
- CreateImageModelDallE2 = "dall-e-2"
- CreateImageModelDallE3 = "dall-e-3"
+ CreateImageModelDallE2 = "dall-e-2"
+ CreateImageModelDallE3 = "dall-e-3"
+ CreateImageModelGptImage1 = "gpt-image-1"
)
const (
CreateImageQualityHD = "hd"
CreateImageQualityStandard = "standard"
+
+ // gpt-image-1 only.
+ CreateImageQualityHigh = "high"
+ CreateImageQualityMedium = "medium"
+ CreateImageQualityLow = "low"
)
const (
+ // dall-e-3 only.
CreateImageStyleVivid = "vivid"
CreateImageStyleNatural = "natural"
)
+const (
+ // gpt-image-1 only.
+ CreateImageBackgroundTransparent = "transparent"
+ CreateImageBackgroundOpaque = "opaque"
+)
+
+const (
+ // gpt-image-1 only.
+ CreateImageModerationLow = "low"
+)
+
+const (
+ // gpt-image-1 only.
+ CreateImageOutputFormatPNG = "png"
+ CreateImageOutputFormatJPEG = "jpeg"
+ CreateImageOutputFormatWEBP = "webp"
+)
+
// ImageRequest represents the request structure for the image API.
type ImageRequest struct {
- Prompt string `json:"prompt,omitempty"`
- Model string `json:"model,omitempty"`
- N int `json:"n,omitempty"`
- Quality string `json:"quality,omitempty"`
- Size string `json:"size,omitempty"`
- Style string `json:"style,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
- User string `json:"user,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Model string `json:"model,omitempty"`
+ N int `json:"n,omitempty"`
+ Quality string `json:"quality,omitempty"`
+ Size string `json:"size,omitempty"`
+ Style string `json:"style,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ User string `json:"user,omitempty"`
+ Background string `json:"background,omitempty"`
+ Moderation string `json:"moderation,omitempty"`
+ OutputCompression int `json:"output_compression,omitempty"`
+ OutputFormat string `json:"output_format,omitempty"`
}
// ImageResponse represents a response structure for image API.
type ImageResponse struct {
Created int64 `json:"created,omitempty"`
Data []ImageResponseDataInner `json:"data,omitempty"`
+ Usage ImageResponseUsage `json:"usage,omitempty"`
httpHeader
}
+// ImageResponseInputTokensDetails represents the token breakdown for input tokens.
+type ImageResponseInputTokensDetails struct {
+ TextTokens int `json:"text_tokens,omitempty"`
+ ImageTokens int `json:"image_tokens,omitempty"`
+}
+
+// ImageResponseUsage represents the token usage information for image API.
+type ImageResponseUsage struct {
+ TotalTokens int `json:"total_tokens,omitempty"`
+ InputTokens int `json:"input_tokens,omitempty"`
+ OutputTokens int `json:"output_tokens,omitempty"`
+ InputTokensDetails ImageResponseInputTokensDetails `json:"input_tokens_details,omitempty"`
+}
+
// ImageResponseDataInner represents a response data structure for image API.
type ImageResponseDataInner struct {
URL string `json:"url,omitempty"`
@@ -82,15 +132,42 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons
return
}
+// WrapReader wraps an io.Reader with filename and Content-type.
+func WrapReader(rdr io.Reader, filename string, contentType string) io.Reader {
+ return file{rdr, filename, contentType}
+}
+
+type file struct {
+ io.Reader
+ name string
+ contentType string
+}
+
+func (f file) Name() string {
+ if f.name != "" {
+ return f.name
+ } else if named, ok := f.Reader.(interface{ Name() string }); ok {
+ return named.Name()
+ }
+ return ""
+}
+
+func (f file) ContentType() string {
+ return f.contentType
+}
+
// ImageEditRequest represents the request structure for the image API.
+// Use WrapReader to wrap an io.Reader with filename and Content-type.
type ImageEditRequest struct {
- Image *os.File `json:"image,omitempty"`
- Mask *os.File `json:"mask,omitempty"`
- Prompt string `json:"prompt,omitempty"`
- Model string `json:"model,omitempty"`
- N int `json:"n,omitempty"`
- Size string `json:"size,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
+ Image io.Reader `json:"image,omitempty"`
+ Mask io.Reader `json:"mask,omitempty"`
+ Prompt string `json:"prompt,omitempty"`
+ Model string `json:"model,omitempty"`
+ N int `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ Quality string `json:"quality,omitempty"`
+ User string `json:"user,omitempty"`
}
// CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API.
@@ -98,15 +175,16 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
body := &bytes.Buffer{}
builder := c.createFormBuilder(body)
- // image
- err = builder.CreateFormFile("image", request.Image)
+ // image, filename verification can be postponed
+ err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil {
return
}
// mask, it is optional
if request.Mask != nil {
- err = builder.CreateFormFile("mask", request.Mask)
+ // filename verification can be postponed
+ err = builder.CreateFormFileReader("mask", request.Mask, "")
if err != nil {
return
}
@@ -153,12 +231,14 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
}
// ImageVariRequest represents the request structure for the image API.
+// Use WrapReader to wrap an io.Reader with filename and Content-type.
type ImageVariRequest struct {
- Image *os.File `json:"image,omitempty"`
- Model string `json:"model,omitempty"`
- N int `json:"n,omitempty"`
- Size string `json:"size,omitempty"`
- ResponseFormat string `json:"response_format,omitempty"`
+ Image io.Reader `json:"image,omitempty"`
+ Model string `json:"model,omitempty"`
+ N int `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ User string `json:"user,omitempty"`
}
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
@@ -167,8 +247,8 @@ func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest)
body := &bytes.Buffer{}
builder := c.createFormBuilder(body)
- // image
- err = builder.CreateFormFile("image", request.Image)
+ // image, filename verification can be postponed
+ err = builder.CreateFormFileReader("image", request.Image, "")
if err != nil {
return
}
diff --git a/image_test.go b/image_test.go
index 9332dd5cd..c2c8f42dc 100644
--- a/image_test.go
+++ b/image_test.go
@@ -4,6 +4,7 @@ import (
utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test/checks"
+ "bytes"
"context"
"fmt"
"io"
@@ -39,120 +40,284 @@ func (fb *mockFormBuilder) FormDataContentType() string {
}
func TestImageFormBuilderFailures(t *testing.T) {
- config := DefaultConfig("")
- config.BaseURL = ""
- client := NewClientWithConfig(config)
-
- mockBuilder := &mockFormBuilder{}
- client.createFormBuilder = func(io.Writer) utils.FormBuilder {
- return mockBuilder
- }
ctx := context.Background()
-
- req := ImageEditRequest{
- Mask: &os.File{},
- }
-
mockFailedErr := fmt.Errorf("mock form builder fail")
- mockBuilder.mockCreateFormFile = func(string, *os.File) error {
- return mockFailedErr
- }
- _, err := client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
- mockBuilder.mockCreateFormFile = func(name string, _ *os.File) error {
- if name == "mask" {
- return mockFailedErr
- }
- return nil
+ newClient := func(fb *mockFormBuilder) *Client {
+ cfg := DefaultConfig("")
+ cfg.BaseURL = ""
+ c := NewClientWithConfig(cfg)
+ c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb }
+ return c
}
- _, err = client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
- mockBuilder.mockCreateFormFile = func(string, *os.File) error {
- return nil
+ tests := []struct {
+ name string
+ setup func(*mockFormBuilder)
+ req ImageEditRequest
+ }{
+ {
+ name: "image",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr }
+ fb.mockWriteField = func(string, string) error { return nil }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "mask",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(name string, _ io.Reader, _ string) error {
+ if name == "mask" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockWriteField = func(string, string) error { return nil }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "prompt",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field, _ string) error {
+ if field == "prompt" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "n",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field, _ string) error {
+ if field == "n" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "size",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field, _ string) error {
+ if field == "size" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "response_format",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field, _ string) error {
+ if field == "response_format" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "close",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(string, string) error { return nil }
+ fb.mockClose = func() error { return mockFailedErr }
+ },
+ req: ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)},
+ },
}
- var failForField string
- mockBuilder.mockWriteField = func(fieldname, _ string) error {
- if fieldname == failForField {
- return mockFailedErr
- }
- return nil
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ fb := &mockFormBuilder{}
+ tc.setup(fb)
+ client := newClient(fb)
+ _, err := client.CreateEditImage(ctx, tc.req)
+ checks.ErrorIs(t, err, mockFailedErr, "CreateEditImage should return error if form builder fails")
+ })
}
- failForField = "prompt"
- _, err = client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
-
- failForField = "n"
- _, err = client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
-
- failForField = "size"
- _, err = client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
-
- failForField = "response_format"
- _, err = client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
+ t.Run("new request", func(t *testing.T) {
+ fb := &mockFormBuilder{
+ mockCreateFormFileReader: func(string, io.Reader, string) error { return nil },
+ mockWriteField: func(string, string) error { return nil },
+ mockClose: func() error { return nil },
+ }
+ client := newClient(fb)
+ client.requestBuilder = &failingRequestBuilder{}
- failForField = ""
- mockBuilder.mockClose = func() error {
- return mockFailedErr
- }
- _, err = client.CreateEditImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
+ _, err := client.CreateEditImage(ctx, ImageEditRequest{Image: bytes.NewBuffer(nil), Mask: bytes.NewBuffer(nil)})
+ checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateEditImage should return error if request builder fails")
+ })
}
func TestVariImageFormBuilderFailures(t *testing.T) {
- config := DefaultConfig("")
- config.BaseURL = ""
- client := NewClientWithConfig(config)
-
- mockBuilder := &mockFormBuilder{}
- client.createFormBuilder = func(io.Writer) utils.FormBuilder {
- return mockBuilder
- }
ctx := context.Background()
+ mockFailedErr := fmt.Errorf("mock form builder fail")
- req := ImageVariRequest{}
+ newClient := func(fb *mockFormBuilder) *Client {
+ cfg := DefaultConfig("")
+ cfg.BaseURL = ""
+ c := NewClientWithConfig(cfg)
+ c.createFormBuilder = func(io.Writer) utils.FormBuilder { return fb }
+ return c
+ }
- mockFailedErr := fmt.Errorf("mock form builder fail")
- mockBuilder.mockCreateFormFile = func(string, *os.File) error {
- return mockFailedErr
+ tests := []struct {
+ name string
+ setup func(*mockFormBuilder)
+ req ImageVariRequest
+ }{
+ {
+ name: "image",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return mockFailedErr }
+ fb.mockWriteField = func(string, string) error { return nil }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageVariRequest{Image: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "n",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field string, _ string) error {
+ if field == "n" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageVariRequest{Image: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "size",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field string, _ string) error {
+ if field == "size" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageVariRequest{Image: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "response_format",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(field string, _ string) error {
+ if field == "response_format" {
+ return mockFailedErr
+ }
+ return nil
+ }
+ fb.mockClose = func() error { return nil }
+ },
+ req: ImageVariRequest{Image: bytes.NewBuffer(nil)},
+ },
+ {
+ name: "close",
+ setup: func(fb *mockFormBuilder) {
+ fb.mockCreateFormFileReader = func(string, io.Reader, string) error { return nil }
+ fb.mockWriteField = func(string, string) error { return nil }
+ fb.mockClose = func() error { return mockFailedErr }
+ },
+ req: ImageVariRequest{Image: bytes.NewBuffer(nil)},
+ },
}
- _, err := client.CreateVariImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
- mockBuilder.mockCreateFormFile = func(string, *os.File) error {
- return nil
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ fb := &mockFormBuilder{}
+ tc.setup(fb)
+ client := newClient(fb)
+ _, err := client.CreateVariImage(ctx, tc.req)
+ checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
+ })
}
- var failForField string
- mockBuilder.mockWriteField = func(fieldname, _ string) error {
- if fieldname == failForField {
- return mockFailedErr
+ t.Run("new request", func(t *testing.T) {
+ fb := &mockFormBuilder{
+ mockCreateFormFileReader: func(string, io.Reader, string) error { return nil },
+ mockWriteField: func(string, string) error { return nil },
+ mockClose: func() error { return nil },
}
- return nil
- }
+ client := newClient(fb)
+ client.requestBuilder = &failingRequestBuilder{}
+
+ _, err := client.CreateVariImage(ctx, ImageVariRequest{Image: bytes.NewBuffer(nil)})
+ checks.ErrorIs(t, err, errTestRequestBuilderFailed, "CreateVariImage should return error if request builder fails")
+ })
+}
+
+type testNamedReader struct{ io.Reader }
- failForField = "n"
- _, err = client.CreateVariImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
+func (testNamedReader) Name() string { return "named.txt" }
- failForField = "size"
- _, err = client.CreateVariImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
+func TestWrapReader(t *testing.T) {
+ r := bytes.NewBufferString("data")
+ wrapped := WrapReader(r, "file.png", "image/png")
+ f, ok := wrapped.(interface {
+ Name() string
+ ContentType() string
+ })
+ if !ok {
+ t.Fatal("wrapped reader missing Name or ContentType")
+ }
+ if f.Name() != "file.png" {
+ t.Fatalf("expected name file.png, got %s", f.Name())
+ }
+ if f.ContentType() != "image/png" {
+ t.Fatalf("expected content type image/png, got %s", f.ContentType())
+ }
- failForField = "response_format"
- _, err = client.CreateVariImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateVariImage should return error if form builder fails")
+ // test name from underlying reader
+ nr := testNamedReader{Reader: bytes.NewBufferString("d")}
+ wrapped = WrapReader(nr, "", "text/plain")
+ f, ok = wrapped.(interface {
+ Name() string
+ ContentType() string
+ })
+ if !ok {
+ t.Fatal("wrapped named reader missing Name or ContentType")
+ }
+ if f.Name() != "named.txt" {
+ t.Fatalf("expected name named.txt, got %s", f.Name())
+ }
+ if f.ContentType() != "text/plain" {
+ t.Fatalf("expected content type text/plain, got %s", f.ContentType())
+ }
- failForField = ""
- mockBuilder.mockClose = func() error {
- return mockFailedErr
+ // no name provided
+ wrapped = WrapReader(bytes.NewBuffer(nil), "", "")
+ f2, ok := wrapped.(interface{ Name() string })
+ if !ok {
+ t.Fatal("wrapped anonymous reader missing Name")
+ }
+ if f2.Name() != "" {
+ t.Fatalf("expected empty name, got %s", f2.Name())
}
- _, err = client.CreateVariImage(ctx, req)
- checks.ErrorIs(t, err, mockFailedErr, "CreateImage should return error if form builder fails")
}
diff --git a/internal/error_accumulator_test.go b/internal/error_accumulator_test.go
index d48f28177..f6c226c5e 100644
--- a/internal/error_accumulator_test.go
+++ b/internal/error_accumulator_test.go
@@ -1,41 +1,39 @@
package openai_test
import (
- "bytes"
- "errors"
"testing"
- utils "github.com/sashabaranov/go-openai/internal"
+ openai "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
+ "github.com/sashabaranov/go-openai/internal/test/checks"
)
-func TestErrorAccumulatorBytes(t *testing.T) {
- accumulator := &utils.DefaultErrorAccumulator{
- Buffer: &bytes.Buffer{},
+func TestDefaultErrorAccumulator_WriteMultiple(t *testing.T) {
+ ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator)
+ if !ok {
+ t.Fatal("type assertion to *DefaultErrorAccumulator failed")
}
+ checks.NoError(t, ea.Write([]byte("{\"error\": \"test1\"}")))
+ checks.NoError(t, ea.Write([]byte("{\"error\": \"test2\"}")))
- errBytes := accumulator.Bytes()
- if len(errBytes) != 0 {
- t.Fatalf("Did not return nil with empty bytes: %s", string(errBytes))
+ expected := "{\"error\": \"test1\"}{\"error\": \"test2\"}"
+ if string(ea.Bytes()) != expected {
+ t.Fatalf("Expected %q, got %q", expected, ea.Bytes())
}
+}
- err := accumulator.Write([]byte("{}"))
- if err != nil {
- t.Fatalf("%+v", err)
+func TestDefaultErrorAccumulator_EmptyBuffer(t *testing.T) {
+ ea, ok := openai.NewErrorAccumulator().(*openai.DefaultErrorAccumulator)
+ if !ok {
+ t.Fatal("type assertion to *DefaultErrorAccumulator failed")
}
-
- errBytes = accumulator.Bytes()
- if len(errBytes) == 0 {
- t.Fatalf("Did not return error bytes when has error: %s", string(errBytes))
+ if len(ea.Bytes()) != 0 {
+ t.Fatal("Buffer should be empty initially")
}
}
-func TestErrorByteWriteErrors(t *testing.T) {
- accumulator := &utils.DefaultErrorAccumulator{
- Buffer: &test.FailingErrorBuffer{},
- }
- err := accumulator.Write([]byte("{"))
- if !errors.Is(err, test.ErrTestErrorAccumulatorWriteFailed) {
- t.Fatalf("Did not return error when write failed: %v", err)
- }
+func TestDefaultErrorAccumulator_WriteError(t *testing.T) {
+ ea := &openai.DefaultErrorAccumulator{Buffer: &test.FailingErrorBuffer{}}
+ err := ea.Write([]byte("fail"))
+ checks.ErrorIs(t, err, test.ErrTestErrorAccumulatorWriteFailed, "Write should propagate buffer errors")
}
diff --git a/internal/form_builder.go b/internal/form_builder.go
index 2224fad45..a17e820c0 100644
--- a/internal/form_builder.go
+++ b/internal/form_builder.go
@@ -4,8 +4,10 @@ import (
"fmt"
"io"
"mime/multipart"
+ "net/textproto"
"os"
- "path"
+ "path/filepath"
+ "strings"
)
type FormBuilder interface {
@@ -30,8 +32,50 @@ func (fb *DefaultFormBuilder) CreateFormFile(fieldname string, file *os.File) er
return fb.createFormFile(fieldname, file, file.Name())
}
+var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"")
+
+func escapeQuotes(s string) string {
+ return quoteEscaper.Replace(s)
+}
+
+// CreateFormFileReader creates a form field with a file reader.
+// The filename in Content-Disposition is required.
func (fb *DefaultFormBuilder) CreateFormFileReader(fieldname string, r io.Reader, filename string) error {
- return fb.createFormFile(fieldname, r, path.Base(filename))
+ if filename == "" {
+ if f, ok := r.(interface{ Name() string }); ok {
+ filename = f.Name()
+ }
+ }
+ var contentType string
+ if f, ok := r.(interface{ ContentType() string }); ok {
+ contentType = f.ContentType()
+ }
+
+ h := make(textproto.MIMEHeader)
+ h.Set(
+ "Content-Disposition",
+ fmt.Sprintf(
+ `form-data; name="%s"; filename="%s"`,
+ escapeQuotes(fieldname),
+ escapeQuotes(filepath.Base(filename)),
+ ),
+ )
+ // content type is optional, but it can be set
+ if contentType != "" {
+ h.Set("Content-Type", contentType)
+ }
+
+ fieldWriter, err := fb.writer.CreatePart(h)
+ if err != nil {
+ return err
+ }
+
+ _, err = io.Copy(fieldWriter, r)
+ if err != nil {
+ return err
+ }
+
+ return nil
}
func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, filename string) error {
@@ -53,6 +97,9 @@ func (fb *DefaultFormBuilder) createFormFile(fieldname string, r io.Reader, file
}
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
+ if fieldname == "" {
+ return fmt.Errorf("fieldname cannot be empty")
+ }
return fb.writer.WriteField(fieldname, value)
}
diff --git a/internal/form_builder_test.go b/internal/form_builder_test.go
index 8df989e3b..53ef11d23 100644
--- a/internal/form_builder_test.go
+++ b/internal/form_builder_test.go
@@ -1,14 +1,58 @@
package openai //nolint:testpackage // testing private field
import (
+ "errors"
+ "io"
+
"github.com/sashabaranov/go-openai/internal/test/checks"
"bytes"
- "errors"
"os"
+ "strings"
"testing"
)
+type mockFormBuilder struct {
+ mockCreateFormFile func(string, *os.File) error
+ mockWriteField func(string, string) error
+ mockClose func() error
+}
+
+func (m *mockFormBuilder) CreateFormFile(fieldname string, file *os.File) error {
+ return m.mockCreateFormFile(fieldname, file)
+}
+
+func (m *mockFormBuilder) WriteField(fieldname, value string) error {
+ return m.mockWriteField(fieldname, value)
+}
+
+func (m *mockFormBuilder) Close() error {
+ return m.mockClose()
+}
+
+func (m *mockFormBuilder) FormDataContentType() string {
+ return ""
+}
+
+func TestCloseMethod(t *testing.T) {
+ t.Run("NormalClose", func(t *testing.T) {
+ body := &bytes.Buffer{}
+ builder := NewFormBuilder(body)
+ checks.NoError(t, builder.Close(), "正常关闭应成功")
+ })
+
+ t.Run("ErrorPropagation", func(t *testing.T) {
+ errorMock := errors.New("mock close error")
+ mockBuilder := &mockFormBuilder{
+ mockClose: func() error {
+ return errorMock
+ },
+ }
+ err := mockBuilder.Close()
+ checks.ErrorIs(t, err, errorMock, "应传递关闭错误")
+ })
+}
+
type failingWriter struct {
}
@@ -43,3 +87,104 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
checks.HasError(t, err, "formbuilder should return error if file is closed")
checks.ErrorIs(t, err, os.ErrClosed, "formbuilder should return error if file is closed")
}
+
+type failingReader struct {
+}
+
+var errMockFailingReaderError = errors.New("mock reader failed")
+
+func (*failingReader) Read([]byte) (int, error) {
+ return 0, errMockFailingReaderError
+}
+
+type readerWithNameAndContentType struct {
+ io.Reader
+}
+
+func (*readerWithNameAndContentType) Name() string {
+ return ""
+}
+
+func (*readerWithNameAndContentType) ContentType() string {
+ return "image/png"
+}
+
+func TestFormBuilderWithReader(t *testing.T) {
+ file, err := os.CreateTemp(t.TempDir(), "")
+ if err != nil {
+ t.Fatalf("Error creating tmp file: %v", err)
+ }
+ defer file.Close()
+ builder := NewFormBuilder(&failingWriter{})
+ err = builder.CreateFormFileReader("file", file, file.Name())
+ checks.ErrorIs(t, err, errMockFailingWriterError, "formbuilder should return error if writer fails")
+
+ builder = NewFormBuilder(&bytes.Buffer{})
+ reader := &failingReader{}
+ err = builder.CreateFormFileReader("file", reader, "")
+ checks.ErrorIs(t, err, errMockFailingReaderError, "formbuilder should return error if copy reader fails")
+
+ successReader := &bytes.Buffer{}
+ err = builder.CreateFormFileReader("file", successReader, "")
+ checks.NoError(t, err, "formbuilder should not return error")
+
+ rnc := &readerWithNameAndContentType{Reader: &bytes.Buffer{}}
+ err = builder.CreateFormFileReader("file", rnc, "")
+ checks.NoError(t, err, "formbuilder should not return error")
+}
+
+func TestFormDataContentType(t *testing.T) {
+ t.Run("ReturnsUnderlyingWriterContentType", func(t *testing.T) {
+ buf := &bytes.Buffer{}
+ builder := NewFormBuilder(buf)
+
+ contentType := builder.FormDataContentType()
+ if contentType == "" {
+ t.Errorf("expected non-empty content type, got empty string")
+ }
+ })
+}
+
+func TestWriteField(t *testing.T) {
+ t.Run("EmptyFieldNameShouldReturnError", func(t *testing.T) {
+ buf := &bytes.Buffer{}
+ builder := NewFormBuilder(buf)
+
+ err := builder.WriteField("", "some value")
+ checks.HasError(t, err, "fieldname is required")
+ })
+
+ t.Run("ValidFieldNameShouldSucceed", func(t *testing.T) {
+ buf := &bytes.Buffer{}
+ builder := NewFormBuilder(buf)
+
+ err := builder.WriteField("key", "value")
+ checks.NoError(t, err, "should write field without error")
+ })
+}
+
+func TestCreateFormFile(t *testing.T) {
+ buf := &bytes.Buffer{}
+ builder := NewFormBuilder(buf)
+
+ err := builder.createFormFile("file", bytes.NewBufferString("data"), "")
+ if err == nil {
+ t.Fatal("expected error for empty filename")
+ }
+
+ builder = NewFormBuilder(&failingWriter{})
+ err = builder.createFormFile("file", bytes.NewBufferString("data"), "name")
+ checks.ErrorIs(t, err, errMockFailingWriterError, "should propagate writer error")
+}
+
+func TestCreateFormFileSuccess(t *testing.T) {
+ buf := &bytes.Buffer{}
+ builder := NewFormBuilder(buf)
+
+ err := builder.createFormFile("file", bytes.NewBufferString("data"), "foo.txt")
+ checks.NoError(t, err, "createFormFile should succeed")
+
+ if !strings.Contains(buf.String(), "filename=\"foo.txt\"") {
+ t.Fatalf("expected filename header, got %q", buf.String())
+ }
+}
diff --git a/internal/marshaller_test.go b/internal/marshaller_test.go
new file mode 100644
index 000000000..70694faed
--- /dev/null
+++ b/internal/marshaller_test.go
@@ -0,0 +1,34 @@
+package openai_test
+
+import (
+ "testing"
+
+ openai "github.com/sashabaranov/go-openai/internal"
+ "github.com/sashabaranov/go-openai/internal/test/checks"
+)
+
+func TestJSONMarshaller_Normal(t *testing.T) {
+ jm := &openai.JSONMarshaller{}
+ data := map[string]string{"key": "value"}
+
+ b, err := jm.Marshal(data)
+ checks.NoError(t, err)
+ if len(b) == 0 {
+ t.Fatal("should return non-empty bytes")
+ }
+}
+
+func TestJSONMarshaller_InvalidInput(t *testing.T) {
+ jm := &openai.JSONMarshaller{}
+ _, err := jm.Marshal(make(chan int))
+ checks.HasError(t, err, "should return error for unsupported type")
+}
+
+func TestJSONMarshaller_EmptyValue(t *testing.T) {
+ jm := &openai.JSONMarshaller{}
+ b, err := jm.Marshal(nil)
+ checks.NoError(t, err)
+ if string(b) != "null" {
+ t.Fatalf("unexpected marshaled value: %s", string(b))
+ }
+}
diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go
index e26022a6b..adccb158e 100644
--- a/internal/request_builder_test.go
+++ b/internal/request_builder_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"errors"
+ "io"
"net/http"
"reflect"
"testing"
@@ -59,3 +60,37 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
t.Errorf("Build() got = %v, want %v", got, want)
}
}
+
+func TestRequestBuilderWithReaderBodyAndHeader(t *testing.T) {
+ b := NewRequestBuilder()
+ ctx := context.Background()
+ method := http.MethodPost
+ url := "/reader"
+ bodyContent := "hello"
+ body := bytes.NewBufferString(bodyContent)
+ header := http.Header{"X-Test": []string{"val"}}
+
+ req, err := b.Build(ctx, method, url, body, header)
+ if err != nil {
+ t.Fatalf("Build returned error: %v", err)
+ }
+
+ gotBody, err := io.ReadAll(req.Body)
+ if err != nil {
+ t.Fatalf("cannot read body: %v", err)
+ }
+ if string(gotBody) != bodyContent {
+ t.Fatalf("expected body %q, got %q", bodyContent, string(gotBody))
+ }
+ if req.Header.Get("X-Test") != "val" {
+ t.Fatalf("expected header set to val, got %q", req.Header.Get("X-Test"))
+ }
+}
+
+func TestRequestBuilderInvalidURL(t *testing.T) {
+ b := NewRequestBuilder()
+ _, err := b.Build(context.Background(), http.MethodGet, ":", nil, nil)
+ if err == nil {
+ t.Fatal("expected error for invalid URL")
+ }
+}
diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go
new file mode 100644
index 000000000..0677054df
--- /dev/null
+++ b/internal/test/checks/checks_test.go
@@ -0,0 +1,19 @@
+package checks_test
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/sashabaranov/go-openai/internal/test/checks"
+)
+
+func TestChecksSuccessPaths(t *testing.T) {
+ checks.NoError(t, nil)
+ checks.NoErrorF(t, nil)
+ checks.HasError(t, errors.New("err"))
+ target := errors.New("x")
+ checks.ErrorIs(t, target, target)
+ checks.ErrorIsF(t, target, target, "msg")
+ checks.ErrorIsNot(t, errors.New("y"), target)
+ checks.ErrorIsNotf(t, errors.New("y"), target, "msg")
+}
diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go
new file mode 100644
index 000000000..fb1f4bf06
--- /dev/null
+++ b/internal/test/failer_test.go
@@ -0,0 +1,24 @@
+//nolint:testpackage // need access to unexported fields and types for testing
+package test
+
+import (
+ "errors"
+ "testing"
+)
+
+func TestFailingErrorBuffer(t *testing.T) {
+ buf := &FailingErrorBuffer{}
+ n, err := buf.Write([]byte("test"))
+ if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) {
+ t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err)
+ }
+ if n != 0 {
+ t.Fatalf("expected n=0, got %d", n)
+ }
+ if buf.Len() != 0 {
+ t.Fatalf("expected Len 0, got %d", buf.Len())
+ }
+ if len(buf.Bytes()) != 0 {
+ t.Fatalf("expected empty bytes")
+ }
+}
diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go
new file mode 100644
index 000000000..aa177679b
--- /dev/null
+++ b/internal/test/helpers_test.go
@@ -0,0 +1,54 @@
+package test_test
+
+import (
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ internaltest "github.com/sashabaranov/go-openai/internal/test"
+)
+
+func TestCreateTestFile(t *testing.T) {
+ dir := t.TempDir()
+ path := filepath.Join(dir, "file.txt")
+ internaltest.CreateTestFile(t, path)
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("failed to read created file: %v", err)
+ }
+ if string(data) != "hello" {
+ t.Fatalf("unexpected file contents: %q", string(data))
+ }
+}
+
+func TestTokenRoundTripperAddsHeader(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() {
+ t.Fatalf("authorization header not set")
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ client := srv.Client()
+ client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport}
+
+ req, err := http.NewRequest(http.MethodGet, srv.URL, nil)
+ if err != nil {
+ t.Fatalf("request error: %v", err)
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ t.Fatalf("client request error: %v", err)
+ }
+ if _, err = io.Copy(io.Discard, resp.Body); err != nil {
+ t.Fatalf("read body: %v", err)
+ }
+ resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("unexpected status: %d", resp.StatusCode)
+ }
+}
diff --git a/internal/test/server.go b/internal/test/server.go
index 127d4c16f..d32c3e4cb 100644
--- a/internal/test/server.go
+++ b/internal/test/server.go
@@ -23,6 +23,18 @@ func NewTestServer() *ServerTest {
return &ServerTest{handlers: make(map[string]handler)}
}
+// HandlerCount returns the number of registered handlers.
+func (ts *ServerTest) HandlerCount() int {
+ return len(ts.handlers)
+}
+
+// HasHandler checks if a handler was registered for the given path.
+func (ts *ServerTest) HasHandler(path string) bool {
+ path = strings.ReplaceAll(path, "*", ".*")
+ _, ok := ts.handlers[path]
+ return ok
+}
+
func (ts *ServerTest) RegisterHandler(path string, handler handler) {
// to make the registered paths friendlier to a regex match in the route handler
// in OpenAITestServer
diff --git a/internal/test/server_test.go b/internal/test/server_test.go
new file mode 100644
index 000000000..f8ce731d1
--- /dev/null
+++ b/internal/test/server_test.go
@@ -0,0 +1,80 @@
+package test_test
+
+import (
+ "io"
+ "net/http"
+ "testing"
+
+ internaltest "github.com/sashabaranov/go-openai/internal/test"
+)
+
+func TestGetTestToken(t *testing.T) {
+ if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" {
+ t.Fatalf("unexpected token")
+ }
+}
+
+func TestNewTestServer(t *testing.T) {
+ ts := internaltest.NewTestServer()
+ if ts == nil {
+ t.Fatalf("server not properly initialized")
+ }
+ if ts.HandlerCount() != 0 {
+ t.Fatalf("expected no handlers initially")
+ }
+}
+
+func TestRegisterHandlerTransformsPath(t *testing.T) {
+ ts := internaltest.NewTestServer()
+ h := func(_ http.ResponseWriter, _ *http.Request) {}
+ ts.RegisterHandler("/foo/*", h)
+ if !ts.HasHandler("/foo/*") {
+ t.Fatalf("handler not registered with transformed path")
+ }
+}
+
+func TestOpenAITestServer(t *testing.T) {
+ ts := internaltest.NewTestServer()
+ ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) {
+ if _, err := io.WriteString(w, "ok"); err != nil {
+ t.Fatalf("write: %v", err)
+ }
+ })
+ srv := ts.OpenAITestServer()
+ srv.Start()
+ defer srv.Close()
+
+ base := srv.Client().Transport
+ client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}}
+ resp, err := client.Get(srv.URL + "/v1/test/123")
+ if err != nil {
+ t.Fatalf("request error: %v", err)
+ }
+ body, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ t.Fatalf("read response body: %v", err)
+ }
+ if resp.StatusCode != http.StatusOK || string(body) != "ok" {
+ t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body))
+ }
+
+ // unregistered path
+ resp, err = client.Get(srv.URL + "/unknown")
+ if err != nil {
+ t.Fatalf("request error: %v", err)
+ }
+ if resp.StatusCode != http.StatusNotFound {
+ t.Fatalf("expected 404, got %d", resp.StatusCode)
+ }
+
+ // missing token should return unauthorized
+ clientNoToken := srv.Client()
+ resp, err = clientNoToken.Get(srv.URL + "/v1/test/123")
+ if err != nil {
+ t.Fatalf("request error: %v", err)
+ }
+ if resp.StatusCode != http.StatusUnauthorized {
+ t.Fatalf("expected 401, got %d", resp.StatusCode)
+ }
+}
diff --git a/internal/unmarshaler_test.go b/internal/unmarshaler_test.go
new file mode 100644
index 000000000..d63eac779
--- /dev/null
+++ b/internal/unmarshaler_test.go
@@ -0,0 +1,37 @@
+package openai_test
+
+import (
+ "testing"
+
+ openai "github.com/sashabaranov/go-openai/internal"
+ "github.com/sashabaranov/go-openai/internal/test/checks"
+)
+
+func TestJSONUnmarshaler_Normal(t *testing.T) {
+ jm := &openai.JSONUnmarshaler{}
+ data := []byte(`{"key":"value"}`)
+ var v map[string]string
+
+ err := jm.Unmarshal(data, &v)
+ checks.NoError(t, err)
+ if v["key"] != "value" {
+ t.Fatal("unmarshal result mismatch")
+ }
+}
+
+func TestJSONUnmarshaler_InvalidJSON(t *testing.T) {
+ jm := &openai.JSONUnmarshaler{}
+ data := []byte(`{invalid}`)
+ var v map[string]interface{}
+
+ err := jm.Unmarshal(data, &v)
+ checks.HasError(t, err, "should return error for invalid JSON")
+}
+
+func TestJSONUnmarshaler_EmptyInput(t *testing.T) {
+ jm := &openai.JSONUnmarshaler{}
+ var v interface{}
+
+ err := jm.Unmarshal(nil, &v)
+ checks.HasError(t, err, "should return error for nil input")
+}
diff --git a/jsonschema/containsref_test.go b/jsonschema/containsref_test.go
new file mode 100644
index 000000000..dc1842775
--- /dev/null
+++ b/jsonschema/containsref_test.go
@@ -0,0 +1,48 @@
+package jsonschema_test
+
+import (
+ "testing"
+
+ "github.com/sashabaranov/go-openai/jsonschema"
+)
+
+// SelfRef struct used to produce a self-referential schema.
+type SelfRef struct {
+ Friends []SelfRef `json:"friends"`
+}
+
+// Address struct referenced by Person without self-reference.
+type Address struct {
+ Street string `json:"street"`
+}
+
+type Person struct {
+ Address Address `json:"address"`
+}
+
+// TestGenerateSchemaForType_SelfRef ensures that self-referential types are not
+// flattened during schema generation.
+func TestGenerateSchemaForType_SelfRef(t *testing.T) {
+ schema, err := jsonschema.GenerateSchemaForType(SelfRef{})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if _, ok := schema.Defs["SelfRef"]; !ok {
+ t.Fatal("expected defs to contain SelfRef for self reference")
+ }
+}
+
+// TestGenerateSchemaForType_NoSelfRef ensures that non-self-referential types
+// are flattened and do not reappear in $defs.
+func TestGenerateSchemaForType_NoSelfRef(t *testing.T) {
+ schema, err := jsonschema.GenerateSchemaForType(Person{})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if _, ok := schema.Defs["Person"]; ok {
+ t.Fatal("unexpected Person definition in defs")
+ }
+ if _, ok := schema.Defs["Address"]; !ok {
+ t.Fatal("expected Address definition in defs")
+ }
+}
diff --git a/jsonschema/json.go b/jsonschema/json.go
index d458418f3..75e3b5173 100644
--- a/jsonschema/json.go
+++ b/jsonschema/json.go
@@ -48,6 +48,11 @@ type Definition struct {
AdditionalProperties any `json:"additionalProperties,omitempty"`
// Whether the schema is nullable or not.
Nullable bool `json:"nullable,omitempty"`
+
+ // Ref Reference to a definition in $defs or external schema.
+ Ref string `json:"$ref,omitempty"`
+ // Defs A map of reusable schema definitions.
+ Defs map[string]Definition `json:"$defs,omitempty"`
}
func (d *Definition) MarshalJSON() ([]byte, error) {
@@ -67,10 +72,37 @@ func (d *Definition) Unmarshal(content string, v any) error {
}
func GenerateSchemaForType(v any) (*Definition, error) {
- return reflectSchema(reflect.TypeOf(v))
+ var defs = make(map[string]Definition)
+ def, err := reflectSchema(reflect.TypeOf(v), defs)
+ if err != nil {
+ return nil, err
+ }
+ // If the schema has a root $ref, resolve it by:
+ // 1. Extracting the key from the $ref.
+ // 2. Detaching the referenced definition from $defs.
+ // 3. Checking for self-references in the detached definition.
+ // - If a self-reference is found, restore the original $defs structure.
+ // 4. Flattening the referenced definition into the root schema.
+ // 5. Clearing the $ref field in the root schema.
+ if def.Ref != "" {
+ origRef := def.Ref
+ key := strings.TrimPrefix(origRef, "#/$defs/")
+ if root, ok := defs[key]; ok {
+ delete(defs, key)
+ root.Defs = defs
+ if containsRef(root, origRef) {
+ root.Defs = nil
+ defs[key] = root
+ }
+ *def = root
+ }
+ def.Ref = ""
+ }
+ def.Defs = defs
+ return def, nil
}
-func reflectSchema(t reflect.Type) (*Definition, error) {
+func reflectSchema(t reflect.Type, defs map[string]Definition) (*Definition, error) {
var d Definition
switch t.Kind() {
case reflect.String:
@@ -84,21 +116,32 @@ func reflectSchema(t reflect.Type) (*Definition, error) {
d.Type = Boolean
case reflect.Slice, reflect.Array:
d.Type = Array
- items, err := reflectSchema(t.Elem())
+ items, err := reflectSchema(t.Elem(), defs)
if err != nil {
return nil, err
}
d.Items = items
case reflect.Struct:
+ if t.Name() != "" {
+ if _, ok := defs[t.Name()]; !ok {
+ defs[t.Name()] = Definition{}
+ object, err := reflectSchemaObject(t, defs)
+ if err != nil {
+ return nil, err
+ }
+ defs[t.Name()] = *object
+ }
+ return &Definition{Ref: "#/$defs/" + t.Name()}, nil
+ }
d.Type = Object
d.AdditionalProperties = false
- object, err := reflectSchemaObject(t)
+ object, err := reflectSchemaObject(t, defs)
if err != nil {
return nil, err
}
d = *object
case reflect.Ptr:
- definition, err := reflectSchema(t.Elem())
+ definition, err := reflectSchema(t.Elem(), defs)
if err != nil {
return nil, err
}
@@ -112,7 +155,7 @@ func reflectSchema(t reflect.Type) (*Definition, error) {
return &d, nil
}
-func reflectSchemaObject(t reflect.Type) (*Definition, error) {
+func reflectSchemaObject(t reflect.Type, defs map[string]Definition) (*Definition, error) {
var d = Definition{
Type: Object,
AdditionalProperties: false,
@@ -126,14 +169,17 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
}
jsonTag := field.Tag.Get("json")
var required = true
- if jsonTag == "" {
+ switch {
+ case jsonTag == "-":
+ continue
+ case jsonTag == "":
jsonTag = field.Name
- } else if strings.HasSuffix(jsonTag, ",omitempty") {
+ case strings.HasSuffix(jsonTag, ",omitempty"):
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
required = false
}
- item, err := reflectSchema(field.Type)
+ item, err := reflectSchema(field.Type, defs)
if err != nil {
return nil, err
}
@@ -164,3 +210,26 @@ func reflectSchemaObject(t reflect.Type) (*Definition, error) {
d.Properties = properties
return &d, nil
}
+
+func containsRef(def Definition, targetRef string) bool {
+ if def.Ref == targetRef {
+ return true
+ }
+
+ for _, d := range def.Defs {
+ if containsRef(d, targetRef) {
+ return true
+ }
+ }
+
+ for _, prop := range def.Properties {
+ if containsRef(prop, targetRef) {
+ return true
+ }
+ }
+
+ if def.Items != nil && containsRef(*def.Items, targetRef) {
+ return true
+ }
+ return false
+}
diff --git a/jsonschema/json_additional_test.go b/jsonschema/json_additional_test.go
new file mode 100644
index 000000000..70cf37490
--- /dev/null
+++ b/jsonschema/json_additional_test.go
@@ -0,0 +1,73 @@
+package jsonschema_test
+
+import (
+ "testing"
+
+ "github.com/sashabaranov/go-openai/jsonschema"
+)
+
+// Test Definition.Unmarshal, including success path, validation error,
+// JSON syntax error and type mismatch during unmarshalling.
+func TestDefinitionUnmarshal(t *testing.T) {
+ schema := jsonschema.Definition{
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "name": {Type: jsonschema.String},
+ },
+ }
+
+ var dst struct {
+ Name string `json:"name"`
+ }
+ if err := schema.Unmarshal(`{"name":"foo"}`, &dst); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if dst.Name != "foo" {
+ t.Errorf("expected name to be foo, got %q", dst.Name)
+ }
+
+ if err := schema.Unmarshal(`{`, &dst); err == nil {
+ t.Error("expected error for malformed json")
+ }
+
+ if err := schema.Unmarshal(`{"name":1}`, &dst); err == nil {
+ t.Error("expected validation error")
+ }
+
+ numSchema := jsonschema.Definition{Type: jsonschema.Number}
+ var s string
+ if err := numSchema.Unmarshal(`123`, &s); err == nil {
+ t.Error("expected unmarshal type error")
+ }
+}
+
+// Ensure GenerateSchemaForType returns an error when encountering unsupported types.
+func TestGenerateSchemaForTypeUnsupported(t *testing.T) {
+ type Bad struct {
+ Ch chan int `json:"ch"`
+ }
+ if _, err := jsonschema.GenerateSchemaForType(Bad{}); err == nil {
+ t.Fatal("expected error for unsupported type")
+ }
+}
+
+// Validate should fail when provided data does not match the expected container types.
+func TestValidateInvalidContainers(t *testing.T) {
+ objSchema := jsonschema.Definition{Type: jsonschema.Object}
+ if jsonschema.Validate(objSchema, 1) {
+ t.Error("expected object validation to fail for non-map input")
+ }
+
+ arrSchema := jsonschema.Definition{Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}}
+ if jsonschema.Validate(arrSchema, 1) {
+ t.Error("expected array validation to fail for non-slice input")
+ }
+}
+
+// Validate should return false when $ref cannot be resolved.
+func TestValidateRefNotFound(t *testing.T) {
+ refSchema := jsonschema.Definition{Ref: "#/$defs/Missing"}
+ if jsonschema.Validate(refSchema, "data", jsonschema.WithDefs(map[string]jsonschema.Definition{})) {
+ t.Error("expected validation to fail when reference is missing")
+ }
+}
diff --git a/jsonschema/json_errors_test.go b/jsonschema/json_errors_test.go
new file mode 100644
index 000000000..3b864fc21
--- /dev/null
+++ b/jsonschema/json_errors_test.go
@@ -0,0 +1,27 @@
+package jsonschema_test
+
+import (
+ "testing"
+
+ "github.com/sashabaranov/go-openai/jsonschema"
+)
+
+// TestGenerateSchemaForType_ErrorPaths verifies error handling for unsupported types.
+func TestGenerateSchemaForType_ErrorPaths(t *testing.T) {
+ type anon struct{ Ch chan int }
+ tests := []struct {
+ name string
+ v any
+ }{
+ {"slice", []chan int{}},
+ {"anon struct", anon{}},
+ {"pointer", (*chan int)(nil)},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if _, err := jsonschema.GenerateSchemaForType(tt.v); err == nil {
+ t.Errorf("expected error for %s", tt.name)
+ }
+ })
+ }
+}
diff --git a/jsonschema/json_test.go b/jsonschema/json_test.go
index 17f0aba8a..34f5d88eb 100644
--- a/jsonschema/json_test.go
+++ b/jsonschema/json_test.go
@@ -182,7 +182,37 @@ func TestDefinition_MarshalJSON(t *testing.T) {
}
}
+type User struct {
+ ID int `json:"id,omitempty"`
+ Name string `json:"name,omitempty"`
+ Orders []*Order `json:"orders,omitempty"`
+}
+
+type Order struct {
+ ID int `json:"id,omitempty"`
+ Amount float64 `json:"amount,omitempty"`
+ Buyer *User `json:"buyer,omitempty"`
+}
+
func TestStructToSchema(t *testing.T) {
+ type Tweet struct {
+ Text string `json:"text"`
+ }
+
+ type Person struct {
+ Name string `json:"name,omitempty"`
+ Age int `json:"age,omitempty"`
+ Friends []Person `json:"friends,omitempty"`
+ Tweets []Tweet `json:"tweets,omitempty"`
+ }
+
+ type MyStructuredResponse struct {
+ PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
+ CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
+ KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
+ SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
+ }
+
tests := []struct {
name string
in any
@@ -329,6 +359,267 @@ func TestStructToSchema(t *testing.T) {
"additionalProperties":false
}`,
},
+ {
+ name: "Test with exclude mark",
+ in: struct {
+ Name string `json:"-"`
+ }{
+ Name: "Name",
+ },
+ want: `{
+ "type":"object",
+ "additionalProperties":false
+ }`,
+ },
+ {
+ name: "Test with no json tag",
+ in: struct {
+ Name string
+ }{
+ Name: "",
+ },
+ want: `{
+ "type":"object",
+ "properties":{
+ "Name":{
+ "type":"string"
+ }
+ },
+ "required":["Name"],
+ "additionalProperties":false
+ }`,
+ },
+ {
+ name: "Test with omitempty tag",
+ in: struct {
+ Name string `json:"name,omitempty"`
+ }{
+ Name: "",
+ },
+ want: `{
+ "type":"object",
+ "properties":{
+ "name":{
+ "type":"string"
+ }
+ },
+ "additionalProperties":false
+ }`,
+ },
+ {
+ name: "Test with $ref and $defs",
+ in: struct {
+ Person Person `json:"person"`
+ Tweets []Tweet `json:"tweets"`
+ }{},
+ want: `{
+ "type" : "object",
+ "properties" : {
+ "person" : {
+ "$ref" : "#/$defs/Person"
+ },
+ "tweets" : {
+ "type" : "array",
+ "items" : {
+ "$ref" : "#/$defs/Tweet"
+ }
+ }
+ },
+ "required" : [ "person", "tweets" ],
+ "additionalProperties" : false,
+ "$defs" : {
+ "Person" : {
+ "type" : "object",
+ "properties" : {
+ "age" : {
+ "type" : "integer"
+ },
+ "friends" : {
+ "type" : "array",
+ "items" : {
+ "$ref" : "#/$defs/Person"
+ }
+ },
+ "name" : {
+ "type" : "string"
+ },
+ "tweets" : {
+ "type" : "array",
+ "items" : {
+ "$ref" : "#/$defs/Tweet"
+ }
+ }
+ },
+ "additionalProperties" : false
+ },
+ "Tweet" : {
+ "type" : "object",
+ "properties" : {
+ "text" : {
+ "type" : "string"
+ }
+ },
+ "required" : [ "text" ],
+ "additionalProperties" : false
+ }
+ }
+}`,
+ },
+ {
+ name: "Test Person",
+ in: Person{},
+ want: `{
+ "type": "object",
+ "properties": {
+ "age": {
+ "type": "integer"
+ },
+ "friends": {
+ "type": "array",
+ "items": {
+ "$ref": "#/$defs/Person"
+ }
+ },
+ "name": {
+ "type": "string"
+ },
+ "tweets": {
+ "type": "array",
+ "items": {
+ "$ref": "#/$defs/Tweet"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "$defs": {
+ "Person": {
+ "type": "object",
+ "properties": {
+ "age": {
+ "type": "integer"
+ },
+ "friends": {
+ "type": "array",
+ "items": {
+ "$ref": "#/$defs/Person"
+ }
+ },
+ "name": {
+ "type": "string"
+ },
+ "tweets": {
+ "type": "array",
+ "items": {
+ "$ref": "#/$defs/Tweet"
+ }
+ }
+ },
+ "additionalProperties": false
+ },
+ "Tweet": {
+ "type": "object",
+ "properties": {
+ "text": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "text"
+ ],
+ "additionalProperties": false
+ }
+ }
+}`,
+ },
+ {
+ name: "Test MyStructuredResponse",
+ in: MyStructuredResponse{},
+ want: `{
+ "type": "object",
+ "properties": {
+ "camel_case": {
+ "type": "string",
+ "description": "CamelCase"
+ },
+ "kebab_case": {
+ "type": "string",
+ "description": "KebabCase"
+ },
+ "pascal_case": {
+ "type": "string",
+ "description": "PascalCase"
+ },
+ "snake_case": {
+ "type": "string",
+ "description": "SnakeCase"
+ }
+ },
+ "required": [
+ "pascal_case",
+ "camel_case",
+ "kebab_case",
+ "snake_case"
+ ],
+ "additionalProperties": false
+}`,
+ },
+ {
+ name: "Test User",
+ in: User{},
+ want: `{
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer"
+ },
+ "name": {
+ "type": "string"
+ },
+ "orders": {
+ "type": "array",
+ "items": {
+ "$ref": "#/$defs/Order"
+ }
+ }
+ },
+ "additionalProperties": false,
+ "$defs": {
+ "Order": {
+ "type": "object",
+ "properties": {
+ "amount": {
+ "type": "number"
+ },
+ "buyer": {
+ "$ref": "#/$defs/User"
+ },
+ "id": {
+ "type": "integer"
+ }
+ },
+ "additionalProperties": false
+ },
+ "User": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "integer"
+ },
+ "name": {
+ "type": "string"
+ },
+ "orders": {
+ "type": "array",
+ "items": {
+ "$ref": "#/$defs/Order"
+ }
+ }
+ },
+ "additionalProperties": false
+ }
+ }
+}`,
+ },
}
for _, tt := range tests {
diff --git a/jsonschema/validate.go b/jsonschema/validate.go
index 49f9b8859..1bd2f809c 100644
--- a/jsonschema/validate.go
+++ b/jsonschema/validate.go
@@ -5,26 +5,68 @@ import (
"errors"
)
+func CollectDefs(def Definition) map[string]Definition {
+ result := make(map[string]Definition)
+ collectDefsRecursive(def, result, "#")
+ return result
+}
+
+func collectDefsRecursive(def Definition, result map[string]Definition, prefix string) {
+ for k, v := range def.Defs {
+ path := prefix + "/$defs/" + k
+ result[path] = v
+ collectDefsRecursive(v, result, path)
+ }
+ for k, sub := range def.Properties {
+ collectDefsRecursive(sub, result, prefix+"/properties/"+k)
+ }
+ if def.Items != nil {
+ collectDefsRecursive(*def.Items, result, prefix)
+ }
+}
+
func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error {
var data any
err := json.Unmarshal(content, &data)
if err != nil {
return err
}
- if !Validate(schema, data) {
+ if !Validate(schema, data, WithDefs(CollectDefs(schema))) {
return errors.New("data validation failed against the provided schema")
}
return json.Unmarshal(content, &v)
}
-func Validate(schema Definition, data any) bool {
+type validateArgs struct {
+ Defs map[string]Definition
+}
+
+type ValidateOption func(*validateArgs)
+
+func WithDefs(defs map[string]Definition) ValidateOption {
+ return func(option *validateArgs) {
+ option.Defs = defs
+ }
+}
+
+func Validate(schema Definition, data any, opts ...ValidateOption) bool {
+ args := validateArgs{}
+ for _, opt := range opts {
+ opt(&args)
+ }
+ if len(opts) == 0 {
+ args.Defs = CollectDefs(schema)
+ }
switch schema.Type {
case Object:
- return validateObject(schema, data)
+ return validateObject(schema, data, args.Defs)
case Array:
- return validateArray(schema, data)
+ return validateArray(schema, data, args.Defs)
case String:
- _, ok := data.(string)
+ v, ok := data.(string)
+ if ok && len(schema.Enum) > 0 {
+ return contains(schema.Enum, v)
+ }
return ok
case Number: // float64 and int
_, ok := data.(float64)
@@ -45,11 +87,16 @@ func Validate(schema Definition, data any) bool {
case Null:
return data == nil
default:
+ if schema.Ref != "" && args.Defs != nil {
+ if v, ok := args.Defs[schema.Ref]; ok {
+ return Validate(v, data, WithDefs(args.Defs))
+ }
+ }
return false
}
}
-func validateObject(schema Definition, data any) bool {
+func validateObject(schema Definition, data any, defs map[string]Definition) bool {
dataMap, ok := data.(map[string]any)
if !ok {
return false
@@ -61,7 +108,7 @@ func validateObject(schema Definition, data any) bool {
}
for key, valueSchema := range schema.Properties {
value, exists := dataMap[key]
- if exists && !Validate(valueSchema, value) {
+ if exists && !Validate(valueSchema, value, WithDefs(defs)) {
return false
} else if !exists && contains(schema.Required, key) {
return false
@@ -70,13 +117,13 @@ func validateObject(schema Definition, data any) bool {
return true
}
-func validateArray(schema Definition, data any) bool {
+func validateArray(schema Definition, data any, defs map[string]Definition) bool {
dataArray, ok := data.([]any)
if !ok {
return false
}
for _, item := range dataArray {
- if !Validate(*schema.Items, item) {
+ if !Validate(*schema.Items, item, WithDefs(defs)) {
return false
}
}
diff --git a/jsonschema/validate_test.go b/jsonschema/validate_test.go
index 6fa30ab0c..aefdf4069 100644
--- a/jsonschema/validate_test.go
+++ b/jsonschema/validate_test.go
@@ -1,6 +1,7 @@
package jsonschema_test
import (
+ "reflect"
"testing"
"github.com/sashabaranov/go-openai/jsonschema"
@@ -70,6 +71,96 @@ func Test_Validate(t *testing.T) {
},
Required: []string{"string"},
}}, false},
+ {
+ "test schema with ref and defs", args{data: map[string]any{
+ "person": map[string]any{
+ "name": "John",
+ "gender": "male",
+ "age": 28,
+ "profile": map[string]any{
+ "full_name": "John Doe",
+ },
+ },
+ }, schema: jsonschema.Definition{
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "person": {Ref: "#/$defs/Person"},
+ },
+ Required: []string{"person"},
+ Defs: map[string]jsonschema.Definition{
+ "Person": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "name": {Type: jsonschema.String},
+ "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}},
+ "age": {Type: jsonschema.Integer},
+ "profile": {Ref: "#/$defs/Person/$defs/Profile"},
+ "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}},
+ },
+ Required: []string{"name", "gender", "age", "profile"},
+ Defs: map[string]jsonschema.Definition{
+ "Profile": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "full_name": {Type: jsonschema.String},
+ },
+ },
+ },
+ },
+ "Tweet": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "text": {Type: jsonschema.String},
+ "person": {Ref: "#/$defs/Person"},
+ },
+ },
+ },
+ }}, true},
+ {
+ "test enum invalid value", args{data: map[string]any{
+ "person": map[string]any{
+ "name": "John",
+ "gender": "other",
+ "age": 28,
+ "profile": map[string]any{
+ "full_name": "John Doe",
+ },
+ },
+ }, schema: jsonschema.Definition{
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "person": {Ref: "#/$defs/Person"},
+ },
+ Required: []string{"person"},
+ Defs: map[string]jsonschema.Definition{
+ "Person": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "name": {Type: jsonschema.String},
+ "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}},
+ "age": {Type: jsonschema.Integer},
+ "profile": {Ref: "#/$defs/Person/$defs/Profile"},
+ "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}},
+ },
+ Required: []string{"name", "gender", "age", "profile"},
+ Defs: map[string]jsonschema.Definition{
+ "Profile": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "full_name": {Type: jsonschema.String},
+ },
+ },
+ },
+ },
+ "Tweet": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "text": {Type: jsonschema.String},
+ "person": {Ref: "#/$defs/Person"},
+ },
+ },
+ },
+ }}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -156,8 +247,100 @@ func TestUnmarshal(t *testing.T) {
err := jsonschema.VerifySchemaAndUnmarshal(tt.args.schema, tt.args.content, tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("Unmarshal() error = %v, wantErr %v", err, tt.wantErr)
- } else if err == nil {
- t.Logf("Unmarshal() v = %+v\n", tt.args.v)
+ }
+ })
+ }
+}
+
+func TestCollectDefs(t *testing.T) {
+ type args struct {
+ schema jsonschema.Definition
+ }
+ tests := []struct {
+ name string
+ args args
+ want map[string]jsonschema.Definition
+ }{
+ {
+ "test collect defs",
+ args{
+ schema: jsonschema.Definition{
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "person": {Ref: "#/$defs/Person"},
+ },
+ Required: []string{"person"},
+ Defs: map[string]jsonschema.Definition{
+ "Person": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "name": {Type: jsonschema.String},
+ "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}},
+ "age": {Type: jsonschema.Integer},
+ "profile": {Ref: "#/$defs/Person/$defs/Profile"},
+ "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}},
+ },
+ Required: []string{"name", "gender", "age", "profile"},
+ Defs: map[string]jsonschema.Definition{
+ "Profile": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "full_name": {Type: jsonschema.String},
+ },
+ },
+ },
+ },
+ "Tweet": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "text": {Type: jsonschema.String},
+ "person": {Ref: "#/$defs/Person"},
+ },
+ },
+ },
+ },
+ },
+ map[string]jsonschema.Definition{
+ "#/$defs/Person": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "name": {Type: jsonschema.String},
+ "gender": {Type: jsonschema.String, Enum: []string{"male", "female", "unknown"}},
+ "age": {Type: jsonschema.Integer},
+ "profile": {Ref: "#/$defs/Person/$defs/Profile"},
+ "tweets": {Type: jsonschema.Array, Items: &jsonschema.Definition{Ref: "#/$defs/Tweet"}},
+ },
+ Required: []string{"name", "gender", "age", "profile"},
+ Defs: map[string]jsonschema.Definition{
+ "Profile": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "full_name": {Type: jsonschema.String},
+ },
+ },
+ },
+ },
+ "#/$defs/Person/$defs/Profile": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "full_name": {Type: jsonschema.String},
+ },
+ },
+ "#/$defs/Tweet": {
+ Type: jsonschema.Object,
+ Properties: map[string]jsonschema.Definition{
+ "text": {Type: jsonschema.String},
+ "person": {Ref: "#/$defs/Person"},
+ },
+ },
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := jsonschema.CollectDefs(tt.args.schema)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("CollectDefs() = %v, want %v", got, tt.want)
}
})
}
diff --git a/reasoning_validator.go b/reasoning_validator.go
index 2910b1395..1d26ca047 100644
--- a/reasoning_validator.go
+++ b/reasoning_validator.go
@@ -28,21 +28,22 @@ var (
ErrReasoningModelLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)
-// ReasoningValidator handles validation for o-series model requests.
+// ReasoningValidator handles validation for reasoning model requests.
type ReasoningValidator struct{}
-// NewReasoningValidator creates a new validator for o-series models.
+// NewReasoningValidator creates a new validator for reasoning models.
func NewReasoningValidator() *ReasoningValidator {
return &ReasoningValidator{}
}
-// Validate performs all validation checks for o-series models.
+// Validate performs all validation checks for reasoning models.
func (v *ReasoningValidator) Validate(request ChatCompletionRequest) error {
o1Series := strings.HasPrefix(request.Model, "o1")
o3Series := strings.HasPrefix(request.Model, "o3")
o4Series := strings.HasPrefix(request.Model, "o4")
+ gpt5Series := strings.HasPrefix(request.Model, "gpt-5")
- if !o1Series && !o3Series && !o4Series {
+ if !o1Series && !o3Series && !o4Series && !gpt5Series {
return nil
}