From b2937ca71d2ba437040b551e87fb334877795c90 Mon Sep 17 00:00:00 2001 From: Giordano Ferreira Date: Sun, 19 Mar 2023 10:54:51 -0300 Subject: [PATCH 1/4] Update module after fork --- README.md | 22 +++++++++++----------- api_test.go | 2 +- audio_test.go | 4 ++-- chat_stream_test.go | 4 ++-- chat_test.go | 4 ++-- completion_test.go | 4 ++-- edits_test.go | 4 ++-- embeddings_test.go | 2 +- files_test.go | 4 ++-- fine_tunes_test.go | 4 ++-- go.mod | 2 +- image_test.go | 4 ++-- models_test.go | 4 ++-- moderation_test.go | 4 ++-- request_builder_test.go | 2 +- stream_test.go | 4 ++-- 16 files changed, 37 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index e6e352e27..57fccea84 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Go OpenAI -[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-openai) -[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-openai)](https://goreportcard.com/report/github.com/sashabaranov/go-openai) -[![codecov](https://codecov.io/gh/sashabaranov/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/sashabaranov/go-openai) +[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/giordanobsf/go-openai) +[![Go Report Card](https://goreportcard.com/badge/github.com/giordanobsf/go-openai)](https://goreportcard.com/report/github.com/giordanobsf/go-openai) +[![codecov](https://codecov.io/gh/giordanobsf/go-openai/branch/master/graph/badge.svg?token=bCbIfHLIsW)](https://codecov.io/gh/giordanobsf/go-openai) > **Note**: the repository was recently renamed from `go-gpt3` to `go-openai` @@ -14,7 +14,7 @@ This library provides Go clients for [OpenAI API](https://platform.openai.com/). Installation: ``` -go get github.com/sashabaranov/go-openai +go get github.com/giordanobsf/go-openai ``` @@ -26,7 +26,7 @@ package main import ( "context" "fmt" - openai "github.com/sashabaranov/go-openai" + openai "github.com/giordanobsf/go-openai" ) func main() { @@ -67,7 +67,7 @@ package main import ( "context" "fmt" - openai "github.com/sashabaranov/go-openai" + openai "github.com/giordanobsf/go-openai" ) func main() { @@ -100,7 +100,7 @@ import ( "context" "fmt" "io" - openai "github.com/sashabaranov/go-openai" + openai "github.com/giordanobsf/go-openai" ) func main() { @@ -149,7 +149,7 @@ import ( "context" "fmt" - openai "github.com/sashabaranov/go-openai" + openai "github.com/giordanobsf/go-openai" ) func main() { @@ -181,7 +181,7 @@ import ( "context" "encoding/base64" "fmt" - openai "github.com/sashabaranov/go-openai" + openai "github.com/giordanobsf/go-openai" "image/png" "os" ) @@ -269,7 +269,7 @@ config.HTTPClient = &http.Client{ c := openai.NewClientWithConfig(config) ``` -See also: https://pkg.go.dev/github.com/sashabaranov/go-openai#ClientConfig +See also: https://pkg.go.dev/github.com/giordanobsf/go-openai#ClientConfig
@@ -285,7 +285,7 @@ import ( "os" "strings" - "github.com/sashabaranov/go-openai" + "github.com/giordanobsf/go-openai" ) func main() { diff --git a/api_test.go b/api_test.go index a5a0d1250..6f15449bc 100644 --- a/api_test.go +++ b/api_test.go @@ -1,7 +1,7 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" + . "github.com/giordanobsf/go-openai" "context" "errors" diff --git a/audio_test.go b/audio_test.go index 2a035c9fe..ed561bbbf 100644 --- a/audio_test.go +++ b/audio_test.go @@ -11,8 +11,8 @@ import ( "path/filepath" "strings" - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "testing" diff --git a/chat_stream_test.go b/chat_stream_test.go index e3da2daf7..ed76fe3df 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/chat_test.go b/chat_test.go index 8866ff2ae..caacad69b 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/completion_test.go b/completion_test.go index daa02e383..cc0332426 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/edits_test.go b/edits_test.go index 6a16f7c2c..ddc64f873 100644 --- a/edits_test.go +++ b/edits_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/embeddings_test.go b/embeddings_test.go index 2aa48c51e..698b9a675 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -1,7 +1,7 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" + . "github.com/giordanobsf/go-openai" "bytes" "encoding/json" diff --git a/files_test.go b/files_test.go index 6a78ce104..dae894764 100644 --- a/files_test.go +++ b/files_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/fine_tunes_test.go b/fine_tunes_test.go index 1f6f96764..0f987f2ec 100644 --- a/fine_tunes_test.go +++ b/fine_tunes_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/go.mod b/go.mod index 42cc7b391..f3262ac2c 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/sashabaranov/go-openai +module github.com/giordanobsf/go-openai go 1.18 diff --git a/image_test.go b/image_test.go index b7949c896..da9b14e3e 100644 --- a/image_test.go +++ b/image_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/models_test.go b/models_test.go index 972a5fe64..ea61b154a 100644 --- a/models_test.go +++ b/models_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/moderation_test.go b/moderation_test.go index f50124534..15e2013cd 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "encoding/json" diff --git a/request_builder_test.go b/request_builder_test.go index f0f99ee5b..d1ec1564f 100644 --- a/request_builder_test.go +++ b/request_builder_test.go @@ -1,7 +1,7 @@ package openai //nolint:testpackage // testing private field import ( - "github.com/sashabaranov/go-openai/internal/test" + "github.com/giordanobsf/go-openai/internal/test" "context" "errors" diff --git a/stream_test.go b/stream_test.go index 8f89e6b85..a757ba2a2 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,8 +1,8 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test" + . "github.com/giordanobsf/go-openai" + "github.com/giordanobsf/go-openai/internal/test" "context" "errors" From a4160c2c403e58eec39cd11c606bc580c173fc38 Mon Sep 17 00:00:00 2001 From: giordanobsf Date: Sun, 19 Mar 2023 15:09:56 -0300 Subject: [PATCH 2/4] feat: read audio file from memory --- audio.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/audio.go b/audio.go index 54bd32fdb..a9090b428 100644 --- a/audio.go +++ b/audio.go @@ -19,6 +19,7 @@ const ( type AudioRequest struct { Model string FilePath string + File *multipart.File } // AudioResponse represents a response structure for audio API. @@ -56,6 +57,7 @@ func (c *Client) callAudioAPI( if err = audioMultipartForm(request, w); err != nil { return } + urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) @@ -71,27 +73,33 @@ func (c *Client) callAudioAPI( // audioMultipartForm creates a form with audio file contents and the name of the model to use for // audio processing. func audioMultipartForm(request AudioRequest, w *multipart.Writer) error { - f, err := os.Open(request.FilePath) - if err != nil { - return fmt.Errorf("opening audio file: %w", err) - } - - fw, err := w.CreateFormFile("file", f.Name()) + fw, err := w.CreateFormFile("file", request.FilePath) if err != nil { return fmt.Errorf("creating form file: %w", err) } - if _, err = io.Copy(fw, f); err != nil { - return fmt.Errorf("reading from opened audio file: %w", err) + if request.File == nil { + f, err := os.Open(request.FilePath) + if err != nil { + return fmt.Errorf("opening audio file: %w", err) + } + + if _, err = io.Copy(fw, f); err != nil { + return fmt.Errorf("reading from opened audio file: %w", err) + } + } else { + if _, err = io.Copy(fw, *request.File); err != nil { + return fmt.Errorf("reading from opened audio file: %w", err) + } } fw, err = w.CreateFormField("model") if err != nil { return fmt.Errorf("creating form field: %w", err) } - + modelName := bytes.NewReader([]byte(request.Model)) - if _, err = io.Copy(fw, modelName); err != nil { + if _, err := io.Copy(fw, modelName); err != nil { return fmt.Errorf("writing model name: %w", err) } w.Close() From 2d9bad0f830b837fbfd2d41d5ca7cebd340001c0 Mon Sep 17 00:00:00 2001 From: giordanobsf Date: Mon, 20 Mar 2023 11:51:02 -0300 Subject: [PATCH 3/4] style: remove endlines --- audio.go | 1 - 1 file changed, 1 deletion(-) diff --git a/audio.go b/audio.go index a9090b428..94a7a15f7 100644 --- a/audio.go +++ b/audio.go @@ -57,7 +57,6 @@ func (c *Client) callAudioAPI( if err = audioMultipartForm(request, w); err != nil { return } - urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix) req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), &formBody) From 86263cb667e78f8656d9eb5a084caf554eaf8109 Mon Sep 17 00:00:00 2001 From: giordanobsf Date: Mon, 20 Mar 2023 12:17:45 -0300 Subject: [PATCH 4/4] Merge sashabaranov/master to giordanobsf/master --- README.md | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++- audio.go | 45 ++++++++++++++++++++++++++++++++++++++++ audio_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 57fccea84..09abf2173 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,61 @@ func main() { Other examples: +
+ChatGPT streaming completion + +```go +package main + +import ( + "context" + "errors" + "fmt" + "io" + openai "github.com/sashabaranov/go-openai" +) + +func main() { + c := openai.NewClient("your token") + ctx := context.Background() + + req := openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + MaxTokens: 20, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Lorem ipsum", + }, + }, + Stream: true, + } + stream, err := c.CreateChatCompletionStream(ctx, req) + if err != nil { + fmt.Printf("ChatCompletionStream error: %v\n", err) + return + } + defer stream.Close() + + fmt.Printf("Stream response: ") + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + fmt.Println("\nStream finished") + return + } + + if err != nil { + fmt.Printf("\nStream error: %v\n", err) + return + } + + fmt.Printf(response.Choices[0].Delta.Content) + } +} +``` +
+
GPT-3 completion @@ -327,4 +382,4 @@ func main() { } } ``` -
\ No newline at end of file +
diff --git a/audio.go b/audio.go index 94a7a15f7..cddb1f6ca 100644 --- a/audio.go +++ b/audio.go @@ -16,10 +16,14 @@ const ( ) // AudioRequest represents a request structure for audio API. +// ResponseFormat is not supported for now. We only return JSON text, which may be sufficient. type AudioRequest struct { Model string FilePath string File *multipart.File + Prompt string // For translation, it should be in English + Temperature float32 + Language string // For translation, just do not use it. It seems "en" works, not confirmed... } // AudioResponse represents a response structure for audio API. @@ -101,6 +105,47 @@ func audioMultipartForm(request AudioRequest, w *multipart.Writer) error { if _, err := io.Copy(fw, modelName); err != nil { return fmt.Errorf("writing model name: %w", err) } + + // Create a form field for the prompt (if provided) + if request.Prompt != "" { + fw, err = w.CreateFormField("prompt") + if err != nil { + return fmt.Errorf("creating form field: %w", err) + } + + prompt := bytes.NewReader([]byte(request.Prompt)) + if _, err = io.Copy(fw, prompt); err != nil { + return fmt.Errorf("writing prompt: %w", err) + } + } + + // Create a form field for the temperature (if provided) + if request.Temperature != 0 { + fw, err = w.CreateFormField("temperature") + if err != nil { + return fmt.Errorf("creating form field: %w", err) + } + + temperature := bytes.NewReader([]byte(fmt.Sprintf("%.2f", request.Temperature))) + if _, err = io.Copy(fw, temperature); err != nil { + return fmt.Errorf("writing temperature: %w", err) + } + } + + // Create a form field for the language (if provided) + if request.Language != "" { + fw, err = w.CreateFormField("language") + if err != nil { + return fmt.Errorf("creating form field: %w", err) + } + + language := bytes.NewReader([]byte(request.Language)) + if _, err = io.Copy(fw, language); err != nil { + return fmt.Errorf("writing language: %w", err) + } + } + + // Close the multipart writer w.Close() return nil diff --git a/audio_test.go b/audio_test.go index ed561bbbf..f8936a6ac 100644 --- a/audio_test.go +++ b/audio_test.go @@ -69,6 +69,59 @@ func TestAudio(t *testing.T) { } } +func TestAudioWithOptionalArgs(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint) + server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + + testcases := []struct { + name string + createFn func(context.Context, AudioRequest) (AudioResponse, error) + }{ + { + "transcribe", + client.CreateTranscription, + }, + { + "translate", + client.CreateTranslation, + }, + } + + ctx := context.Background() + + dir, cleanup := createTestDirectory(t) + defer cleanup() + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + path := filepath.Join(dir, "fake.mp3") + createTestFile(t, path) + + req := AudioRequest{ + FilePath: path, + Model: "whisper-3", + Prompt: "用简体中文", + Temperature: 0.5, + Language: "zh", + } + _, err = tc.createFn(ctx, req) + if err != nil { + t.Fatalf("audio API error: %v", err) + } + }) + } +} + // createTestFile creates a fake file with "hello" as the content. func createTestFile(t *testing.T, path string) { file, err := os.Create(path)