From 8baadc813e9c210659f7a83f615616f80034dae8 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Mon, 29 Apr 2024 15:04:59 +0100 Subject: [PATCH 01/31] Add assistants stream --- run.go | 150 +++++++++++++++++++++++++++++++++++++++++++++++ run_test.go | 15 +++++ stream_reader.go | 2 +- 3 files changed, 166 insertions(+), 1 deletion(-) diff --git a/run.go b/run.go index 094b0a4db..687470cb8 100644 --- a/run.go +++ b/run.go @@ -124,6 +124,11 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +type RunRequestStreaming struct { + RunRequest + Stream bool `json:"stream"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } @@ -149,6 +154,11 @@ type CreateThreadAndRunRequest struct { Thread ThreadRequest `json:"thread"` } +type CreateThreadAndStreamRequest struct { + CreateThreadAndRunRequest + Stream bool `json:"stream"` +} + type RunStep struct { ID string `json:"id"` Object string `json:"object"` @@ -337,6 +347,43 @@ func (c *Client) SubmitToolOutputs( return } +type SubmitToolOutputsStreamRequest struct { + SubmitToolOutputsRequest + Stream bool `json:"stream"` +} + +func (c *Client) SubmitToolOutputsStream( + ctx context.Context, + threadID string, + runID string, + request SubmitToolOutputsRequest, +) (stream *AssistantStream, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) + r := SubmitToolOutputsStreamRequest{ + SubmitToolOutputsRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + // CancelRun cancels a run. func (c *Client) CancelRun( ctx context.Context, @@ -375,6 +422,109 @@ func (c *Client) CreateThreadAndRun( return } +type StreamMessageDelta struct { + Role string `json:"role"` + Content []MessageContent `json:"content"` + FileIDs []string `json:"file_ids"` +} + +type AssistantStreamEvent struct { + ID string `json:"id"` + Object string `json:"object"` + Delta StreamMessageDelta `json:"delta,omitempty"` + + // Run + CreatedAt int64 `json:"created_at,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + AssistantID string `json:"assistant_id,omitempty"` + Status RunStatus `json:"status,omitempty"` + RequiredAction *RunRequiredAction `json:"required_action,omitempty"` + LastError *RunLastError `json:"last_error,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + StartedAt *int64 `json:"started_at,omitempty"` + CancelledAt *int64 `json:"cancelled_at,omitempty"` + FailedAt *int64 `json:"failed_at,omitempty"` + CompletedAt *int64 `json:"completed_at,omitempty"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + FileIDS []string `json:"file_ids"` //nolint:revive // backwards-compatibility + Metadata map[string]any `json:"metadata,omitempty"` + Usage Usage `json:"usage,omitempty"` + + // ThreadMessage.Completed + Role string `json:"role,omitempty"` + Content []MessageContent `json:"content,omitempty"` + // IncompleteDetails + // IncompleteAt + + // Run steps + RunID string `json:"run_id"` + Type RunStepType `json:"type"` + StepDetails StepDetails `json:"step_details"` + ExpiredAt *int64 `json:"expired_at,omitempty"` +} + +type AssistantStream struct { + *streamReader[AssistantStreamEvent] +} + +func (c *Client) CreateThreadAndStream(ctx context.Context, request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { + urlSuffix := "/threads/runs" + sr := CreateThreadAndStreamRequest{ + CreateThreadAndRunRequest: request, + Stream: true, + } + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(sr), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + +func (c *Client) CreateRunStreaming(ctx context.Context, threadID string, request RunRequest) (stream *AssistantStream, err error) { + urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) + + r := RunRequestStreaming{ + RunRequest: request, + Stream: true, + } + + req, err := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(urlSuffix), + withBody(r), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + if err != nil { + return + } + + resp, err := sendRequestStream[AssistantStreamEvent](c, req) + if err != nil { + return + } + stream = &AssistantStream{ + streamReader: resp, + } + return +} + // RetrieveRunStep retrieves a run step. func (c *Client) RetrieveRunStep( ctx context.Context, diff --git a/run_test.go b/run_test.go index cdf99db05..f06c1564b 100644 --- a/run_test.go +++ b/run_test.go @@ -219,6 +219,21 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndRun error") + _, err = client.CreateThreadAndStream(ctx, openai.CreateThreadAndRunRequest{ + RunRequest: openai.RunRequest{ + AssistantID: assistantID, + }, + Thread: openai.ThreadRequest{ + Messages: []openai.ThreadMessage{ + { + Role: openai.ThreadMessageRoleUser, + Content: "Hello, World!", + }, + }, + }, + }) + checks.NoError(t, err, "CreateThreadAndStream error") + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) checks.NoError(t, err, "RetrieveRunStep error") diff --git a/stream_reader.go b/stream_reader.go index 4210a1948..433548794 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -16,7 +16,7 @@ var ( ) type streamable interface { - ChatCompletionStreamResponse | CompletionResponse + ChatCompletionStreamResponse | CompletionResponse | AssistantStreamEvent } type streamReader[T streamable] struct { From 1870579b6933ea9966bed7a1952eb70ebc31c456 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Wed, 8 May 2024 10:01:35 +0100 Subject: [PATCH 02/31] Lint --- run.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/run.go b/run.go index 687470cb8..0db2ec1d2 100644 --- a/run.go +++ b/run.go @@ -469,7 +469,9 @@ type AssistantStream struct { *streamReader[AssistantStreamEvent] } -func (c *Client) CreateThreadAndStream(ctx context.Context, request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { +func (c *Client) CreateThreadAndStream( + ctx context.Context, + request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { urlSuffix := "/threads/runs" sr := CreateThreadAndStreamRequest{ CreateThreadAndRunRequest: request, @@ -496,7 +498,10 @@ func (c *Client) CreateThreadAndStream(ctx context.Context, request CreateThread return } -func (c *Client) CreateRunStreaming(ctx context.Context, threadID string, request RunRequest) (stream *AssistantStream, err error) { +func (c *Client) CreateRunStreaming( + ctx context.Context, + threadID string, + request RunRequest) (stream *AssistantStream, err error) { urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) r := RunRequestStreaming{ From 80aedac6610cb8600963a0b0900da85e73630ed0 Mon Sep 17 00:00:00 2001 From: coolbaluk Date: Mon, 13 May 2024 10:05:48 +0100 Subject: [PATCH 03/31] Add basic tests --- run_test.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/run_test.go b/run_test.go index f06c1564b..f3445852e 100644 --- a/run_test.go +++ b/run_test.go @@ -234,6 +234,16 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndStream error") + _, err = client.CreateRunStreaming(ctx, threadID, openai.RunRequest{ + AssistantID: assistantID, + }) + checks.NoError(t, err, "CreateRunStreaming error") + + _, err = client.SubmitToolOutputsStream(ctx, threadID, runID, openai.SubmitToolOutputsRequest{ + ToolOutputs: nil, + }) + checks.NoError(t, err, "SubmitToolOutputsStream error") + _, err = client.RetrieveRunStep(ctx, threadID, runID, stepID) checks.NoError(t, err, "RetrieveRunStep error") From db0e71ba688df19cfd00c618797b4eaff64fcc72 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Thu, 16 May 2024 09:56:43 +0800 Subject: [PATCH 04/31] SSE EOL Reader --- sse_reader_test.go | 180 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 sse_reader_test.go diff --git a/sse_reader_test.go b/sse_reader_test.go new file mode 100644 index 000000000..b643be043 --- /dev/null +++ b/sse_reader_test.go @@ -0,0 +1,180 @@ +package openai + +import ( + "bufio" + "io" + "strings" + "testing" +) + +// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance +func NewEOLSplitterFunc() bufio.SplitFunc { + splitter := NewEOLSplitter() + return splitter.Split +} + +// EOLSplitter is the custom split function to handle CR LF, CR, and LF as end-of-line. +type EOLSplitter struct { + prevCR bool +} + +// NewEOLSplitter creates a new EOLSplitter instance. +func NewEOLSplitter() *EOLSplitter { + return &EOLSplitter{prevCR: false} +} + +// Split function to handle CR LF, CR, and LF as end-of-line. +func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Check if the previous data ended with a CR + if s.prevCR { + s.prevCR = false + if len(data) > 0 && data[0] == '\n' { + return 1, nil, nil // Skip the LF following the previous CR + } + } + + // Search for the first occurrence of CR LF, CR, or LF + for i := 0; i < len(data); i++ { + if data[i] == '\r' { + if i+1 < len(data) && data[i+1] == '\n' { + // Found CR LF + return i + 2, data[:i], nil + } + // Found CR + if !atEOF && i == len(data)-1 { + // If CR is the last byte, and not EOF, then need to check if + // the next byte is LF. + // + // save the state and request more data + s.prevCR = true + return 0, nil, nil + } + return i + 1, data[:i], nil + } + if data[i] == '\n' { + // Found LF + return i + 1, data[:i], nil + } + } + + // If at EOF, we have a final, non-terminated line. Return it. + if atEOF && len(data) > 0 { + return len(data), data, nil + } + + // Request more data. + return 0, nil, nil +} + +// CustomReader simulates a reader that splits the input across multiple reads. +type CustomReader struct { + chunks []string + index int +} + +func NewChunksReader(chunks []string) *CustomReader { + return &CustomReader{ + chunks: chunks, + } +} + +func (r *CustomReader) Read(p []byte) (n int, err error) { + if r.index >= len(r.chunks) { + return 0, io.EOF + } + n = copy(p, r.chunks[r.index]) + r.index++ + return n, nil +} + +// TestEolSplitter tests the custom EOL splitter function. +func TestEolSplitter(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + {"CRLF", "Line1\r\nLine2\r\nLine3\r\n", []string{"Line1", "Line2", "Line3"}}, + {"CR", "Line1\rLine2\rLine3\r", []string{"Line1", "Line2", "Line3"}}, + {"LF", "Line1\nLine2\nLine3\n", []string{"Line1", "Line2", "Line3"}}, + {"Mixed", "Line1\r\nLine2\rLine3\nLine4\r\nLine5", []string{"Line1", "Line2", "Line3", "Line4", "Line5"}}, + {"SingleLineNoEOL", "Line1", []string{"Line1"}}, + {"SingleLineLF", "Line1\n", []string{"Line1"}}, + {"SingleLineCR", "Line1\r", []string{"Line1"}}, + {"SingleLineCRLF", "Line1\r\n", []string{"Line1"}}, + {"DoubleNewLines", "Line1\n\nLine2", []string{"Line1", "", "Line2"}}, + {"lflf", "\n\n", []string{"", ""}}, + {"crlfcrlf", "\r\n\r\n", []string{"", ""}}, + {"crcr", "\r\r", []string{"", ""}}, + {"mixed eol: crlf cr lf", "A\r\nB\rC\nD", []string{"A", "B", "C", "D"}}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + reader := strings.NewReader(test.input) + scanner := bufio.NewScanner(reader) + scanner.Split(NewEOLSplitterFunc()) + + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(lines) != len(test.expected) { + t.Errorf("Expected %d lines, got %d", len(test.expected), len(lines)) + t.Errorf("Expected: %v, got: %v", test.expected, lines) + } + + for i := range lines { + if lines[i] != test.expected[i] { + t.Errorf("Expected line %d to be %q, got %q", i, test.expected[i], lines[i]) + } + } + }) + } +} + +// TestEolSplitterBoundaryCondition tests the boundary condition where CR LF is split across two slices. +func TestEolSplitterBoundaryCondition(t *testing.T) { + // Additional cases + cases := []struct { + input []string + expected []string + }{ + {[]string{"Line1\r", "\nLine2"}, []string{"Line1", "Line2"}}, + {[]string{"Line1\r", "\nLine2\r"}, []string{"Line1", "Line2"}}, + {[]string{"Line1\r", "\nLine2\r\n"}, []string{"Line1", "Line2"}}, + {[]string{"Line1\r", "\nLine2\r", "Line3"}, []string{"Line1", "Line2", "Line3"}}, + {[]string{"Line1\r", "\nLine2\r", "\nLine3\r\n"}, []string{"Line1", "Line2", "Line3"}}, + } + for _, c := range cases { + // Custom reader to simulate the boundary condition + reader := NewChunksReader(c.input) + scanner := bufio.NewScanner(reader) + scanner.Split(NewEOLSplitterFunc()) + + var lines []string + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + + if err := scanner.Err(); err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if len(lines) != len(c.expected) { + t.Errorf("Expected %d lines, got %d", len(c.expected), len(lines)) + continue + } + + for i := range lines { + if lines[i] != c.expected[i] { + t.Errorf("Expected line %d to be %q, got %q", i, c.expected[i], lines[i]) + } + } + } +} From 2c5ace849416d04c5896482ba308ea13b3de5a04 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Thu, 16 May 2024 10:20:58 +0800 Subject: [PATCH 05/31] sse parser --- sse_parser_test.go | 257 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) create mode 100644 sse_parser_test.go diff --git a/sse_parser_test.go b/sse_parser_test.go new file mode 100644 index 000000000..03a5b008b --- /dev/null +++ b/sse_parser_test.go @@ -0,0 +1,257 @@ +package openai + +import ( + "bufio" + "io" + "reflect" + "strconv" + "strings" + "testing" +) + +type ServerSentEvent struct { + ID string // ID of the event + Data string // Data of the event + Event string // Type of the event + Retry int // Retry time in milliseconds + Comment string // Comment +} + +type SSEScanner struct { + scanner *bufio.Scanner + event *ServerSentEvent + err error + readComment bool +} + +func NewSSEScanner(r io.Reader, readComment bool) *SSEScanner { + scanner := bufio.NewScanner(r) + + // N.B. The bufio.ScanLines handles `\r?\n``, but not `\r` itself as EOL, as + // the SSE spec requires + // + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream + // + // scanner.Split(bufio.ScanLines) + scanner.Split(NewEOLSplitterFunc()) + + return &SSEScanner{ + scanner: scanner, + readComment: readComment, + } +} + +func (s *SSEScanner) Next() bool { + s.event = nil + + var event ServerSentEvent + var dataLines []string + + var seenNonEmptyLine bool + + for s.scanner.Scan() { + line := strings.TrimSpace(s.scanner.Text()) + + if line == "" { + if seenNonEmptyLine { + break + } + + continue + } + + seenNonEmptyLine = true + + if strings.HasPrefix(line, "id: ") { + event.ID = strings.TrimPrefix(line, "id: ") + } else if strings.HasPrefix(line, "data: ") { + dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) + } else if strings.HasPrefix(line, "event: ") { + event.Event = strings.TrimPrefix(line, "event: ") + } else if strings.HasPrefix(line, "retry: ") { + retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) + if err == nil { + event.Retry = retry + } + + // ignore invalid retry values + } else if strings.HasPrefix(line, ":") { + if s.readComment { + event.Comment = strings.TrimPrefix(line, ":") + } + + // ignore comment line + } + + // ignore unknown lines + } + + s.err = s.scanner.Err() + + if !seenNonEmptyLine { + return false + } + + event.Data = strings.Join(dataLines, "\n") + s.event = &event + + return true +} + +func (s *SSEScanner) Scan() *ServerSentEvent { + return s.event +} + +func (s *SSEScanner) Err() error { + return s.err +} + +func TestSSEScanner(t *testing.T) { + tests := []struct { + raw string + want []ServerSentEvent + }{ + { + raw: `data: hello world`, + want: []ServerSentEvent{ + { + Data: "hello world", + }, + }, + }, + { + raw: `event: hello +data: hello world`, + want: []ServerSentEvent{ + { + Event: "hello", + Data: "hello world", + }, + }, + }, + { + raw: `event: hello-json +data: { +data: "msg": "hello world", +data: "id": 12345 +data: }`, + want: []ServerSentEvent{ + { + Event: "hello-json", + Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}", + }, + }, + }, + { + raw: `data: hello world + +data: hello again`, + want: []ServerSentEvent{ + { + Data: "hello world", + }, + { + Data: "hello again", + }, + }, + }, + { + raw: `retry: 10000 + data: hello world`, + want: []ServerSentEvent{ + { + Retry: 10000, + Data: "hello world", + }, + }, + }, + { + raw: `retry: 10000 + +retry: 20000`, + want: []ServerSentEvent{ + { + Retry: 10000, + }, + { + Retry: 20000, + }, + }, + }, + { + raw: `: comment 1 +: comment 2 +id: message-id +retry: 20000 +event: hello-event +data: hello`, + want: []ServerSentEvent{ + { + ID: "message-id", + Retry: 20000, + Event: "hello-event", + Data: "hello", + }, + }, + }, + { + raw: `: comment 1 +id: message 1 +data: hello 1 +retry: 10000 +event: hello-event 1 + +: comment 2 +data: hello 2 +id: message 2 +retry: 20000 +event: hello-event 2 +`, + want: []ServerSentEvent{ + { + ID: "message 1", + Retry: 10000, + Event: "hello-event 1", + Data: "hello 1", + }, + { + ID: "message 2", + Retry: 20000, + Event: "hello-event 2", + Data: "hello 2", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.raw, func(t *testing.T) { + rawWithCRLF := strings.ReplaceAll(tt.raw, "\n", "\r\n") + runSSEScanTest(t, rawWithCRLF, tt.want) + + // Test with "\r" EOL + rawWithCR := strings.ReplaceAll(tt.raw, "\n", "\r") + runSSEScanTest(t, rawWithCR, tt.want) + + // Test with "\n" EOL (original) + runSSEScanTest(t, tt.raw, tt.want) + }) + } +} + +func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) { + sseScanner := NewSSEScanner(strings.NewReader(raw), false) + + var got []ServerSentEvent + for sseScanner.Next() { + got = append(got, *sseScanner.Scan()) + } + + if err := sseScanner.Err(); err != nil { + t.Errorf("SSEScanner error: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("SSEScanner() = %v, want %v", got, want) + } +} From 7701aef3e100e569ebf46b614bdf45f67eafa962 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Thu, 16 May 2024 22:26:35 +0800 Subject: [PATCH 06/31] partial work --- run.go | 31 ++++++- run_stream_test.go | 200 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 4 deletions(-) create mode 100644 run_stream_test.go diff --git a/run.go b/run.go index 0db2ec1d2..b57f80d5c 100644 --- a/run.go +++ b/run.go @@ -3,8 +3,10 @@ package openai import ( "context" "fmt" + "io" "net/http" "net/url" + "os" ) type Run struct { @@ -477,6 +479,7 @@ func (c *Client) CreateThreadAndStream( CreateThreadAndRunRequest: request, Stream: true, } + req, err := c.newRequest( ctx, http.MethodPost, @@ -488,14 +491,34 @@ func (c *Client) CreateThreadAndStream( return } - resp, err := sendRequestStream[AssistantStreamEvent](c, req) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + // resp, err := sendRequestStream[AssistantStreamEvent](c, req) if err != nil { return } - stream = &AssistantStream{ - streamReader: resp, + defer resp.Body.Close() + + outf, err := os.Create("thread.run.stream") + if err != nil { + return nil, err } - return + defer outf.Close() + + r := io.TeeReader(resp.Body, outf) + + _, err = io.Copy(os.Stdout, r) + + // ChatCompletionStreamChoiceDelta + + // stream = &AssistantStream{ + // streamReader: resp, + // } + return nil, err } func (c *Client) CreateRunStreaming( diff --git a/run_stream_test.go b/run_stream_test.go new file mode 100644 index 000000000..7d6f3bf78 --- /dev/null +++ b/run_stream_test.go @@ -0,0 +1,200 @@ +package openai + +import ( + "encoding/json" + "io" + "reflect" + "strings" + "testing" +) + +type StreamRawEvent struct { + Type string + Data json.RawMessage +} + +type StreamDone struct { + Data string // [DONE] +} + +// Define StreamThreadMessageDelta +type StreamThreadMessageDelta struct { + ID string `json:"id"` + Object string `json:"object"` + Delta Delta `json:"delta"` +} + +type Delta struct { + // DeltaText | DeltaImageFile + Content []DeltaContent `json:"content"` +} + +type DeltaContent struct { + Index int `json:"index"` + Type string `json:"type"` + + Text *DeltaText `json:"text"` + ImageFile *DeltaImageFile `json:"image_file"` + ImageURL *DeltaImageURL `json:"image_url"` +} + +type DeltaText struct { + Value string `json:"value"` + // Annotations []any `json:"annotations"` +} + +type DeltaImageFile struct { + FileID string `json:"file_id"` + Detail string `json:"detail"` +} + +type DeltaImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + +type StreamScannerV2 struct { + scanner *SSEScanner + next any +} + +func NewStreamScannerV2(r io.Reader) *StreamScannerV2 { + return &StreamScannerV2{ + scanner: NewSSEScanner(r, false), + } +} + +func (s *StreamScannerV2) Next() bool { + if s.scanner.Next() { + event := s.scanner.Scan() + if event != nil { + switch event.Event { + case "thread.message.delta": + var delta StreamThreadMessageDelta + if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { + s.next = delta + return true + } + case "done": + s.next = StreamDone{Data: "DONE"} + return true + default: + s.next = StreamRawEvent{Data: json.RawMessage(event.Data)} + } + } + } + return false +} + +func (s *StreamScannerV2) Event() any { + return s.next +} + +func (s *StreamScannerV2) Err() error { + return s.scanner.Err() +} + +func TestStreamScannerV2(t *testing.T) { + raw := ` +event: thread.message.delta +data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}} + +event: done +data: [DONE] +` + scanner := NewStreamScannerV2(strings.NewReader(raw)) + var events []any + + for scanner.Next() { + event := scanner.Event() + events = append(events, event) + } + + expectedValues := []any{ + StreamThreadMessageDelta{ + ID: "msg_KFiZxHhXYQo6cGFnGjRDHSee", + Object: "thread.message.delta", + Delta: Delta{ + Content: []DeltaContent{ + { + Index: 0, + Type: "text", + Text: &DeltaText{ + Value: "hello", + }, + }, + }, + }, + }, + StreamDone{Data: "DONE"}, + } + + if len(events) != len(expectedValues) { + t.Fatalf("Expected %d events but got %d", len(expectedValues), len(events)) + } + + for i, event := range events { + expectedValue := expectedValues[i] + if !reflect.DeepEqual(event, expectedValue) { + t.Errorf("Expected %v but got %v", expectedValue, event) + } + } +} + +func TestStreamThreadMessageDeltaJSON(t *testing.T) { + tests := []struct { + name string + jsonData string + expectType string + expectValue interface{} + }{ + { + name: "DeltaContent with Text", + jsonData: `{"index":0,"type":"text","text":{"value":"hello"}}`, + expectType: "text", + expectValue: &DeltaText{Value: "hello"}, + }, + { + name: "DeltaContent with ImageFile", + jsonData: `{"index":1,"type":"image_file","image_file":{"file_id":"file123","detail":"An image"}}`, + expectType: "image_file", + expectValue: &DeltaImageFile{FileID: "file123", Detail: "An image"}, + }, + { + name: "DeltaContent with ImageURL", + jsonData: `{"index":2,"type":"image_url","image_url":{"url":"https://example.com/image.jpg","detail":"low"}}`, + expectType: "image_url", + expectValue: &DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var content DeltaContent + err := json.Unmarshal([]byte(tt.jsonData), &content) + if err != nil { + t.Fatalf("Error unmarshalling JSON: %v", err) + } + + if content.Type != tt.expectType { + t.Errorf("Expected Type to be '%s', got %s", tt.expectType, content.Type) + } + + var actualValue interface{} + switch tt.expectType { + case "text": + actualValue = content.Text + case "image_file": + actualValue = content.ImageFile + case "image_url": + actualValue = content.ImageURL + default: + t.Fatalf("Unexpected type: %s", tt.expectType) + } + + if !reflect.DeepEqual(actualValue, tt.expectValue) { + t.Errorf("Expected value to be '%v', got %v", tt.expectValue, actualValue) + } + }) + } +} From bc1b4db808fa35ed9c7f3399389e5b664f5cc5e3 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Thu, 16 May 2024 22:41:35 +0800 Subject: [PATCH 07/31] implement text reader --- run_stream_test.go | 109 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 100 insertions(+), 9 deletions(-) diff --git a/run_stream_test.go b/run_stream_test.go index 7d6f3bf78..7f93179fb 100644 --- a/run_stream_test.go +++ b/run_stream_test.go @@ -53,18 +53,62 @@ type DeltaImageURL struct { Detail string `json:"detail"` } -type StreamScannerV2 struct { - scanner *SSEScanner - next any +// StreamTextReader wraps StreamerV2 to implement io.Reader. +type StreamTextReader struct { + streamer *StreamerV2 + buffer []byte +} + +// NewStreamTextReader initializes and returns a new StreamTextReader. +func NewStreamTextReader(streamer *StreamerV2) *StreamTextReader { + return &StreamTextReader{ + streamer: streamer, + } +} + +// Read implements the io.Reader interface. +func (r *StreamTextReader) Read(p []byte) (int, error) { + // If we have data in the buffer, copy it to p first. + if len(r.buffer) > 0 { + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + + for r.streamer.Next() { + event := r.streamer.Event() + switch e := event.(type) { + case StreamThreadMessageDelta: + // Check if the event contains text content. + for _, content := range e.Delta.Content { + if content.Text != nil { + r.buffer = []byte(content.Text.Value) + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + } + case StreamDone: + return 0, io.EOF + } + } + + // If we reach here, there are no more events. + return 0, io.EOF } -func NewStreamScannerV2(r io.Reader) *StreamScannerV2 { - return &StreamScannerV2{ +func NewStreamerV2(r io.Reader) *StreamerV2 { + return &StreamerV2{ scanner: NewSSEScanner(r, false), } } -func (s *StreamScannerV2) Next() bool { +type StreamerV2 struct { + scanner *SSEScanner + next any +} + +func (s *StreamerV2) Next() bool { if s.scanner.Next() { event := s.scanner.Scan() if event != nil { @@ -86,14 +130,61 @@ func (s *StreamScannerV2) Next() bool { return false } -func (s *StreamScannerV2) Event() any { +func (s *StreamerV2) Event() any { return s.next } -func (s *StreamScannerV2) Err() error { +func (s *StreamerV2) Err() error { return s.scanner.Err() } +func TestNewStreamTextReader(t *testing.T) { + raw := ` +event: thread.message.delta +data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}} + +event: thread.message.delta +data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"world"}}]}} + +event: done +data: [DONE] +` + scanner := NewStreamerV2(strings.NewReader(raw)) + reader := NewStreamTextReader(scanner) + + expected := "helloworld" + buffer := make([]byte, len(expected)) + n, err := reader.Read(buffer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != len("hello") { + t.Fatalf("expected to read %d bytes, read %d bytes", len("hello"), n) + } + if string(buffer[:n]) != "hello" { + t.Fatalf("expected %q, got %q", "hello", string(buffer[:n])) + } + + n, err = reader.Read(buffer[n:]) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if n != len("world") { + t.Fatalf("expected to read %d bytes, read %d bytes", len("world"), n) + } + if string(buffer[:len(expected)]) != expected { + t.Fatalf("expected %q, got %q", expected, string(buffer[:len(expected)])) + } + + n, err = reader.Read(buffer) + if err != io.EOF { + t.Fatalf("expected io.EOF, got %v", err) + } + if n != 0 { + t.Fatalf("expected to read 0 bytes, read %d bytes", n) + } +} + func TestStreamScannerV2(t *testing.T) { raw := ` event: thread.message.delta @@ -102,7 +193,7 @@ data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delt event: done data: [DONE] ` - scanner := NewStreamScannerV2(strings.NewReader(raw)) + scanner := NewStreamerV2(strings.NewReader(raw)) var events []any for scanner.Next() { From 58b951b9d2bad911a2f1538b8212daed3c47bf49 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Fri, 17 May 2024 10:42:46 +0800 Subject: [PATCH 08/31] add convenience wrappers around stream events --- run_stream_test.go | 67 ++++++++++++++++++++++++++++++---------------- 1 file changed, 44 insertions(+), 23 deletions(-) diff --git a/run_stream_test.go b/run_stream_test.go index 7f93179fb..6505ea508 100644 --- a/run_stream_test.go +++ b/run_stream_test.go @@ -53,21 +53,21 @@ type DeltaImageURL struct { Detail string `json:"detail"` } -// StreamTextReader wraps StreamerV2 to implement io.Reader. -type StreamTextReader struct { +// streamTextReader wraps StreamerV2 to implement io.Reader. +type streamTextReader struct { streamer *StreamerV2 buffer []byte } -// NewStreamTextReader initializes and returns a new StreamTextReader. -func NewStreamTextReader(streamer *StreamerV2) *StreamTextReader { - return &StreamTextReader{ +// newStreamTextReader initializes and returns a new StreamTextReader. +func newStreamTextReader(streamer *StreamerV2) *streamTextReader { + return &streamTextReader{ streamer: streamer, } } // Read implements the io.Reader interface. -func (r *StreamTextReader) Read(p []byte) (int, error) { +func (r *streamTextReader) Read(p []byte) (int, error) { // If we have data in the buffer, copy it to p first. if len(r.buffer) > 0 { n := copy(p, r.buffer) @@ -76,24 +76,23 @@ func (r *StreamTextReader) Read(p []byte) (int, error) { } for r.streamer.Next() { - event := r.streamer.Event() - switch e := event.(type) { - case StreamThreadMessageDelta: - // Check if the event contains text content. - for _, content := range e.Delta.Content { - if content.Text != nil { - r.buffer = []byte(content.Text.Value) - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] - return n, nil - } - } - case StreamDone: - return 0, io.EOF + // Read only text deltas + text, ok := r.streamer.MessageDeltaText() + if !ok { + continue } + + r.buffer = []byte(text) + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + + // Check for streamer error + if err := r.streamer.Err(); err != nil { + return 0, err } - // If we reach here, there are no more events. return 0, io.EOF } @@ -130,10 +129,33 @@ func (s *StreamerV2) Next() bool { return false } +// Reader returns io.Reader that reads only text deltas from the stream +func (s *StreamerV2) Reader() io.Reader { + return newStreamTextReader(s) +} + func (s *StreamerV2) Event() any { return s.next } +func (s *StreamerV2) MessageDeltaText() (string, bool) { + event, ok := s.next.(StreamThreadMessageDelta) + if !ok { + return "", false + } + + var text string + for _, content := range event.Delta.Content { + if content.Text != nil { + // Can we return the first text we find? Does OpenAI stream ever + // return multiple text contents in a delta? + text = text + content.Text.Value + } + } + + return text, true +} + func (s *StreamerV2) Err() error { return s.scanner.Err() } @@ -149,8 +171,7 @@ data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delt event: done data: [DONE] ` - scanner := NewStreamerV2(strings.NewReader(raw)) - reader := NewStreamTextReader(scanner) + reader := NewStreamerV2(strings.NewReader(raw)).Reader() expected := "helloworld" buffer := make([]byte, len(expected)) From 223840c15161ae7406ef24c28b5f362afdb03b9e Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Fri, 17 May 2024 10:54:30 +0800 Subject: [PATCH 09/31] comments --- run_stream_test.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/run_stream_test.go b/run_stream_test.go index 6505ea508..c81f32f07 100644 --- a/run_stream_test.go +++ b/run_stream_test.go @@ -53,20 +53,18 @@ type DeltaImageURL struct { Detail string `json:"detail"` } -// streamTextReader wraps StreamerV2 to implement io.Reader. +// streamTextReader is an io.Reader of the text deltas of thread.message.delta events type streamTextReader struct { streamer *StreamerV2 buffer []byte } -// newStreamTextReader initializes and returns a new StreamTextReader. -func newStreamTextReader(streamer *StreamerV2) *streamTextReader { +func newStreamTextReader(streamer *StreamerV2) io.Reader { return &streamTextReader{ streamer: streamer, } } -// Read implements the io.Reader interface. func (r *streamTextReader) Read(p []byte) (int, error) { // If we have data in the buffer, copy it to p first. if len(r.buffer) > 0 { @@ -129,7 +127,7 @@ func (s *StreamerV2) Next() bool { return false } -// Reader returns io.Reader that reads only text deltas from the stream +// Reader returns io.Reader of the text deltas of thread.message.delta events func (s *StreamerV2) Reader() io.Reader { return newStreamTextReader(s) } @@ -138,6 +136,12 @@ func (s *StreamerV2) Event() any { return s.next } +// Text returns text delta if the current event is a "thread.message.delta". Alias of MessageDeltaText. +func (s *StreamerV2) Text() (string, bool) { + return s.MessageDeltaText() +} + +// MessageDeltaText returns text delta if the current event is a "thread.message.delta" func (s *StreamerV2) MessageDeltaText() (string, bool) { event, ok := s.next.(StreamThreadMessageDelta) if !ok { From c491cf7cc712eac6a8c9d2a5c0278fefdbb24ca2 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 11:24:48 +0800 Subject: [PATCH 10/31] reorganize files --- sse.go | 164 +++++++++++++++ sse_parser_test.go | 257 ------------------------ sse_reader_test.go => sse_test.go | 220 ++++++++++++++------ stream_v2.go | 162 +++++++++++++++ run_stream_test.go => stream_v2_test.go | 156 -------------- 5 files changed, 482 insertions(+), 477 deletions(-) create mode 100644 sse.go delete mode 100644 sse_parser_test.go rename sse_reader_test.go => sse_test.go (54%) create mode 100644 stream_v2.go rename run_stream_test.go => stream_v2_test.go (55%) diff --git a/sse.go b/sse.go new file mode 100644 index 000000000..8ed15470a --- /dev/null +++ b/sse.go @@ -0,0 +1,164 @@ +package openai + +import ( + "bufio" + "io" + "strconv" + "strings" +) + +// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance +func NewEOLSplitterFunc() bufio.SplitFunc { + splitter := NewEOLSplitter() + return splitter.Split +} + +// EOLSplitter is the custom split function to handle CR LF, CR, and LF as end-of-line. +type EOLSplitter struct { + prevCR bool +} + +// NewEOLSplitter creates a new EOLSplitter instance. +func NewEOLSplitter() *EOLSplitter { + return &EOLSplitter{prevCR: false} +} + +// Split function to handle CR LF, CR, and LF as end-of-line. +func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { + // Check if the previous data ended with a CR + if s.prevCR { + s.prevCR = false + if len(data) > 0 && data[0] == '\n' { + return 1, nil, nil // Skip the LF following the previous CR + } + } + + // Search for the first occurrence of CR LF, CR, or LF + for i := 0; i < len(data); i++ { + if data[i] == '\r' { + if i+1 < len(data) && data[i+1] == '\n' { + // Found CR LF + return i + 2, data[:i], nil + } + // Found CR + if !atEOF && i == len(data)-1 { + // If CR is the last byte, and not EOF, then need to check if + // the next byte is LF. + // + // save the state and request more data + s.prevCR = true + return 0, nil, nil + } + return i + 1, data[:i], nil + } + if data[i] == '\n' { + // Found LF + return i + 1, data[:i], nil + } + } + + // If at EOF, we have a final, non-terminated line. Return it. + if atEOF && len(data) > 0 { + return len(data), data, nil + } + + // Request more data. + return 0, nil, nil +} + +type ServerSentEvent struct { + ID string // ID of the event + Data string // Data of the event + Event string // Type of the event + Retry int // Retry time in milliseconds + Comment string // Comment +} + +type SSEScanner struct { + scanner *bufio.Scanner + event *ServerSentEvent + err error + readComment bool +} + +func NewSSEScanner(r io.Reader, readComment bool) *SSEScanner { + scanner := bufio.NewScanner(r) + + // N.B. The bufio.ScanLines handles `\r?\n``, but not `\r` itself as EOL, as + // the SSE spec requires + // + // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream + // + // scanner.Split(bufio.ScanLines) + scanner.Split(NewEOLSplitterFunc()) + + return &SSEScanner{ + scanner: scanner, + readComment: readComment, + } +} + +func (s *SSEScanner) Next() bool { + s.event = nil + + var event ServerSentEvent + var dataLines []string + + var seenNonEmptyLine bool + + for s.scanner.Scan() { + line := strings.TrimSpace(s.scanner.Text()) + + if line == "" { + if seenNonEmptyLine { + break + } + + continue + } + + seenNonEmptyLine = true + + if strings.HasPrefix(line, "id: ") { + event.ID = strings.TrimPrefix(line, "id: ") + } else if strings.HasPrefix(line, "data: ") { + dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) + } else if strings.HasPrefix(line, "event: ") { + event.Event = strings.TrimPrefix(line, "event: ") + } else if strings.HasPrefix(line, "retry: ") { + retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) + if err == nil { + event.Retry = retry + } + + // ignore invalid retry values + } else if strings.HasPrefix(line, ":") { + if s.readComment { + event.Comment = strings.TrimPrefix(line, ":") + } + + // ignore comment line + } + + // ignore unknown lines + } + + s.err = s.scanner.Err() + + if !seenNonEmptyLine { + return false + } + + event.Data = strings.Join(dataLines, "\n") + s.event = &event + + return true +} + +func (s *SSEScanner) Scan() *ServerSentEvent { + return s.event +} + +func (s *SSEScanner) Err() error { + return s.err +} diff --git a/sse_parser_test.go b/sse_parser_test.go deleted file mode 100644 index 03a5b008b..000000000 --- a/sse_parser_test.go +++ /dev/null @@ -1,257 +0,0 @@ -package openai - -import ( - "bufio" - "io" - "reflect" - "strconv" - "strings" - "testing" -) - -type ServerSentEvent struct { - ID string // ID of the event - Data string // Data of the event - Event string // Type of the event - Retry int // Retry time in milliseconds - Comment string // Comment -} - -type SSEScanner struct { - scanner *bufio.Scanner - event *ServerSentEvent - err error - readComment bool -} - -func NewSSEScanner(r io.Reader, readComment bool) *SSEScanner { - scanner := bufio.NewScanner(r) - - // N.B. The bufio.ScanLines handles `\r?\n``, but not `\r` itself as EOL, as - // the SSE spec requires - // - // See: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream - // - // scanner.Split(bufio.ScanLines) - scanner.Split(NewEOLSplitterFunc()) - - return &SSEScanner{ - scanner: scanner, - readComment: readComment, - } -} - -func (s *SSEScanner) Next() bool { - s.event = nil - - var event ServerSentEvent - var dataLines []string - - var seenNonEmptyLine bool - - for s.scanner.Scan() { - line := strings.TrimSpace(s.scanner.Text()) - - if line == "" { - if seenNonEmptyLine { - break - } - - continue - } - - seenNonEmptyLine = true - - if strings.HasPrefix(line, "id: ") { - event.ID = strings.TrimPrefix(line, "id: ") - } else if strings.HasPrefix(line, "data: ") { - dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) - } else if strings.HasPrefix(line, "event: ") { - event.Event = strings.TrimPrefix(line, "event: ") - } else if strings.HasPrefix(line, "retry: ") { - retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) - if err == nil { - event.Retry = retry - } - - // ignore invalid retry values - } else if strings.HasPrefix(line, ":") { - if s.readComment { - event.Comment = strings.TrimPrefix(line, ":") - } - - // ignore comment line - } - - // ignore unknown lines - } - - s.err = s.scanner.Err() - - if !seenNonEmptyLine { - return false - } - - event.Data = strings.Join(dataLines, "\n") - s.event = &event - - return true -} - -func (s *SSEScanner) Scan() *ServerSentEvent { - return s.event -} - -func (s *SSEScanner) Err() error { - return s.err -} - -func TestSSEScanner(t *testing.T) { - tests := []struct { - raw string - want []ServerSentEvent - }{ - { - raw: `data: hello world`, - want: []ServerSentEvent{ - { - Data: "hello world", - }, - }, - }, - { - raw: `event: hello -data: hello world`, - want: []ServerSentEvent{ - { - Event: "hello", - Data: "hello world", - }, - }, - }, - { - raw: `event: hello-json -data: { -data: "msg": "hello world", -data: "id": 12345 -data: }`, - want: []ServerSentEvent{ - { - Event: "hello-json", - Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}", - }, - }, - }, - { - raw: `data: hello world - -data: hello again`, - want: []ServerSentEvent{ - { - Data: "hello world", - }, - { - Data: "hello again", - }, - }, - }, - { - raw: `retry: 10000 - data: hello world`, - want: []ServerSentEvent{ - { - Retry: 10000, - Data: "hello world", - }, - }, - }, - { - raw: `retry: 10000 - -retry: 20000`, - want: []ServerSentEvent{ - { - Retry: 10000, - }, - { - Retry: 20000, - }, - }, - }, - { - raw: `: comment 1 -: comment 2 -id: message-id -retry: 20000 -event: hello-event -data: hello`, - want: []ServerSentEvent{ - { - ID: "message-id", - Retry: 20000, - Event: "hello-event", - Data: "hello", - }, - }, - }, - { - raw: `: comment 1 -id: message 1 -data: hello 1 -retry: 10000 -event: hello-event 1 - -: comment 2 -data: hello 2 -id: message 2 -retry: 20000 -event: hello-event 2 -`, - want: []ServerSentEvent{ - { - ID: "message 1", - Retry: 10000, - Event: "hello-event 1", - Data: "hello 1", - }, - { - ID: "message 2", - Retry: 20000, - Event: "hello-event 2", - Data: "hello 2", - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.raw, func(t *testing.T) { - rawWithCRLF := strings.ReplaceAll(tt.raw, "\n", "\r\n") - runSSEScanTest(t, rawWithCRLF, tt.want) - - // Test with "\r" EOL - rawWithCR := strings.ReplaceAll(tt.raw, "\n", "\r") - runSSEScanTest(t, rawWithCR, tt.want) - - // Test with "\n" EOL (original) - runSSEScanTest(t, tt.raw, tt.want) - }) - } -} - -func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) { - sseScanner := NewSSEScanner(strings.NewReader(raw), false) - - var got []ServerSentEvent - for sseScanner.Next() { - got = append(got, *sseScanner.Scan()) - } - - if err := sseScanner.Err(); err != nil { - t.Errorf("SSEScanner error: %v", err) - } - - if !reflect.DeepEqual(got, want) { - t.Errorf("SSEScanner() = %v, want %v", got, want) - } -} diff --git a/sse_reader_test.go b/sse_test.go similarity index 54% rename from sse_reader_test.go rename to sse_test.go index b643be043..9c81fb143 100644 --- a/sse_reader_test.go +++ b/sse_test.go @@ -3,82 +3,24 @@ package openai import ( "bufio" "io" + "reflect" "strings" "testing" ) -// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance -func NewEOLSplitterFunc() bufio.SplitFunc { - splitter := NewEOLSplitter() - return splitter.Split -} - -// EOLSplitter is the custom split function to handle CR LF, CR, and LF as end-of-line. -type EOLSplitter struct { - prevCR bool -} - -// NewEOLSplitter creates a new EOLSplitter instance. -func NewEOLSplitter() *EOLSplitter { - return &EOLSplitter{prevCR: false} -} - -// Split function to handle CR LF, CR, and LF as end-of-line. -func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { - // Check if the previous data ended with a CR - if s.prevCR { - s.prevCR = false - if len(data) > 0 && data[0] == '\n' { - return 1, nil, nil // Skip the LF following the previous CR - } - } - - // Search for the first occurrence of CR LF, CR, or LF - for i := 0; i < len(data); i++ { - if data[i] == '\r' { - if i+1 < len(data) && data[i+1] == '\n' { - // Found CR LF - return i + 2, data[:i], nil - } - // Found CR - if !atEOF && i == len(data)-1 { - // If CR is the last byte, and not EOF, then need to check if - // the next byte is LF. - // - // save the state and request more data - s.prevCR = true - return 0, nil, nil - } - return i + 1, data[:i], nil - } - if data[i] == '\n' { - // Found LF - return i + 1, data[:i], nil - } - } - - // If at EOF, we have a final, non-terminated line. Return it. - if atEOF && len(data) > 0 { - return len(data), data, nil - } - - // Request more data. - return 0, nil, nil -} - -// CustomReader simulates a reader that splits the input across multiple reads. -type CustomReader struct { +// ChunksReader simulates a reader that splits the input across multiple reads. +type ChunksReader struct { chunks []string index int } -func NewChunksReader(chunks []string) *CustomReader { - return &CustomReader{ +func NewChunksReader(chunks []string) *ChunksReader { + return &ChunksReader{ chunks: chunks, } } -func (r *CustomReader) Read(p []byte) (n int, err error) { +func (r *ChunksReader) Read(p []byte) (n int, err error) { if r.index >= len(r.chunks) { return 0, io.EOF } @@ -178,3 +120,153 @@ func TestEolSplitterBoundaryCondition(t *testing.T) { } } } + +func TestSSEScanner(t *testing.T) { + tests := []struct { + raw string + want []ServerSentEvent + }{ + { + raw: `data: hello world`, + want: []ServerSentEvent{ + { + Data: "hello world", + }, + }, + }, + { + raw: `event: hello +data: hello world`, + want: []ServerSentEvent{ + { + Event: "hello", + Data: "hello world", + }, + }, + }, + { + raw: `event: hello-json +data: { +data: "msg": "hello world", +data: "id": 12345 +data: }`, + want: []ServerSentEvent{ + { + Event: "hello-json", + Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}", + }, + }, + }, + { + raw: `data: hello world + +data: hello again`, + want: []ServerSentEvent{ + { + Data: "hello world", + }, + { + Data: "hello again", + }, + }, + }, + { + raw: `retry: 10000 + data: hello world`, + want: []ServerSentEvent{ + { + Retry: 10000, + Data: "hello world", + }, + }, + }, + { + raw: `retry: 10000 + +retry: 20000`, + want: []ServerSentEvent{ + { + Retry: 10000, + }, + { + Retry: 20000, + }, + }, + }, + { + raw: `: comment 1 +: comment 2 +id: message-id +retry: 20000 +event: hello-event +data: hello`, + want: []ServerSentEvent{ + { + ID: "message-id", + Retry: 20000, + Event: "hello-event", + Data: "hello", + }, + }, + }, + { + raw: `: comment 1 +id: message 1 +data: hello 1 +retry: 10000 +event: hello-event 1 + +: comment 2 +data: hello 2 +id: message 2 +retry: 20000 +event: hello-event 2 +`, + want: []ServerSentEvent{ + { + ID: "message 1", + Retry: 10000, + Event: "hello-event 1", + Data: "hello 1", + }, + { + ID: "message 2", + Retry: 20000, + Event: "hello-event 2", + Data: "hello 2", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.raw, func(t *testing.T) { + rawWithCRLF := strings.ReplaceAll(tt.raw, "\n", "\r\n") + runSSEScanTest(t, rawWithCRLF, tt.want) + + // Test with "\r" EOL + rawWithCR := strings.ReplaceAll(tt.raw, "\n", "\r") + runSSEScanTest(t, rawWithCR, tt.want) + + // Test with "\n" EOL (original) + runSSEScanTest(t, tt.raw, tt.want) + }) + } +} + +func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) { + sseScanner := NewSSEScanner(strings.NewReader(raw), false) + + var got []ServerSentEvent + for sseScanner.Next() { + got = append(got, *sseScanner.Scan()) + } + + if err := sseScanner.Err(); err != nil { + t.Errorf("SSEScanner error: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("SSEScanner() = %v, want %v", got, want) + } +} diff --git a/stream_v2.go b/stream_v2.go new file mode 100644 index 000000000..bf90f0a5d --- /dev/null +++ b/stream_v2.go @@ -0,0 +1,162 @@ +package openai + +import ( + "encoding/json" + "io" +) + +type StreamRawEvent struct { + Type string + Data json.RawMessage +} + +type StreamDone struct { + Data string // [DONE] +} + +// Define StreamThreadMessageDelta +type StreamThreadMessageDelta struct { + ID string `json:"id"` + Object string `json:"object"` + Delta Delta `json:"delta"` +} + +type Delta struct { + // DeltaText | DeltaImageFile + Content []DeltaContent `json:"content"` +} + +type DeltaContent struct { + Index int `json:"index"` + Type string `json:"type"` + + Text *DeltaText `json:"text"` + ImageFile *DeltaImageFile `json:"image_file"` + ImageURL *DeltaImageURL `json:"image_url"` +} + +type DeltaText struct { + Value string `json:"value"` + // Annotations []any `json:"annotations"` +} + +type DeltaImageFile struct { + FileID string `json:"file_id"` + Detail string `json:"detail"` +} + +type DeltaImageURL struct { + URL string `json:"url"` + Detail string `json:"detail"` +} + +// streamTextReader is an io.Reader of the text deltas of thread.message.delta events +type streamTextReader struct { + streamer *StreamerV2 + buffer []byte +} + +func newStreamTextReader(streamer *StreamerV2) io.Reader { + return &streamTextReader{ + streamer: streamer, + } +} + +func (r *streamTextReader) Read(p []byte) (int, error) { + // If we have data in the buffer, copy it to p first. + if len(r.buffer) > 0 { + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + + for r.streamer.Next() { + // Read only text deltas + text, ok := r.streamer.MessageDeltaText() + if !ok { + continue + } + + r.buffer = []byte(text) + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + + // Check for streamer error + if err := r.streamer.Err(); err != nil { + return 0, err + } + + return 0, io.EOF +} + +func NewStreamerV2(r io.Reader) *StreamerV2 { + return &StreamerV2{ + scanner: NewSSEScanner(r, false), + } +} + +type StreamerV2 struct { + scanner *SSEScanner + next any +} + +func (s *StreamerV2) Next() bool { + if s.scanner.Next() { + event := s.scanner.Scan() + if event != nil { + switch event.Event { + case "thread.message.delta": + var delta StreamThreadMessageDelta + if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { + s.next = delta + return true + } + case "done": + s.next = StreamDone{Data: "DONE"} + return true + default: + s.next = StreamRawEvent{Data: json.RawMessage(event.Data)} + } + } + } + return false +} + +// Reader returns io.Reader of the text deltas of thread.message.delta events +func (s *StreamerV2) Reader() io.Reader { + return newStreamTextReader(s) +} + +func (s *StreamerV2) Event() any { + return s.next +} + +// Text returns text delta if the current event is a "thread.message.delta". Alias of MessageDeltaText. +func (s *StreamerV2) Text() (string, bool) { + return s.MessageDeltaText() +} + +// MessageDeltaText returns text delta if the current event is a "thread.message.delta" +func (s *StreamerV2) MessageDeltaText() (string, bool) { + event, ok := s.next.(StreamThreadMessageDelta) + if !ok { + return "", false + } + + var text string + for _, content := range event.Delta.Content { + if content.Text != nil { + // Can we return the first text we find? Does OpenAI stream ever + // return multiple text contents in a delta? + text = text + content.Text.Value + } + } + + return text, true +} + +func (s *StreamerV2) Err() error { + return s.scanner.Err() +} diff --git a/run_stream_test.go b/stream_v2_test.go similarity index 55% rename from run_stream_test.go rename to stream_v2_test.go index c81f32f07..597443b12 100644 --- a/run_stream_test.go +++ b/stream_v2_test.go @@ -8,162 +8,6 @@ import ( "testing" ) -type StreamRawEvent struct { - Type string - Data json.RawMessage -} - -type StreamDone struct { - Data string // [DONE] -} - -// Define StreamThreadMessageDelta -type StreamThreadMessageDelta struct { - ID string `json:"id"` - Object string `json:"object"` - Delta Delta `json:"delta"` -} - -type Delta struct { - // DeltaText | DeltaImageFile - Content []DeltaContent `json:"content"` -} - -type DeltaContent struct { - Index int `json:"index"` - Type string `json:"type"` - - Text *DeltaText `json:"text"` - ImageFile *DeltaImageFile `json:"image_file"` - ImageURL *DeltaImageURL `json:"image_url"` -} - -type DeltaText struct { - Value string `json:"value"` - // Annotations []any `json:"annotations"` -} - -type DeltaImageFile struct { - FileID string `json:"file_id"` - Detail string `json:"detail"` -} - -type DeltaImageURL struct { - URL string `json:"url"` - Detail string `json:"detail"` -} - -// streamTextReader is an io.Reader of the text deltas of thread.message.delta events -type streamTextReader struct { - streamer *StreamerV2 - buffer []byte -} - -func newStreamTextReader(streamer *StreamerV2) io.Reader { - return &streamTextReader{ - streamer: streamer, - } -} - -func (r *streamTextReader) Read(p []byte) (int, error) { - // If we have data in the buffer, copy it to p first. - if len(r.buffer) > 0 { - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] - return n, nil - } - - for r.streamer.Next() { - // Read only text deltas - text, ok := r.streamer.MessageDeltaText() - if !ok { - continue - } - - r.buffer = []byte(text) - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] - return n, nil - } - - // Check for streamer error - if err := r.streamer.Err(); err != nil { - return 0, err - } - - return 0, io.EOF -} - -func NewStreamerV2(r io.Reader) *StreamerV2 { - return &StreamerV2{ - scanner: NewSSEScanner(r, false), - } -} - -type StreamerV2 struct { - scanner *SSEScanner - next any -} - -func (s *StreamerV2) Next() bool { - if s.scanner.Next() { - event := s.scanner.Scan() - if event != nil { - switch event.Event { - case "thread.message.delta": - var delta StreamThreadMessageDelta - if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { - s.next = delta - return true - } - case "done": - s.next = StreamDone{Data: "DONE"} - return true - default: - s.next = StreamRawEvent{Data: json.RawMessage(event.Data)} - } - } - } - return false -} - -// Reader returns io.Reader of the text deltas of thread.message.delta events -func (s *StreamerV2) Reader() io.Reader { - return newStreamTextReader(s) -} - -func (s *StreamerV2) Event() any { - return s.next -} - -// Text returns text delta if the current event is a "thread.message.delta". Alias of MessageDeltaText. -func (s *StreamerV2) Text() (string, bool) { - return s.MessageDeltaText() -} - -// MessageDeltaText returns text delta if the current event is a "thread.message.delta" -func (s *StreamerV2) MessageDeltaText() (string, bool) { - event, ok := s.next.(StreamThreadMessageDelta) - if !ok { - return "", false - } - - var text string - for _, content := range event.Delta.Content { - if content.Text != nil { - // Can we return the first text we find? Does OpenAI stream ever - // return multiple text contents in a delta? - text = text + content.Text.Value - } - } - - return text, true -} - -func (s *StreamerV2) Err() error { - return s.scanner.Err() -} - func TestNewStreamTextReader(t *testing.T) { raw := ` event: thread.message.delta From 6a9af22441cd9785a44a469e483f98d6dbbdcd40 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 14:15:18 +0800 Subject: [PATCH 11/31] fix StreamRawEvent parsing --- stream_v2.go | 24 +++++++++++++++++++++++- stream_v2_test.go | 9 ++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/stream_v2.go b/stream_v2.go index bf90f0a5d..82b9bdb8f 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -92,19 +92,37 @@ func (r *streamTextReader) Read(p []byte) (int, error) { } func NewStreamerV2(r io.Reader) *StreamerV2 { + var rc io.ReadCloser + + if closer, ok := r.(io.ReadCloser); ok { + rc = closer + } else { + rc = io.NopCloser(r) + } + return &StreamerV2{ + r: rc, scanner: NewSSEScanner(r, false), } } type StreamerV2 struct { + // r is only used for closing the stream + r io.ReadCloser + scanner *SSEScanner next any } +// Close closes the underlying io.ReadCloser +func (s *StreamerV2) Close() error { + return s.r.Close() +} + func (s *StreamerV2) Next() bool { if s.scanner.Next() { event := s.scanner.Scan() + if event != nil { switch event.Event { case "thread.message.delta": @@ -117,7 +135,11 @@ func (s *StreamerV2) Next() bool { s.next = StreamDone{Data: "DONE"} return true default: - s.next = StreamRawEvent{Data: json.RawMessage(event.Data)} + s.next = StreamRawEvent{ + Type: event.Event, + Data: json.RawMessage(event.Data), + } + return true } } } diff --git a/stream_v2_test.go b/stream_v2_test.go index 597443b12..4d2e021e4 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -55,13 +55,16 @@ data: [DONE] } func TestStreamScannerV2(t *testing.T) { - raw := ` + raw := `event: thread.created +data: {"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}} + event: thread.message.delta data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}} event: done data: [DONE] ` + scanner := NewStreamerV2(strings.NewReader(raw)) var events []any @@ -71,6 +74,10 @@ data: [DONE] } expectedValues := []any{ + StreamRawEvent{ + Type: "thread.created", + Data: json.RawMessage(`{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`), + }, StreamThreadMessageDelta{ ID: "msg_KFiZxHhXYQo6cGFnGjRDHSee", Object: "thread.message.delta", From 4496c987d178f05c4285c7bd3379e825ea236017 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 14:15:35 +0800 Subject: [PATCH 12/31] CreateThreadAndStream to use StreamV2 --- run.go | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/run.go b/run.go index b57f80d5c..5d852e1bf 100644 --- a/run.go +++ b/run.go @@ -3,10 +3,8 @@ package openai import ( "context" "fmt" - "io" "net/http" "net/url" - "os" ) type Run struct { @@ -473,7 +471,7 @@ type AssistantStream struct { func (c *Client) CreateThreadAndStream( ctx context.Context, - request CreateThreadAndRunRequest) (stream *AssistantStream, err error) { + request CreateThreadAndRunRequest) (stream *StreamerV2, err error) { urlSuffix := "/threads/runs" sr := CreateThreadAndStreamRequest{ CreateThreadAndRunRequest: request, @@ -491,34 +489,23 @@ func (c *Client) CreateThreadAndStream( return } + // TODO: implement requestStreamV2 req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() - // resp, err := sendRequestStream[AssistantStreamEvent](c, req) if err != nil { return } - defer resp.Body.Close() - outf, err := os.Create("thread.run.stream") - if err != nil { - return nil, err + if resp.StatusCode != 200 { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } - defer outf.Close() - - r := io.TeeReader(resp.Body, outf) - - _, err = io.Copy(os.Stdout, r) - - // ChatCompletionStreamChoiceDelta - // stream = &AssistantStream{ - // streamReader: resp, - // } - return nil, err + return NewStreamerV2(resp.Body), nil } func (c *Client) CreateRunStreaming( From 8d378e1d25243b68da000d53493bbfee18f2d773 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 14:17:46 +0800 Subject: [PATCH 13/31] fuss --- stream_v2.go | 47 +++++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/stream_v2.go b/stream_v2.go index 82b9bdb8f..abe19d507 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -120,30 +120,33 @@ func (s *StreamerV2) Close() error { } func (s *StreamerV2) Next() bool { - if s.scanner.Next() { - event := s.scanner.Scan() - - if event != nil { - switch event.Event { - case "thread.message.delta": - var delta StreamThreadMessageDelta - if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { - s.next = delta - return true - } - case "done": - s.next = StreamDone{Data: "DONE"} - return true - default: - s.next = StreamRawEvent{ - Type: event.Event, - Data: json.RawMessage(event.Data), - } - return true - } + if !s.scanner.Next() { + return false + } + + event := s.scanner.Scan() + + if event == nil { + return false + } + + switch event.Event { + case "thread.message.delta": + var delta StreamThreadMessageDelta + if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { + s.next = delta + + } + case "done": + s.next = StreamDone{Data: "DONE"} + default: + s.next = StreamRawEvent{ + Type: event.Event, + Data: json.RawMessage(event.Data), } } - return false + + return true } // Reader returns io.Reader of the text deltas of thread.message.delta events From 9933435eab18a2859481369824e2d11fd7596712 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 14:21:18 +0800 Subject: [PATCH 14/31] NewSSEScanner return Next by value instead of pointer --- sse.go | 13 +++++++------ sse_test.go | 2 +- stream_v2.go | 4 ---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/sse.go b/sse.go index 8ed15470a..076650b8c 100644 --- a/sse.go +++ b/sse.go @@ -76,7 +76,7 @@ type ServerSentEvent struct { type SSEScanner struct { scanner *bufio.Scanner - event *ServerSentEvent + next ServerSentEvent err error readComment bool } @@ -99,9 +99,10 @@ func NewSSEScanner(r io.Reader, readComment bool) *SSEScanner { } func (s *SSEScanner) Next() bool { - s.event = nil - + // Zero the next event before scanning a new one var event ServerSentEvent + s.next = event + var dataLines []string var seenNonEmptyLine bool @@ -150,13 +151,13 @@ func (s *SSEScanner) Next() bool { } event.Data = strings.Join(dataLines, "\n") - s.event = &event + s.next = event return true } -func (s *SSEScanner) Scan() *ServerSentEvent { - return s.event +func (s *SSEScanner) Scan() ServerSentEvent { + return s.next } func (s *SSEScanner) Err() error { diff --git a/sse_test.go b/sse_test.go index 9c81fb143..71fcb3ce6 100644 --- a/sse_test.go +++ b/sse_test.go @@ -259,7 +259,7 @@ func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) { var got []ServerSentEvent for sseScanner.Next() { - got = append(got, *sseScanner.Scan()) + got = append(got, sseScanner.Scan()) } if err := sseScanner.Err(); err != nil { diff --git a/stream_v2.go b/stream_v2.go index abe19d507..a853559ac 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -126,10 +126,6 @@ func (s *StreamerV2) Next() bool { event := s.scanner.Scan() - if event == nil { - return false - } - switch event.Event { case "thread.message.delta": var delta StreamThreadMessageDelta From aa1d9b0a5a0816ed743720c20e80d8729d065a4a Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 14:22:22 +0800 Subject: [PATCH 15/31] simplify the DONE event --- stream_v2.go | 3 +-- stream_v2_test.go | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/stream_v2.go b/stream_v2.go index a853559ac..7bc9253ef 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -11,7 +11,6 @@ type StreamRawEvent struct { } type StreamDone struct { - Data string // [DONE] } // Define StreamThreadMessageDelta @@ -134,7 +133,7 @@ func (s *StreamerV2) Next() bool { } case "done": - s.next = StreamDone{Data: "DONE"} + s.next = StreamDone{} default: s.next = StreamRawEvent{ Type: event.Event, diff --git a/stream_v2_test.go b/stream_v2_test.go index 4d2e021e4..01d5298f9 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -93,7 +93,7 @@ data: [DONE] }, }, }, - StreamDone{Data: "DONE"}, + StreamDone{}, } if len(events) != len(expectedValues) { From 67fe23ef4faf86658d858ca5ac31d59c065e8f62 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sat, 18 May 2024 14:27:21 +0800 Subject: [PATCH 16/31] make StreamerV2 itself an io.Reader --- stream_v2.go | 75 ++++++++++++++++++++--------------------------- stream_v2_test.go | 2 +- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/stream_v2.go b/stream_v2.go index 7bc9253ef..e15f9498d 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -49,47 +49,6 @@ type DeltaImageURL struct { Detail string `json:"detail"` } -// streamTextReader is an io.Reader of the text deltas of thread.message.delta events -type streamTextReader struct { - streamer *StreamerV2 - buffer []byte -} - -func newStreamTextReader(streamer *StreamerV2) io.Reader { - return &streamTextReader{ - streamer: streamer, - } -} - -func (r *streamTextReader) Read(p []byte) (int, error) { - // If we have data in the buffer, copy it to p first. - if len(r.buffer) > 0 { - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] - return n, nil - } - - for r.streamer.Next() { - // Read only text deltas - text, ok := r.streamer.MessageDeltaText() - if !ok { - continue - } - - r.buffer = []byte(text) - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] - return n, nil - } - - // Check for streamer error - if err := r.streamer.Err(); err != nil { - return 0, err - } - - return 0, io.EOF -} - func NewStreamerV2(r io.Reader) *StreamerV2 { var rc io.ReadCloser @@ -111,6 +70,9 @@ type StreamerV2 struct { scanner *SSEScanner next any + + // buffer for implementing io.Reader + buffer []byte } // Close closes the underlying io.ReadCloser @@ -144,9 +106,34 @@ func (s *StreamerV2) Next() bool { return true } -// Reader returns io.Reader of the text deltas of thread.message.delta events -func (s *StreamerV2) Reader() io.Reader { - return newStreamTextReader(s) +// Read implements io.Reader of the text deltas of thread.message.delta events +func (r *StreamerV2) Read(p []byte) (int, error) { + // If we have data in the buffer, copy it to p first. + if len(r.buffer) > 0 { + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + + for r.Next() { + // Read only text deltas + text, ok := r.MessageDeltaText() + if !ok { + continue + } + + r.buffer = []byte(text) + n := copy(p, r.buffer) + r.buffer = r.buffer[n:] + return n, nil + } + + // Check for streamer error + if err := r.Err(); err != nil { + return 0, err + } + + return 0, io.EOF } func (s *StreamerV2) Event() any { diff --git a/stream_v2_test.go b/stream_v2_test.go index 01d5298f9..a92f793b6 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -19,7 +19,7 @@ data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delt event: done data: [DONE] ` - reader := NewStreamerV2(strings.NewReader(raw)).Reader() + reader := NewStreamerV2(strings.NewReader(raw)) expected := "helloworld" buffer := make([]byte, len(expected)) From f4e16037bcc8c5f285dcb1ec5edbd6828a93b437 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Sun, 19 May 2024 11:23:49 +0800 Subject: [PATCH 17/31] lint --- run.go | 4 ++-- sse.go | 24 ++++++++++++------------ sse_test.go | 32 +++++++++++++++++--------------- stream_v2.go | 30 ++++++++++++++---------------- stream_v2_test.go | 32 ++++++++++++++++++-------------- 5 files changed, 63 insertions(+), 59 deletions(-) diff --git a/run.go b/run.go index 5d852e1bf..0463615b7 100644 --- a/run.go +++ b/run.go @@ -495,12 +495,12 @@ func (c *Client) CreateThreadAndStream( req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body is closed in stream.Close() + resp, err := c.config.HTTPClient.Do(req) if err != nil { return } - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { resp.Body.Close() return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } diff --git a/sse.go b/sse.go index 076650b8c..fe5a5c5f3 100644 --- a/sse.go +++ b/sse.go @@ -7,7 +7,7 @@ import ( "strings" ) -// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance +// NewEOLSplitterFunc returns a bufio.SplitFunc tied to a new EOLSplitter instance. func NewEOLSplitterFunc() bufio.SplitFunc { splitter := NewEOLSplitter() return splitter.Split @@ -23,6 +23,8 @@ func NewEOLSplitter() *EOLSplitter { return &EOLSplitter{prevCR: false} } +const crlfLen = 2 + // Split function to handle CR LF, CR, and LF as end-of-line. func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, err error) { // Check if the previous data ended with a CR @@ -38,7 +40,7 @@ func (s *EOLSplitter) Split(data []byte, atEOF bool) (advance int, token []byte, if data[i] == '\r' { if i+1 < len(data) && data[i+1] == '\n' { // Found CR LF - return i + 2, data[:i], nil + return i + crlfLen, data[:i], nil } // Found CR if !atEOF && i == len(data)-1 { @@ -119,29 +121,27 @@ func (s *SSEScanner) Next() bool { } seenNonEmptyLine = true - - if strings.HasPrefix(line, "id: ") { + switch { + case strings.HasPrefix(line, "id: "): event.ID = strings.TrimPrefix(line, "id: ") - } else if strings.HasPrefix(line, "data: ") { + case strings.HasPrefix(line, "data: "): dataLines = append(dataLines, strings.TrimPrefix(line, "data: ")) - } else if strings.HasPrefix(line, "event: ") { + case strings.HasPrefix(line, "event: "): event.Event = strings.TrimPrefix(line, "event: ") - } else if strings.HasPrefix(line, "retry: ") { + case strings.HasPrefix(line, "retry: "): retry, err := strconv.Atoi(strings.TrimPrefix(line, "retry: ")) if err == nil { event.Retry = retry } - // ignore invalid retry values - } else if strings.HasPrefix(line, ":") { + case strings.HasPrefix(line, ":"): if s.readComment { event.Comment = strings.TrimPrefix(line, ":") } - // ignore comment line + default: + // ignore unknown lines } - - // ignore unknown lines } s.err = s.scanner.Err() diff --git a/sse_test.go b/sse_test.go index 71fcb3ce6..73c458d43 100644 --- a/sse_test.go +++ b/sse_test.go @@ -1,4 +1,4 @@ -package openai +package openai_test import ( "bufio" @@ -6,6 +6,8 @@ import ( "reflect" "strings" "testing" + + "github.com/sashabaranov/go-openai" ) // ChunksReader simulates a reader that splits the input across multiple reads. @@ -55,7 +57,7 @@ func TestEolSplitter(t *testing.T) { t.Run(test.name, func(t *testing.T) { reader := strings.NewReader(test.input) scanner := bufio.NewScanner(reader) - scanner.Split(NewEOLSplitterFunc()) + scanner.Split(openai.NewEOLSplitterFunc()) var lines []string for scanner.Scan() { @@ -97,7 +99,7 @@ func TestEolSplitterBoundaryCondition(t *testing.T) { // Custom reader to simulate the boundary condition reader := NewChunksReader(c.input) scanner := bufio.NewScanner(reader) - scanner.Split(NewEOLSplitterFunc()) + scanner.Split(openai.NewEOLSplitterFunc()) var lines []string for scanner.Scan() { @@ -124,11 +126,11 @@ func TestEolSplitterBoundaryCondition(t *testing.T) { func TestSSEScanner(t *testing.T) { tests := []struct { raw string - want []ServerSentEvent + want []openai.ServerSentEvent }{ { raw: `data: hello world`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Data: "hello world", }, @@ -137,7 +139,7 @@ func TestSSEScanner(t *testing.T) { { raw: `event: hello data: hello world`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Event: "hello", Data: "hello world", @@ -150,7 +152,7 @@ data: { data: "msg": "hello world", data: "id": 12345 data: }`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Event: "hello-json", Data: "{\n\"msg\": \"hello world\",\n\"id\": 12345\n}", @@ -161,7 +163,7 @@ data: }`, raw: `data: hello world data: hello again`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Data: "hello world", }, @@ -173,7 +175,7 @@ data: hello again`, { raw: `retry: 10000 data: hello world`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Retry: 10000, Data: "hello world", @@ -184,7 +186,7 @@ data: hello again`, raw: `retry: 10000 retry: 20000`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { Retry: 10000, }, @@ -200,7 +202,7 @@ id: message-id retry: 20000 event: hello-event data: hello`, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { ID: "message-id", Retry: 20000, @@ -222,7 +224,7 @@ id: message 2 retry: 20000 event: hello-event 2 `, - want: []ServerSentEvent{ + want: []openai.ServerSentEvent{ { ID: "message 1", Retry: 10000, @@ -254,10 +256,10 @@ event: hello-event 2 } } -func runSSEScanTest(t *testing.T, raw string, want []ServerSentEvent) { - sseScanner := NewSSEScanner(strings.NewReader(raw), false) +func runSSEScanTest(t *testing.T, raw string, want []openai.ServerSentEvent) { + sseScanner := openai.NewSSEScanner(strings.NewReader(raw), false) - var got []ServerSentEvent + var got []openai.ServerSentEvent for sseScanner.Next() { got = append(got, sseScanner.Scan()) } diff --git a/stream_v2.go b/stream_v2.go index e15f9498d..515902e94 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -13,7 +13,6 @@ type StreamRawEvent struct { type StreamDone struct { } -// Define StreamThreadMessageDelta type StreamThreadMessageDelta struct { ID string `json:"id"` Object string `json:"object"` @@ -75,7 +74,7 @@ type StreamerV2 struct { buffer []byte } -// Close closes the underlying io.ReadCloser +// Close closes the underlying io.ReadCloser. func (s *StreamerV2) Close() error { return s.r.Close() } @@ -92,7 +91,6 @@ func (s *StreamerV2) Next() bool { var delta StreamThreadMessageDelta if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { s.next = delta - } case "done": s.next = StreamDone{} @@ -106,30 +104,30 @@ func (s *StreamerV2) Next() bool { return true } -// Read implements io.Reader of the text deltas of thread.message.delta events -func (r *StreamerV2) Read(p []byte) (int, error) { +// Read implements io.Reader of the text deltas of thread.message.delta events. +func (s *StreamerV2) Read(p []byte) (int, error) { // If we have data in the buffer, copy it to p first. - if len(r.buffer) > 0 { - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] + if len(s.buffer) > 0 { + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] return n, nil } - for r.Next() { + for s.Next() { // Read only text deltas - text, ok := r.MessageDeltaText() + text, ok := s.MessageDeltaText() if !ok { continue } - r.buffer = []byte(text) - n := copy(p, r.buffer) - r.buffer = r.buffer[n:] + s.buffer = []byte(text) + n := copy(p, s.buffer) + s.buffer = s.buffer[n:] return n, nil } // Check for streamer error - if err := r.Err(); err != nil { + if err := s.Err(); err != nil { return 0, err } @@ -145,7 +143,7 @@ func (s *StreamerV2) Text() (string, bool) { return s.MessageDeltaText() } -// MessageDeltaText returns text delta if the current event is a "thread.message.delta" +// MessageDeltaText returns text delta if the current event is a "thread.message.delta". func (s *StreamerV2) MessageDeltaText() (string, bool) { event, ok := s.next.(StreamThreadMessageDelta) if !ok { @@ -157,7 +155,7 @@ func (s *StreamerV2) MessageDeltaText() (string, bool) { if content.Text != nil { // Can we return the first text we find? Does OpenAI stream ever // return multiple text contents in a delta? - text = text + content.Text.Value + text += content.Text.Value } } diff --git a/stream_v2_test.go b/stream_v2_test.go index a92f793b6..477edb054 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -1,11 +1,15 @@ -package openai +//nolint:lll +package openai_test import ( "encoding/json" + "errors" "io" "reflect" "strings" "testing" + + "github.com/sashabaranov/go-openai" ) func TestNewStreamTextReader(t *testing.T) { @@ -19,7 +23,7 @@ data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delt event: done data: [DONE] ` - reader := NewStreamerV2(strings.NewReader(raw)) + reader := openai.NewStreamerV2(strings.NewReader(raw)) expected := "helloworld" buffer := make([]byte, len(expected)) @@ -46,7 +50,7 @@ data: [DONE] } n, err = reader.Read(buffer) - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Fatalf("expected io.EOF, got %v", err) } if n != 0 { @@ -65,7 +69,7 @@ event: done data: [DONE] ` - scanner := NewStreamerV2(strings.NewReader(raw)) + scanner := openai.NewStreamerV2(strings.NewReader(raw)) var events []any for scanner.Next() { @@ -74,26 +78,26 @@ data: [DONE] } expectedValues := []any{ - StreamRawEvent{ + openai.StreamRawEvent{ Type: "thread.created", Data: json.RawMessage(`{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`), }, - StreamThreadMessageDelta{ + openai.StreamThreadMessageDelta{ ID: "msg_KFiZxHhXYQo6cGFnGjRDHSee", Object: "thread.message.delta", - Delta: Delta{ - Content: []DeltaContent{ + Delta: openai.Delta{ + Content: []openai.DeltaContent{ { Index: 0, Type: "text", - Text: &DeltaText{ + Text: &openai.DeltaText{ Value: "hello", }, }, }, }, }, - StreamDone{}, + openai.StreamDone{}, } if len(events) != len(expectedValues) { @@ -119,25 +123,25 @@ func TestStreamThreadMessageDeltaJSON(t *testing.T) { name: "DeltaContent with Text", jsonData: `{"index":0,"type":"text","text":{"value":"hello"}}`, expectType: "text", - expectValue: &DeltaText{Value: "hello"}, + expectValue: &openai.DeltaText{Value: "hello"}, }, { name: "DeltaContent with ImageFile", jsonData: `{"index":1,"type":"image_file","image_file":{"file_id":"file123","detail":"An image"}}`, expectType: "image_file", - expectValue: &DeltaImageFile{FileID: "file123", Detail: "An image"}, + expectValue: &openai.DeltaImageFile{FileID: "file123", Detail: "An image"}, }, { name: "DeltaContent with ImageURL", jsonData: `{"index":2,"type":"image_url","image_url":{"url":"https://example.com/image.jpg","detail":"low"}}`, expectType: "image_url", - expectValue: &DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"}, + expectValue: &openai.DeltaImageURL{URL: "https://example.com/image.jpg", Detail: "low"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var content DeltaContent + var content openai.DeltaContent err := json.Unmarshal([]byte(tt.jsonData), &content) if err != nil { t.Fatalf("Error unmarshalling JSON: %v", err) From 90f1bd022b87e0eb4a87c88c75dbb8f31c8c13aa Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Mon, 20 May 2024 15:36:23 +0800 Subject: [PATCH 18/31] change StreamerV2 to return the StreamEvent interface --- stream_v2.go | 75 ++++++++++++++++++++++++---- stream_v2_test.go | 124 +++++++++++++++++++++++++++++++--------------- 2 files changed, 150 insertions(+), 49 deletions(-) diff --git a/stream_v2.go b/stream_v2.go index 515902e94..3eca9cb0a 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -6,17 +6,20 @@ import ( ) type StreamRawEvent struct { - Type string + streamEvent Data json.RawMessage } type StreamDone struct { + streamEvent } type StreamThreadMessageDelta struct { ID string `json:"id"` Object string `json:"object"` Delta Delta `json:"delta"` + + streamEvent } type Delta struct { @@ -68,7 +71,7 @@ type StreamerV2 struct { r io.ReadCloser scanner *SSEScanner - next any + next StreamEvent // buffer for implementing io.Reader buffer []byte @@ -79,6 +82,36 @@ func (s *StreamerV2) Close() error { return s.r.Close() } +type StreamThreadCreated struct { + Thread + streamEvent +} + +type StreamThreadRunCreated struct { + Run + streamEvent +} + +type StreamEvent interface { + Event() string + JSON() json.RawMessage +} + +type streamEvent struct { + event string + data json.RawMessage +} + +// Event returns the event name +func (s *streamEvent) Event() string { + return s.event +} + +// JSON returns the raw JSON data +func (s *streamEvent) JSON() json.RawMessage { + return s.data +} + func (s *StreamerV2) Next() bool { if !s.scanner.Next() { return false @@ -86,18 +119,42 @@ func (s *StreamerV2) Next() bool { event := s.scanner.Scan() + streamEvent := streamEvent{ + event: event.Event, + data: json.RawMessage(event.Data), + } + switch event.Event { + case "thread.created": + var thread Thread + if err := json.Unmarshal([]byte(event.Data), &thread); err == nil { + s.next = &StreamThreadCreated{ + Thread: thread, + streamEvent: streamEvent, + } + } + case "thread.run.created": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunCreated{ + Run: run, + streamEvent: streamEvent, + } + } case "thread.message.delta": var delta StreamThreadMessageDelta if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { - s.next = delta + delta.streamEvent = streamEvent + s.next = &delta } case "done": - s.next = StreamDone{} + streamEvent.data = nil + s.next = &StreamDone{ + streamEvent: streamEvent, + } default: - s.next = StreamRawEvent{ - Type: event.Event, - Data: json.RawMessage(event.Data), + s.next = &StreamRawEvent{ + streamEvent: streamEvent, } } @@ -134,7 +191,7 @@ func (s *StreamerV2) Read(p []byte) (int, error) { return 0, io.EOF } -func (s *StreamerV2) Event() any { +func (s *StreamerV2) Event() StreamEvent { return s.next } @@ -145,7 +202,7 @@ func (s *StreamerV2) Text() (string, bool) { // MessageDeltaText returns text delta if the current event is a "thread.message.delta". func (s *StreamerV2) MessageDeltaText() (string, bool) { - event, ok := s.next.(StreamThreadMessageDelta) + event, ok := s.next.(*StreamThreadMessageDelta) if !ok { return "", false } diff --git a/stream_v2_test.go b/stream_v2_test.go index 477edb054..63b6a074e 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -2,8 +2,10 @@ package openai_test import ( + "bytes" "encoding/json" "errors" + "fmt" "io" "reflect" "strings" @@ -58,56 +60,98 @@ data: [DONE] } } -func TestStreamScannerV2(t *testing.T) { - raw := `event: thread.created -data: {"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}} - -event: thread.message.delta -data: {"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}} +type TestCase struct { + Event string + Data string +} -event: done -data: [DONE] -` +func constructStreamInput(testCases []TestCase) io.Reader { + var sb bytes.Buffer + for _, tc := range testCases { + sb.WriteString("event: ") + sb.WriteString(tc.Event) + sb.WriteString("\n") + sb.WriteString("data: ") + sb.WriteString(tc.Data) + sb.WriteString("\n\n") + } + return &sb +} - scanner := openai.NewStreamerV2(strings.NewReader(raw)) - var events []any +func jsonEqual[T any](t *testing.T, data []byte, expected T) error { + var obj T + if err := json.Unmarshal(data, &obj); err != nil { + t.Fatalf("Error unmarshalling JSON: %v", err) + } - for scanner.Next() { - event := scanner.Event() - events = append(events, event) + if !reflect.DeepEqual(obj, expected) { + t.Fatalf("Expected %v, but got %v", expected, obj) } - expectedValues := []any{ - openai.StreamRawEvent{ - Type: "thread.created", - Data: json.RawMessage(`{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`), + return nil +} + +func TestStreamerV2(t *testing.T) { + testCases := []TestCase{ + { + Event: "thread.created", + Data: `{"id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","object":"thread","created_at":1715864046,"metadata":{},"tool_resources":{"code_interpreter":{"file_ids":[]}}}`, + }, + { + Event: "thread.run.created", + Data: `{"id":"run_ojU7pVxtTIaa4l1GgRmHVSbK","object":"thread.run","created_at":1715864046,"assistant_id":"asst_7xUrZ16RBU2BpaUOzLnc9HsD","thread_id":"thread_vMWb8sJ14upXpPO2VbRpGTYD","status":"queued","started_at":null,"expires_at":1715864646,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-3.5-turbo","instructions":null,"tools":[],"tool_resources":{"code_interpreter":{"file_ids":[]}},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto"}`, }, - openai.StreamThreadMessageDelta{ - ID: "msg_KFiZxHhXYQo6cGFnGjRDHSee", - Object: "thread.message.delta", - Delta: openai.Delta{ - Content: []openai.DeltaContent{ - { - Index: 0, - Type: "text", - Text: &openai.DeltaText{ - Value: "hello", - }, - }, - }, - }, + { + Event: "thread.message.delta", + Data: `{"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}}`, + }, + { + Event: "done", + Data: "[DONE]", }, - openai.StreamDone{}, } - if len(events) != len(expectedValues) { - t.Fatalf("Expected %d events but got %d", len(expectedValues), len(events)) - } + streamer := openai.NewStreamerV2(constructStreamInput(testCases)) - for i, event := range events { - expectedValue := expectedValues[i] - if !reflect.DeepEqual(event, expectedValue) { - t.Errorf("Expected %v but got %v", expectedValue, event) + for _, tc := range testCases { + if !streamer.Next() { + t.Fatal("Expected Next() to return true, but got false") + } + + event := streamer.Event() + + if event.Event() != tc.Event { + t.Fatalf("Expected event type to be %s, but got %s", tc.Event, event.Event()) + } + + if tc.Event != "done" { + // compare the json data + jsondata := event.JSON() + if string(jsondata) != tc.Data { + t.Fatalf("Expected JSON data to be %s, but got %s", tc.Data, string(jsondata)) + } + } + + switch event := event.(type) { + case *openai.StreamThreadCreated: + jsonEqual(t, []byte(tc.Data), event.Thread) + case *openai.StreamThreadRunCreated: + jsonEqual(t, []byte(tc.Data), event.Run) + case *openai.StreamThreadMessageDelta: + fmt.Println(event) + + // reinitialize the delta object to avoid comparing the hidden streamEvent fields + delta := openai.StreamThreadMessageDelta{ + ID: event.ID, + Object: event.Object, + Delta: event.Delta, + } + + jsonEqual(t, []byte(tc.Data), delta) + case *openai.StreamDone: + if event.JSON() != nil { + t.Fatalf("Expected JSON data to be nil, but got %s", string(event.JSON())) + } } } } From c502adad3dec5435a35cde46faf56173a1f097f5 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Tue, 21 May 2024 17:00:17 +0800 Subject: [PATCH 19/31] implement sendRequestStreamV2 --- client.go | 20 ++++++++++++++++++++ run.go | 46 ++++++++++++---------------------------------- run_test.go | 4 ++-- stream_v2.go | 26 +++++++++++++++++++++----- 4 files changed, 55 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index c57ba17c7..e9e061902 100644 --- a/client.go +++ b/client.go @@ -156,6 +156,26 @@ func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err er return } +func sendRequestStreamV2(client *Client, req *http.Request) (stream *StreamerV2, err error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := client.config.HTTPClient.Do(req) + if err != nil { + return + } + + // TODO: how to handle error? + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return NewStreamerV2(resp.Body), nil +} + func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "text/event-stream") diff --git a/run.go b/run.go index 0463615b7..81bfabb46 100644 --- a/run.go +++ b/run.go @@ -82,12 +82,13 @@ const ( ) type RunRequest struct { - AssistantID string `json:"assistant_id"` - Model string `json:"model,omitempty"` - Instructions string `json:"instructions,omitempty"` - AdditionalInstructions string `json:"additional_instructions,omitempty"` - Tools []Tool `json:"tools,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + AssistantID string `json:"assistant_id"` + Model string `json:"model,omitempty"` + Instructions string `json:"instructions,omitempty"` + AdditionalInstructions string `json:"additional_instructions,omitempty"` + AdditionalMessages []ThreadMessage `json:"additional_messages,omitempty"` + Tools []Tool `json:"tools,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. @@ -469,7 +470,7 @@ type AssistantStream struct { *streamReader[AssistantStreamEvent] } -func (c *Client) CreateThreadAndStream( +func (c *Client) CreateThreadAndRunStream( ctx context.Context, request CreateThreadAndRunRequest) (stream *StreamerV2, err error) { urlSuffix := "/threads/runs" @@ -489,29 +490,13 @@ func (c *Client) CreateThreadAndStream( return } - // TODO: implement requestStreamV2 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Cache-Control", "no-cache") - req.Header.Set("Connection", "keep-alive") - - resp, err := c.config.HTTPClient.Do(req) - if err != nil { - return - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - - return NewStreamerV2(resp.Body), nil + return sendRequestStreamV2(c, req) } -func (c *Client) CreateRunStreaming( +func (c *Client) CreateRunStream( ctx context.Context, threadID string, - request RunRequest) (stream *AssistantStream, err error) { + request RunRequest) (stream *StreamerV2, err error) { urlSuffix := fmt.Sprintf("/threads/%s/runs", threadID) r := RunRequestStreaming{ @@ -530,14 +515,7 @@ func (c *Client) CreateRunStreaming( return } - resp, err := sendRequestStream[AssistantStreamEvent](c, req) - if err != nil { - return - } - stream = &AssistantStream{ - streamReader: resp, - } - return + return sendRequestStreamV2(c, req) } // RetrieveRunStep retrieves a run step. diff --git a/run_test.go b/run_test.go index f3445852e..606ac426a 100644 --- a/run_test.go +++ b/run_test.go @@ -219,7 +219,7 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndRun error") - _, err = client.CreateThreadAndStream(ctx, openai.CreateThreadAndRunRequest{ + _, err = client.CreateThreadAndRunStream(ctx, openai.CreateThreadAndRunRequest{ RunRequest: openai.RunRequest{ AssistantID: assistantID, }, @@ -234,7 +234,7 @@ func TestRun(t *testing.T) { }) checks.NoError(t, err, "CreateThreadAndStream error") - _, err = client.CreateRunStreaming(ctx, threadID, openai.RunRequest{ + _, err = client.CreateRunStream(ctx, threadID, openai.RunRequest{ AssistantID: assistantID, }) checks.NoError(t, err, "CreateRunStreaming error") diff --git a/stream_v2.go b/stream_v2.go index 3eca9cb0a..75ac8c3e7 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -61,14 +61,14 @@ func NewStreamerV2(r io.Reader) *StreamerV2 { } return &StreamerV2{ - r: rc, - scanner: NewSSEScanner(r, false), + readCloser: rc, + scanner: NewSSEScanner(r, false), } } type StreamerV2 struct { - // r is only used for closing the stream - r io.ReadCloser + // readCloser is only used for closing the stream + readCloser io.ReadCloser scanner *SSEScanner next StreamEvent @@ -77,9 +77,25 @@ type StreamerV2 struct { buffer []byte } +// TeeSSE tees the stream data with a io.TeeReader +func (s *StreamerV2) TeeSSE(w io.Writer) { + // readCloser is a helper struct that implements io.ReadCloser by combining an io.Reader and an io.Closer + type readCloser struct { + io.Reader + io.Closer + } + + s.readCloser = &readCloser{ + Reader: io.TeeReader(s.readCloser, w), + Closer: s.readCloser, + } + + s.scanner = NewSSEScanner(s.readCloser, false) +} + // Close closes the underlying io.ReadCloser. func (s *StreamerV2) Close() error { - return s.r.Close() + return s.readCloser.Close() } type StreamThreadCreated struct { From f8c9b69e70bf4815bb6f5fe0f6b2889785cb905d Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Tue, 21 May 2024 19:59:46 +0800 Subject: [PATCH 20/31] add more stream events support --- stream_v2.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++ stream_v2_test.go | 22 ++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/stream_v2.go b/stream_v2.go index 75ac8c3e7..6c495c2c3 100644 --- a/stream_v2.go +++ b/stream_v2.go @@ -14,6 +14,11 @@ type StreamDone struct { streamEvent } +type StreamThreadMessageCompleted struct { + Message + streamEvent +} + type StreamThreadMessageDelta struct { ID string `json:"id"` Object string `json:"object"` @@ -108,6 +113,21 @@ type StreamThreadRunCreated struct { streamEvent } +type StreamThreadRunRequiresAction struct { + Run + streamEvent +} + +type StreamThreadRunCompleted struct { + Run + streamEvent +} + +type StreamRunStepCompleted struct { + RunStep + streamEvent +} + type StreamEvent interface { Event() string JSON() json.RawMessage @@ -157,12 +177,45 @@ func (s *StreamerV2) Next() bool { streamEvent: streamEvent, } } + + case "thread.run.requires_action": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunRequiresAction{ + Run: run, + streamEvent: streamEvent, + } + } + case "thread.run.completed": + var run Run + if err := json.Unmarshal([]byte(event.Data), &run); err == nil { + s.next = &StreamThreadRunCompleted{ + Run: run, + streamEvent: streamEvent, + } + } case "thread.message.delta": var delta StreamThreadMessageDelta if err := json.Unmarshal([]byte(event.Data), &delta); err == nil { delta.streamEvent = streamEvent s.next = &delta } + case "thread.run.step.completed": + var runStep RunStep + if err := json.Unmarshal([]byte(event.Data), &runStep); err == nil { + s.next = &StreamRunStepCompleted{ + RunStep: runStep, + streamEvent: streamEvent, + } + } + case "thread.message.completed": + var msg Message + if err := json.Unmarshal([]byte(event.Data), &msg); err == nil { + s.next = &StreamThreadMessageCompleted{ + Message: msg, + streamEvent: streamEvent, + } + } case "done": streamEvent.data = nil s.next = &StreamDone{ diff --git a/stream_v2_test.go b/stream_v2_test.go index 63b6a074e..0a5d2b9f6 100644 --- a/stream_v2_test.go +++ b/stream_v2_test.go @@ -105,6 +105,22 @@ func TestStreamerV2(t *testing.T) { Event: "thread.message.delta", Data: `{"id":"msg_KFiZxHhXYQo6cGFnGjRDHSee","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"hello"}}]}}`, }, + { + Event: "thread.run.requires_action", + Data: `{"id":"run_oNjmoH9jHSQBSPkuVqfHSaLs","object":"thread.run","created_at":1716281751,"assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_4yCKEOWSRQRofNuzl7Ny3uNs","status":"requires_action","started_at":1716281751,"expires_at":1716282351,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":{"type":"submit_tool_outputs","submit_tool_outputs":{"tool_calls":[{"id":"call_q7J5q7taE0K0x83HRuJxJJjR","type":"function","function":{"name":"lookupDefinition","arguments":"{\"entry\":\"square root of pi\",\"language\":\"en\"}"}}]}},"last_error":null,"model":"gpt-3.5-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"lookupDefinition","description":"Lookup the definition of an entry. e.g. word, short phrase, person, place, or term","parameters":{"properties":{"entry":{"description":"The entry to lookup","type":"string"},"language":{"description":"ISO 639-1 language code, e.g., 'en' for English, 'zh' for Chinese","type":"string"}},"type":"object"}}}],"tool_resources":{"code_interpreter":{"file_ids":[]}},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto"}`, + }, + { + Event: "thread.run.completed", + Data: `{"id":"run_o14scUSKGFFRrwhsfGkh2pMJ","object":"thread.run","created_at":1716281844,"assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_732uu0FpoLAGrOlxAz8syqD0","status":"completed","started_at":1716281844,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1716281845,"required_action":null,"last_error":null,"model":"gpt-3.5-turbo","instructions":null,"tools":[{"type":"function","function":{"name":"lookupDefinition","description":"Lookup the definition of an entry. e.g. word, short phrase, person, place, or term","parameters":{"properties":{"entry":{"description":"The entry to lookup","type":"string"},"language":{"description":"ISO 639-1 language code, e.g., 'en' for English, 'zh' for Chinese","type":"string"}},"type":"object"}}}],"tool_resources":{"code_interpreter":{"file_ids":[]}},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":300,"completion_tokens":24,"total_tokens":324},"response_format":"auto","tool_choice":"auto"}`, + }, + { + Event: "thread.run.step.completed", + Data: `{"id":"step_9UKPyHGdL6VczTfigS5bdGQb","object":"thread.run.step","created_at":1716281845,"run_id":"run_o14scUSKGFFRrwhsfGkh2pMJ","assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_732uu0FpoLAGrOlxAz8syqD0","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1716281845,"expires_at":1716282444,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_Hb14QXWwPWEiMJ12L8Spa3T9"}},"usage":{"prompt_tokens":300,"completion_tokens":24,"total_tokens":324}}`, + }, + { + Event: "thread.message.completed", + Data: `{"id":"msg_Hb14QXWwPWEiMJ12L8Spa3T9","object":"thread.message","created_at":1716281845,"assistant_id":"asst_FDlm0qwiBOu65jhL95yNuRv3","thread_id":"thread_732uu0FpoLAGrOlxAz8syqD0","run_id":"run_o14scUSKGFFRrwhsfGkh2pMJ","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1716281845,"role":"assistant","content":[{"type":"text","text":{"value":"Sure! Here you go:\n\nWhy couldn't the leopard play hide and seek?\n\nBecause he was always spotted!","annotations":[]}}],"attachments":[],"metadata":{}}`, + }, { Event: "done", Data: "[DONE]", @@ -148,6 +164,12 @@ func TestStreamerV2(t *testing.T) { } jsonEqual(t, []byte(tc.Data), delta) + case *openai.StreamThreadRunRequiresAction: + jsonEqual(t, []byte(tc.Data), event.Run) + case *openai.StreamThreadRunCompleted: + jsonEqual(t, []byte(tc.Data), event.Run) + case *openai.StreamRunStepCompleted: + jsonEqual(t, []byte(tc.Data), event.RunStep) case *openai.StreamDone: if event.JSON() != nil { t.Fatalf("Expected JSON data to be nil, but got %s", string(event.JSON())) From 436a2236e4df6ec47e7b5004b0cbd2b073859405 Mon Sep 17 00:00:00 2001 From: Howard Yeh Date: Tue, 21 May 2024 21:01:35 +0800 Subject: [PATCH 21/31] submit tool output streaming --- run.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/run.go b/run.go index 81bfabb46..728fe853d 100644 --- a/run.go +++ b/run.go @@ -155,11 +155,6 @@ type CreateThreadAndRunRequest struct { Thread ThreadRequest `json:"thread"` } -type CreateThreadAndStreamRequest struct { - CreateThreadAndRunRequest - Stream bool `json:"stream"` -} - type RunStep struct { ID string `json:"id"` Object string `json:"object"` @@ -358,7 +353,7 @@ func (c *Client) SubmitToolOutputsStream( threadID string, runID string, request SubmitToolOutputsRequest, -) (stream *AssistantStream, err error) { +) (stream *StreamerV2, err error) { urlSuffix := fmt.Sprintf("/threads/%s/runs/%s/submit_tool_outputs", threadID, runID) r := SubmitToolOutputsStreamRequest{ SubmitToolOutputsRequest: request, @@ -375,14 +370,7 @@ func (c *Client) SubmitToolOutputsStream( return } - resp, err := sendRequestStream[AssistantStreamEvent](c, req) - if err != nil { - return - } - stream = &AssistantStream{ - streamReader: resp, - } - return + return sendRequestStreamV2(c, req) } // CancelRun cancels a run. @@ -473,8 +461,13 @@ type AssistantStream struct { func (c *Client) CreateThreadAndRunStream( ctx context.Context, request CreateThreadAndRunRequest) (stream *StreamerV2, err error) { + type createThreadAndStreamRequest struct { + CreateThreadAndRunRequest + Stream bool `json:"stream"` + } + urlSuffix := "/threads/runs" - sr := CreateThreadAndStreamRequest{ + sr := createThreadAndStreamRequest{ CreateThreadAndRunRequest: request, Stream: true, } From 211cb49fc22766f4174fef15301c4d39aef609d3 Mon Sep 17 00:00:00 2001 From: ando-masaki Date: Fri, 24 May 2024 16:18:47 +0900 Subject: [PATCH 22/31] Update client.go to get response header whether there is an error or not. (#751) Update client.go to get response header whether there is an error or not. Because 429 Too Many Requests error response has "Retry-After" header. --- client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c57ba17c7..7bc28e984 100644 --- a/client.go +++ b/client.go @@ -129,14 +129,14 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { defer res.Body.Close() - if isFailureStatusCode(res) { - return c.handleErrorResp(res) - } - if v != nil { v.SetHeader(res.Header) } + if isFailureStatusCode(res) { + return c.handleErrorResp(res) + } + return decodeResponse(res.Body, v) } From 30cf7b879cff5eb56f06fda19c51c9e92fce8b13 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:50:22 -0700 Subject: [PATCH 23/31] feat: add params to RunRequest (#754) --- run.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/run.go b/run.go index 094b0a4db..6bd3933b1 100644 --- a/run.go +++ b/run.go @@ -92,6 +92,7 @@ type RunRequest struct { // Sampling temperature between 0 and 2. Higher values like 0.8 are more random. // lower values are more focused and deterministic. Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. @@ -103,6 +104,11 @@ type RunRequest struct { // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` + + // This can be either a string or a ToolChoice object. + ToolChoice any `json:"tool_choice,omitempty"` + // This can be either a string or a ResponseFormat object. + ResponseFormat any `json:"response_format,omitempty"` } // ThreadTruncationStrategy defines the truncation strategy to use for the thread. @@ -124,6 +130,13 @@ const ( TruncationStrategyLastMessages = TruncationStrategy("last_messages") ) +// ReponseFormat specifies the format the model must output. +// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-response_format. +// Type can either be text or json_object. +type ReponseFormat struct { + Type string `json:"type"` +} + type RunModifyRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } From 8618492b98bb91edbb43f8080b3a68275e183663 Mon Sep 17 00:00:00 2001 From: shosato0306 <38198918+shosato0306@users.noreply.github.com> Date: Wed, 5 Jun 2024 20:03:57 +0900 Subject: [PATCH 24/31] feat: add incomplete run status (#763) --- run.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/run.go b/run.go index 6bd3933b1..5598f1dfb 100644 --- a/run.go +++ b/run.go @@ -30,10 +30,10 @@ type Run struct { Temperature *float32 `json:"temperature,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. TruncationStrategy *ThreadTruncationStrategy `json:"truncation_strategy,omitempty"` @@ -50,6 +50,7 @@ const ( RunStatusCancelling RunStatus = "cancelling" RunStatusFailed RunStatus = "failed" RunStatusCompleted RunStatus = "completed" + RunStatusIncomplete RunStatus = "incomplete" RunStatusExpired RunStatus = "expired" RunStatusCancelled RunStatus = "cancelled" ) @@ -95,11 +96,11 @@ type RunRequest struct { TopP *float32 `json:"top_p,omitempty"` // The maximum number of prompt tokens that may be used over the course of the run. - // If the run exceeds the number of prompt tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of prompt tokens specified, the run will end with status 'incomplete'. MaxPromptTokens int `json:"max_prompt_tokens,omitempty"` // The maximum number of completion tokens that may be used over the course of the run. - // If the run exceeds the number of completion tokens specified, the run will end with status 'complete'. + // If the run exceeds the number of completion tokens specified, the run will end with status 'incomplete'. MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // ThreadTruncationStrategy defines the truncation strategy to use for the thread. From fd41f7a5f49e6723d97642c186e5e090abaebfe2 Mon Sep 17 00:00:00 2001 From: Adam Smith <62568604+TheAdamSmith@users.noreply.github.com> Date: Thu, 13 Jun 2024 06:23:07 -0700 Subject: [PATCH 25/31] Fix integration test (#762) * added TestCompletionStream test moved completion stream testing to seperate function added NoErrorF fixes nil pointer reference on stream object * update integration test models --- api_integration_test.go | 64 ++++++++++++++++++++-------------- completion.go | 31 ++++++++-------- embeddings.go | 2 +- internal/test/checks/checks.go | 7 ++++ 4 files changed, 62 insertions(+), 42 deletions(-) diff --git a/api_integration_test.go b/api_integration_test.go index 736040c50..f34685188 100644 --- a/api_integration_test.go +++ b/api_integration_test.go @@ -26,7 +26,7 @@ func TestAPI(t *testing.T) { _, err = c.ListEngines(ctx) checks.NoError(t, err, "ListEngines error") - _, err = c.GetEngine(ctx, "davinci") + _, err = c.GetEngine(ctx, openai.GPT3Davinci002) checks.NoError(t, err, "GetEngine error") fileRes, err := c.ListFiles(ctx) @@ -42,7 +42,7 @@ func TestAPI(t *testing.T) { "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: openai.AdaEmbeddingV2, } _, err = c.CreateEmbeddings(ctx, embeddingReq) checks.NoError(t, err, "Embedding error") @@ -77,31 +77,6 @@ func TestAPI(t *testing.T) { ) checks.NoError(t, err, "CreateChatCompletion (with name) returned error") - stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ - Prompt: "Ex falso quodlibet", - Model: openai.GPT3Ada, - MaxTokens: 5, - Stream: true, - }) - checks.NoError(t, err, "CreateCompletionStream returned error") - defer stream.Close() - - counter := 0 - for { - _, err = stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - t.Errorf("Stream error: %v", err) - } else { - counter++ - } - } - if counter == 0 { - t.Error("Stream did not return any responses") - } - _, err = c.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -134,6 +109,41 @@ func TestAPI(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion (with functions) returned error") } +func TestCompletionStream(t *testing.T) { + apiToken := os.Getenv("OPENAI_TOKEN") + if apiToken == "" { + t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") + } + + c := openai.NewClient(apiToken) + ctx := context.Background() + + stream, err := c.CreateCompletionStream(ctx, openai.CompletionRequest{ + Prompt: "Ex falso quodlibet", + Model: openai.GPT3Babbage002, + MaxTokens: 5, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + counter := 0 + for { + _, err = stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + t.Errorf("Stream error: %v", err) + } else { + counter++ + } + } + if counter == 0 { + t.Error("Stream did not return any responses") + } +} + func TestAPIError(t *testing.T) { apiToken := os.Getenv("OPENAI_TOKEN") if apiToken == "" { diff --git a/completion.go b/completion.go index ced8e0606..024f09b14 100644 --- a/completion.go +++ b/completion.go @@ -39,30 +39,33 @@ const ( GPT3Dot5Turbo16K0613 = "gpt-3.5-turbo-16k-0613" GPT3Dot5Turbo = "gpt-3.5-turbo" GPT3Dot5TurboInstruct = "gpt-3.5-turbo-instruct" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci003 = "text-davinci-003" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci002 = "text-davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextCurie001 = "text-curie-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextBabbage001 = "text-babbage-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextAda001 = "text-ada-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3TextDavinci001 = "text-davinci-001" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3Davinci002 = "davinci-002" - // Deprecated: Will be shut down on January 04, 2024. Use gpt-3.5-turbo-instruct instead. + // Deprecated: Model is shutdown. Use davinci-002 instead. + GPT3Davinci = "davinci" + GPT3Davinci002 = "davinci-002" + // Deprecated: Model is shutdown. Use gpt-3.5-turbo-instruct instead. GPT3CurieInstructBeta = "curie-instruct-beta" GPT3Curie = "curie" GPT3Curie002 = "curie-002" - GPT3Ada = "ada" - GPT3Ada002 = "ada-002" - GPT3Babbage = "babbage" - GPT3Babbage002 = "babbage-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Ada = "ada" + GPT3Ada002 = "ada-002" + // Deprecated: Model is shutdown. Use babbage-002 instead. + GPT3Babbage = "babbage" + GPT3Babbage002 = "babbage-002" ) // Codex Defines the models provided by OpenAI. diff --git a/embeddings.go b/embeddings.go index c5633a313..b513ba6a7 100644 --- a/embeddings.go +++ b/embeddings.go @@ -16,7 +16,7 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") type EmbeddingModel string const ( - // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + // Deprecated: The following block is shut down. Use text-embedding-ada-002 instead. AdaSimilarity EmbeddingModel = "text-similarity-ada-001" BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" CurieSimilarity EmbeddingModel = "text-similarity-curie-001" diff --git a/internal/test/checks/checks.go b/internal/test/checks/checks.go index 713369157..6bd0964c6 100644 --- a/internal/test/checks/checks.go +++ b/internal/test/checks/checks.go @@ -12,6 +12,13 @@ func NoError(t *testing.T, err error, message ...string) { } } +func NoErrorF(t *testing.T, err error, message ...string) { + t.Helper() + if err != nil { + t.Fatal(err, message) + } +} + func HasError(t *testing.T, err error, message ...string) { t.Helper() if err == nil { From 7e96c712cbdad50b9cf67324b1ca5ef6541b6235 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:15:27 +0400 Subject: [PATCH 26/31] run integration tests (#769) --- .github/workflows/integration-tests.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/integration-tests.yml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml new file mode 100644 index 000000000..19f158e40 --- /dev/null +++ b/.github/workflows/integration-tests.yml @@ -0,0 +1,19 @@ +name: Integration tests + +on: + push: + branches: + - master + +jobs: + integration_tests: + name: Run integration tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Run integration tests + run: go test -v -tags=integration ./api_integration_test.go From c69c3bb1d259375d5de801f890aca40c0b2a8867 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Thu, 13 Jun 2024 19:21:25 +0400 Subject: [PATCH 27/31] integration tests: pass openai secret (#770) * pass openai secret * only run in master branch --- .github/workflows/integration-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 19f158e40..7260b00b4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -16,4 +16,6 @@ jobs: with: go-version: '1.21' - name: Run integration tests + env: + OPENAI_TOKEN: ${{ secrets.OPENAI_TOKEN }} run: go test -v -tags=integration ./api_integration_test.go From 99cc170b5414bd21fc1c55bccba1d6c1bad04516 Mon Sep 17 00:00:00 2001 From: eiixy <990656271@qq.com> Date: Thu, 13 Jun 2024 23:24:37 +0800 Subject: [PATCH 28/31] feat: support batches api (#746) * feat: support batches api * update batch_test.go * fix golangci-lint check * fix golangci-lint check * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix tests coverage * fix: create batch api * update batch_test.go * feat: add `CreateBatchWithUploadFile` * feat: add `UploadBatchFile` * optimize variable and type naming * expose `BatchLineItem` interface * update batches const --- batch.go | 275 ++++++++++++++++++++++++++++++++++++ batch_test.go | 368 +++++++++++++++++++++++++++++++++++++++++++++++++ client_test.go | 11 ++ files.go | 1 + 4 files changed, 655 insertions(+) create mode 100644 batch.go create mode 100644 batch_test.go diff --git a/batch.go b/batch.go new file mode 100644 index 000000000..4aba966bc --- /dev/null +++ b/batch.go @@ -0,0 +1,275 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" +) + +const batchesSuffix = "/batches" + +type BatchEndpoint string + +const ( + BatchEndpointChatCompletions BatchEndpoint = "/v1/chat/completions" + BatchEndpointCompletions BatchEndpoint = "/v1/completions" + BatchEndpointEmbeddings BatchEndpoint = "/v1/embeddings" +) + +type BatchLineItem interface { + MarshalBatchLineItem() []byte +} + +type BatchChatCompletionRequest struct { + CustomID string `json:"custom_id"` + Body ChatCompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchChatCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchCompletionRequest struct { + CustomID string `json:"custom_id"` + Body CompletionRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchCompletionRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type BatchEmbeddingRequest struct { + CustomID string `json:"custom_id"` + Body EmbeddingRequest `json:"body"` + Method string `json:"method"` + URL BatchEndpoint `json:"url"` +} + +func (r BatchEmbeddingRequest) MarshalBatchLineItem() []byte { + marshal, _ := json.Marshal(r) + return marshal +} + +type Batch struct { + ID string `json:"id"` + Object string `json:"object"` + Endpoint BatchEndpoint `json:"endpoint"` + Errors *struct { + Object string `json:"object,omitempty"` + Data struct { + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param *string `json:"param,omitempty"` + Line *int `json:"line,omitempty"` + } `json:"data"` + } `json:"errors"` + InputFileID string `json:"input_file_id"` + CompletionWindow string `json:"completion_window"` + Status string `json:"status"` + OutputFileID *string `json:"output_file_id"` + ErrorFileID *string `json:"error_file_id"` + CreatedAt int `json:"created_at"` + InProgressAt *int `json:"in_progress_at"` + ExpiresAt *int `json:"expires_at"` + FinalizingAt *int `json:"finalizing_at"` + CompletedAt *int `json:"completed_at"` + FailedAt *int `json:"failed_at"` + ExpiredAt *int `json:"expired_at"` + CancellingAt *int `json:"cancelling_at"` + CancelledAt *int `json:"cancelled_at"` + RequestCounts BatchRequestCounts `json:"request_counts"` + Metadata map[string]any `json:"metadata"` +} + +type BatchRequestCounts struct { + Total int `json:"total"` + Completed int `json:"completed"` + Failed int `json:"failed"` +} + +type CreateBatchRequest struct { + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` +} + +type BatchResponse struct { + httpHeader + Batch +} + +var ErrUploadBatchFileFailed = errors.New("upload batch file failed") + +// CreateBatch — API call to Create batch. +func (c *Client) CreateBatch( + ctx context.Context, + request CreateBatchRequest, +) (response BatchResponse, err error) { + if request.CompletionWindow == "" { + request.CompletionWindow = "24h" + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} + +type UploadBatchFileRequest struct { + FileName string + Lines []BatchLineItem +} + +func (r *UploadBatchFileRequest) MarshalJSONL() []byte { + buff := bytes.Buffer{} + for i, line := range r.Lines { + if i != 0 { + buff.Write([]byte("\n")) + } + buff.Write(line.MarshalBatchLineItem()) + } + return buff.Bytes() +} + +func (r *UploadBatchFileRequest) AddChatCompletion(customerID string, body ChatCompletionRequest) { + r.Lines = append(r.Lines, BatchChatCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointChatCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddCompletion(customerID string, body CompletionRequest) { + r.Lines = append(r.Lines, BatchCompletionRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointCompletions, + }) +} + +func (r *UploadBatchFileRequest) AddEmbedding(customerID string, body EmbeddingRequest) { + r.Lines = append(r.Lines, BatchEmbeddingRequest{ + CustomID: customerID, + Body: body, + Method: "POST", + URL: BatchEndpointEmbeddings, + }) +} + +// UploadBatchFile — upload batch file. +func (c *Client) UploadBatchFile(ctx context.Context, request UploadBatchFileRequest) (File, error) { + if request.FileName == "" { + request.FileName = "@batchinput.jsonl" + } + return c.CreateFileBytes(ctx, FileBytesRequest{ + Name: request.FileName, + Bytes: request.MarshalJSONL(), + Purpose: PurposeBatch, + }) +} + +type CreateBatchWithUploadFileRequest struct { + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + UploadBatchFileRequest +} + +// CreateBatchWithUploadFile — API call to Create batch with upload file. +func (c *Client) CreateBatchWithUploadFile( + ctx context.Context, + request CreateBatchWithUploadFileRequest, +) (response BatchResponse, err error) { + var file File + file, err = c.UploadBatchFile(ctx, UploadBatchFileRequest{ + FileName: request.FileName, + Lines: request.Lines, + }) + if err != nil { + err = errors.Join(ErrUploadBatchFileFailed, err) + return + } + return c.CreateBatch(ctx, CreateBatchRequest{ + InputFileID: file.ID, + Endpoint: request.Endpoint, + CompletionWindow: request.CompletionWindow, + Metadata: request.Metadata, + }) +} + +// RetrieveBatch — API call to Retrieve batch. +func (c *Client) RetrieveBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +// CancelBatch — API call to Cancel batch. +func (c *Client) CancelBatch( + ctx context.Context, + batchID string, +) (response BatchResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s/cancel", batchesSuffix, batchID) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix)) + if err != nil { + return + } + err = c.sendRequest(req, &response) + return +} + +type ListBatchResponse struct { + httpHeader + Object string `json:"object"` + Data []Batch `json:"data"` + FirstID string `json:"first_id"` + LastID string `json:"last_id"` + HasMore bool `json:"has_more"` +} + +// ListBatch API call to List batch. +func (c *Client) ListBatch(ctx context.Context, after *string, limit *int) (response ListBatchResponse, err error) { + urlValues := url.Values{} + if limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *limit)) + } + if after != nil { + urlValues.Add("after", *after) + } + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", batchesSuffix, encodedValues) + req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 000000000..4b2261e0e --- /dev/null +++ b/batch_test.go @@ -0,0 +1,368 @@ +package openai_test + +import ( + "context" + "fmt" + "net/http" + "reflect" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestUploadBatchFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/files", handleCreateFile) + req := openai.UploadBatchFileRequest{} + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.UploadBatchFile(context.Background(), req) + checks.NoError(t, err, "UploadBatchFile error") +} + +func TestCreateBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + _, err := client.CreateBatch(context.Background(), openai.CreateBatchRequest{ + InputFileID: "file-abc", + Endpoint: openai.BatchEndpointChatCompletions, + CompletionWindow: "24h", + }) + checks.NoError(t, err, "CreateBatch error") +} + +func TestCreateBatchWithUploadFile(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/files", handleCreateFile) + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + req := openai.CreateBatchWithUploadFileRequest{ + Endpoint: openai.BatchEndpointChatCompletions, + } + req.AddChatCompletion("req-1", openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + _, err := client.CreateBatchWithUploadFile(context.Background(), req) + checks.NoError(t, err, "CreateBatchWithUploadFile error") +} + +func TestRetrieveBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1", handleRetrieveBatchEndpoint) + _, err := client.RetrieveBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestCancelBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches/file-id-1/cancel", handleCancelBatchEndpoint) + _, err := client.CancelBatch(context.Background(), "file-id-1") + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestListBatch(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/batches", handleBatchEndpoint) + after := "batch_abc123" + limit := 10 + _, err := client.ListBatch(context.Background(), &after, &limit) + checks.NoError(t, err, "RetrieveBatch error") +} + +func TestUploadBatchFileRequest_AddChatCompletion(t *testing.T) { + type args struct { + customerID string + body openai.ChatCompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + { + customerID: "req-2", + body: openai.ChatCompletionRequest{ + MaxTokens: 5, + Model: openai.GPT3Dot5Turbo, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello!\"}],\"max_tokens\":5},\"method\":\"POST\",\"url\":\"/v1/chat/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddChatCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddCompletion(t *testing.T) { + type args struct { + customerID string + body openai.CompletionRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + { + customerID: "req-2", + body: openai.CompletionRequest{ + Model: openai.GPT3Dot5Turbo, + User: "Hello", + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}\n{\"custom_id\":\"req-2\",\"body\":{\"model\":\"gpt-3.5-turbo\",\"user\":\"Hello\"},\"method\":\"POST\",\"url\":\"/v1/completions\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddCompletion(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUploadBatchFileRequest_AddEmbedding(t *testing.T) { + type args struct { + customerID string + body openai.EmbeddingRequest + } + tests := []struct { + name string + args []args + want []byte + }{ + {"", []args{ + { + customerID: "req-1", + body: openai.EmbeddingRequest{ + Model: openai.GPT3Dot5Turbo, + Input: []string{"Hello", "World"}, + }, + }, + { + customerID: "req-2", + body: openai.EmbeddingRequest{ + Model: openai.AdaEmbeddingV2, + Input: []string{"Hello", "World"}, + }, + }, + }, []byte("{\"custom_id\":\"req-1\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"gpt-3.5-turbo\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}\n{\"custom_id\":\"req-2\",\"body\":{\"input\":[\"Hello\",\"World\"],\"model\":\"text-embedding-ada-002\",\"user\":\"\"},\"method\":\"POST\",\"url\":\"/v1/embeddings\"}")}, //nolint:lll + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &openai.UploadBatchFileRequest{} + for _, arg := range tt.args { + r.AddEmbedding(arg.customerID, arg.body) + } + got := r.MarshalJSONL() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Marshal() got = %v, want %v", got, tt.want) + } + }) + } +} + +func handleBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } else if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job" + } + } + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + }`) + } +} + +func handleRetrieveBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} + +func handleCancelBatchEndpoint(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + _, _ = fmt.Fprintln(w, `{ + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job" + } + }`) + } +} diff --git a/client_test.go b/client_test.go index a08d10f21..e49da9b3d 100644 --- a/client_test.go +++ b/client_test.go @@ -396,6 +396,17 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"CreateSpeech", func() (any, error) { return client.CreateSpeech(ctx, CreateSpeechRequest{Model: TTSModel1, Voice: VoiceAlloy}) }}, + {"CreateBatch", func() (any, error) { + return client.CreateBatch(ctx, CreateBatchRequest{}) + }}, + {"CreateBatchWithUploadFile", func() (any, error) { + return client.CreateBatchWithUploadFile(ctx, CreateBatchWithUploadFileRequest{}) + }}, + {"RetrieveBatch", func() (any, error) { + return client.RetrieveBatch(ctx, "") + }}, + {"CancelBatch", func() (any, error) { return client.CancelBatch(ctx, "") }}, + {"ListBatch", func() (any, error) { return client.ListBatch(ctx, nil, nil) }}, } for _, testCase := range testCases { diff --git a/files.go b/files.go index b40a44f15..26ad6bd70 100644 --- a/files.go +++ b/files.go @@ -22,6 +22,7 @@ const ( PurposeFineTuneResults PurposeType = "fine-tune-results" PurposeAssistants PurposeType = "assistants" PurposeAssistantsOutput PurposeType = "assistants_output" + PurposeBatch PurposeType = "batch" ) // FileBytesRequest represents a file upload request. From 68acf22a43903c1b460006e7c4b883ce73e35857 Mon Sep 17 00:00:00 2001 From: Pawel Kosiec Date: Thu, 13 Jun 2024 17:26:37 +0200 Subject: [PATCH 29/31] Support Tool Resources properties for Threads (#760) * Support Tool Resources properties for Threads * Add Chunking Strategy for Threads vector stores --- thread.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/thread.go b/thread.go index 900e3f2ea..6f7521454 100644 --- a/thread.go +++ b/thread.go @@ -10,21 +10,74 @@ const ( ) type Thread struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Metadata map[string]any `json:"metadata"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Metadata map[string]any `json:"metadata"` + ToolResources ToolResources `json:"tool_resources,omitempty"` httpHeader } type ThreadRequest struct { - Messages []ThreadMessage `json:"messages,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Messages []ThreadMessage `json:"messages,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *ToolResourcesRequest `json:"tool_resources,omitempty"` } +type ToolResources struct { + CodeInterpreter *CodeInterpreterToolResources `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResources `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResources struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` +} + +type ToolResourcesRequest struct { + CodeInterpreter *CodeInterpreterToolResourcesRequest `json:"code_interpreter,omitempty"` + FileSearch *FileSearchToolResourcesRequest `json:"file_search,omitempty"` +} + +type CodeInterpreterToolResourcesRequest struct { + FileIDs []string `json:"file_ids,omitempty"` +} + +type FileSearchToolResourcesRequest struct { + VectorStoreIDs []string `json:"vector_store_ids,omitempty"` + VectorStores []VectorStoreToolResources `json:"vector_stores,omitempty"` +} + +type VectorStoreToolResources struct { + FileIDs []string `json:"file_ids,omitempty"` + ChunkingStrategy *ChunkingStrategy `json:"chunking_strategy,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ChunkingStrategy struct { + Type ChunkingStrategyType `json:"type"` + Static *StaticChunkingStrategy `json:"static,omitempty"` +} + +type StaticChunkingStrategy struct { + MaxChunkSizeTokens int `json:"max_chunk_size_tokens"` + ChunkOverlapTokens int `json:"chunk_overlap_tokens"` +} + +type ChunkingStrategyType string + +const ( + ChunkingStrategyTypeAuto ChunkingStrategyType = "auto" + ChunkingStrategyTypeStatic ChunkingStrategyType = "static" +) + type ModifyThreadRequest struct { - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` + ToolResources *ToolResources `json:"tool_resources,omitempty"` } type ThreadMessageRole string From 0a421308993425afed7796da8f8e0e1abafd4582 Mon Sep 17 00:00:00 2001 From: Peng Guan-Cheng Date: Wed, 19 Jun 2024 16:37:21 +0800 Subject: [PATCH 30/31] feat: provide vector store (#772) * implement vectore store feature * fix after integration testing * fix golint error * improve test to increare code coverage * fix golint anc code coverage problem * add tool_resource in assistant response * chore: code style * feat: use pagination param * feat: use pagination param * test: use pagination param * test: rm unused code --------- Co-authored-by: Denny Depok <61371551+kodernubie@users.noreply.github.com> Co-authored-by: eric.p --- assistant.go | 50 ++++--- config.go | 2 +- vector_store.go | 345 ++++++++++++++++++++++++++++++++++++++++++ vector_store_test.go | 349 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 728 insertions(+), 18 deletions(-) create mode 100644 vector_store.go create mode 100644 vector_store_test.go diff --git a/assistant.go b/assistant.go index 661681e83..cc13a3020 100644 --- a/assistant.go +++ b/assistant.go @@ -14,16 +14,17 @@ const ( ) type Assistant struct { - ID string `json:"id"` - Object string `json:"object"` - CreatedAt int64 `json:"created_at"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Model string `json:"model"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"tools"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Model string `json:"model"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"tools"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` httpHeader } @@ -34,6 +35,7 @@ const ( AssistantToolTypeCodeInterpreter AssistantToolType = "code_interpreter" AssistantToolTypeRetrieval AssistantToolType = "retrieval" AssistantToolTypeFunction AssistantToolType = "function" + AssistantToolTypeFileSearch AssistantToolType = "file_search" ) type AssistantTool struct { @@ -41,19 +43,33 @@ type AssistantTool struct { Function *FunctionDefinition `json:"function,omitempty"` } +type AssistantToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` +} + +type AssistantToolCodeInterpreter struct { + FileIDs []string `json:"file_ids"` +} + +type AssistantToolResource struct { + FileSearch *AssistantToolFileSearch `json:"file_search,omitempty"` + CodeInterpreter *AssistantToolCodeInterpreter `json:"code_interpreter,omitempty"` +} + // AssistantRequest provides the assistant request parameters. // When modifying the tools the API functions as the following: // If Tools is undefined, no changes are made to the Assistant's tools. // If Tools is empty slice it will effectively delete all of the Assistant's tools. // If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools. type AssistantRequest struct { - Model string `json:"model"` - Name *string `json:"name,omitempty"` - Description *string `json:"description,omitempty"` - Instructions *string `json:"instructions,omitempty"` - Tools []AssistantTool `json:"-"` - FileIDs []string `json:"file_ids,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` + Model string `json:"model"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []AssistantTool `json:"-"` + FileIDs []string `json:"file_ids,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ToolResources *AssistantToolResource `json:"tool_resources,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases diff --git a/config.go b/config.go index bb437c97f..1347567d7 100644 --- a/config.go +++ b/config.go @@ -24,7 +24,7 @@ const ( const AzureAPIKeyHeader = "api-key" -const defaultAssistantVersion = "v1" // This will be deprecated by the end of 2024. +const defaultAssistantVersion = "v2" // upgrade to v2 to support vector store // ClientConfig is a configuration of a client. type ClientConfig struct { diff --git a/vector_store.go b/vector_store.go new file mode 100644 index 000000000..5c364362a --- /dev/null +++ b/vector_store.go @@ -0,0 +1,345 @@ +package openai + +import ( + "context" + "fmt" + "net/http" + "net/url" +) + +const ( + vectorStoresSuffix = "/vector_stores" + vectorStoresFilesSuffix = "/files" + vectorStoresFileBatchesSuffix = "/file_batches" +) + +type VectorStoreFileCount struct { + InProgress int `json:"in_progress"` + Completed int `json:"completed"` + Failed int `json:"failed"` + Cancelled int `json:"cancelled"` + Total int `json:"total"` +} + +type VectorStore struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + Name string `json:"name"` + UsageBytes int `json:"usage_bytes"` + FileCounts VectorStoreFileCount `json:"file_counts"` + Status string `json:"status"` + ExpiresAfter *VectorStoreExpires `json:"expires_after"` + ExpiresAt *int `json:"expires_at"` + Metadata map[string]any `json:"metadata"` + + httpHeader +} + +type VectorStoreExpires struct { + Anchor string `json:"anchor"` + Days int `json:"days"` +} + +// VectorStoreRequest provides the vector store request parameters. +type VectorStoreRequest struct { + Name string `json:"name,omitempty"` + FileIDs []string `json:"file_ids,omitempty"` + ExpiresAfter *VectorStoreExpires `json:"expires_after,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// VectorStoresList is a list of vector store. +type VectorStoresList struct { + VectorStores []VectorStore `json:"data"` + LastID *string `json:"last_id"` + FirstID *string `json:"first_id"` + HasMore bool `json:"has_more"` + httpHeader +} + +type VectorStoreDeleteResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Deleted bool `json:"deleted"` + + httpHeader +} + +type VectorStoreFile struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + UsageBytes int `json:"usage_bytes"` + Status string `json:"status"` + + httpHeader +} + +type VectorStoreFileRequest struct { + FileID string `json:"file_id"` +} + +type VectorStoreFilesList struct { + VectorStoreFiles []VectorStoreFile `json:"data"` + + httpHeader +} + +type VectorStoreFileBatch struct { + ID string `json:"id"` + Object string `json:"object"` + CreatedAt int64 `json:"created_at"` + VectorStoreID string `json:"vector_store_id"` + Status string `json:"status"` + FileCounts VectorStoreFileCount `json:"file_counts"` + + httpHeader +} + +type VectorStoreFileBatchRequest struct { + FileIDs []string `json:"file_ids"` +} + +// CreateVectorStore creates a new vector store. +func (c *Client) CreateVectorStore(ctx context.Context, request VectorStoreRequest) (response VectorStore, err error) { + req, _ := c.newRequest( + ctx, + http.MethodPost, + c.fullURL(vectorStoresSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion), + ) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStore retrieves an vector store. +func (c *Client) RetrieveVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ModifyVectorStore modifies a vector store. +func (c *Client) ModifyVectorStore( + ctx context.Context, + vectorStoreID string, + request VectorStoreRequest, +) (response VectorStore, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStore deletes an vector store. +func (c *Client) DeleteVectorStore( + ctx context.Context, + vectorStoreID string, +) (response VectorStoreDeleteResponse, err error) { + urlSuffix := fmt.Sprintf("%s/%s", vectorStoresSuffix, vectorStoreID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStores Lists the currently available vector store. +func (c *Client) ListVectorStores( + ctx context.Context, + pagination Pagination, +) (response VectorStoresList, err error) { + urlValues := url.Values{} + + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s%s", vectorStoresSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFile creates a new vector store file. +func (c *Client) CreateVectorStoreFile( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileRequest, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFile retrieves a vector store file. +func (c *Client) RetrieveVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (response VectorStoreFile, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// DeleteVectorStoreFile deletes an existing file. +func (c *Client) DeleteVectorStoreFile( + ctx context.Context, + vectorStoreID string, + fileID string, +) (err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, fileID) + req, _ := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, nil) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFiles( + ctx context.Context, + vectorStoreID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFilesSuffix, encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CreateVectorStoreFileBatch creates a new vector store file batch. +func (c *Client) CreateVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + request VectorStoreFileBatchRequest, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix) + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBody(request), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// RetrieveVectorStoreFileBatch retrieves a vector store file batch. +func (c *Client) RetrieveVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s", vectorStoresSuffix, vectorStoreID, vectorStoresFileBatchesSuffix, batchID) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// CancelVectorStoreFileBatch cancel a new vector store file batch. +func (c *Client) CancelVectorStoreFileBatch( + ctx context.Context, + vectorStoreID string, + batchID string, +) (response VectorStoreFileBatch, err error) { + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/cancel") + req, _ := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} + +// ListVectorStoreFiles Lists the currently available files for a vector store. +func (c *Client) ListVectorStoreFilesInBatch( + ctx context.Context, + vectorStoreID string, + batchID string, + pagination Pagination, +) (response VectorStoreFilesList, err error) { + urlValues := url.Values{} + if pagination.After != nil { + urlValues.Add("after", *pagination.After) + } + if pagination.Limit != nil { + urlValues.Add("limit", fmt.Sprintf("%d", *pagination.Limit)) + } + if pagination.Before != nil { + urlValues.Add("before", *pagination.Before) + } + if pagination.Order != nil { + urlValues.Add("order", *pagination.Order) + } + + encodedValues := "" + if len(urlValues) > 0 { + encodedValues = "?" + urlValues.Encode() + } + + urlSuffix := fmt.Sprintf("%s/%s%s/%s%s%s", vectorStoresSuffix, + vectorStoreID, vectorStoresFileBatchesSuffix, batchID, "/files", encodedValues) + req, _ := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix), + withBetaAssistantVersion(c.config.AssistantVersion)) + + err = c.sendRequest(req, &response) + return +} diff --git a/vector_store_test.go b/vector_store_test.go new file mode 100644 index 000000000..58b9a857e --- /dev/null +++ b/vector_store_test.go @@ -0,0 +1,349 @@ +package openai_test + +import ( + "context" + + openai "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" + + "encoding/json" + "fmt" + "net/http" + "testing" +) + +// TestVectorStore Tests the vector store endpoint of the API using the mocked server. +func TestVectorStore(t *testing.T) { + vectorStoreID := "vs_abc123" + vectorStoreName := "TestStore" + vectorStoreFileID := "file-wB6RM6wHdA49HfS2DJ9fEyrH" + vectorStoreFileBatchID := "vsfb_abc123" + limit := 20 + order := "desc" + after := "vs_abc122" + before := "vs_abc123" + + client, server, teardown := setupOpenAITestServer() + defer teardown() + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files/"+vectorStoreFileID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodDelete { + fmt.Fprintln(w, `{ + id: "file-wB6RM6wHdA49HfS2DJ9fEyrH", + object: "vector_store.file.deleted", + deleted: true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + var request openai.VectorStoreFileRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFile{ + ID: request.FileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/files", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFilesList{ + VectorStoreFiles: []openai.VectorStoreFile{ + { + ID: vectorStoreFileID, + Object: "vector_store.file", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID+"/cancel", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 1, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches/"+vectorStoreFileBatchID, + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodPost { + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "cancelling", + FileCounts: openai.VectorStoreFileCount{ + Completed: 1, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID+"/file_batches", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreFileBatchRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStoreFileBatch{ + ID: vectorStoreFileBatchID, + Object: "vector_store.file_batch", + CreatedAt: 1234567890, + VectorStoreID: vectorStoreID, + Status: "completed", + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: len(request.FileIDs), + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores/"+vectorStoreID, + func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodPost: + var request openai.VectorStore + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + }) + fmt.Fprintln(w, string(resBytes)) + case http.MethodDelete: + fmt.Fprintln(w, `{ + "id": "vectorstore_abc123", + "object": "vector_store.deleted", + "deleted": true + }`) + } + }, + ) + + server.RegisterHandler( + "/v1/vector_stores", + func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var request openai.VectorStoreRequest + err := json.NewDecoder(r.Body).Decode(&request) + checks.NoError(t, err, "Decode error") + + resBytes, _ := json.Marshal(openai.VectorStore{ + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: request.Name, + FileCounts: openai.VectorStoreFileCount{ + InProgress: 0, + Completed: 0, + Failed: 0, + Cancelled: 0, + Total: 0, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } else if r.Method == http.MethodGet { + resBytes, _ := json.Marshal(openai.VectorStoresList{ + LastID: &vectorStoreID, + FirstID: &vectorStoreID, + VectorStores: []openai.VectorStore{ + { + ID: vectorStoreID, + Object: "vector_store", + CreatedAt: 1234567890, + Name: vectorStoreName, + }, + }, + }) + fmt.Fprintln(w, string(resBytes)) + } + }, + ) + + ctx := context.Background() + + t.Run("create_vector_store", func(t *testing.T) { + _, err := client.CreateVectorStore(ctx, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "CreateVectorStore error") + }) + + t.Run("retrieve_vector_store", func(t *testing.T) { + _, err := client.RetrieveVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "RetrieveVectorStore error") + }) + + t.Run("delete_vector_store", func(t *testing.T) { + _, err := client.DeleteVectorStore(ctx, vectorStoreID) + checks.NoError(t, err, "DeleteVectorStore error") + }) + + t.Run("list_vector_store", func(t *testing.T) { + _, err := client.ListVectorStores(context.TODO(), openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStores error") + }) + + t.Run("create_vector_store_file", func(t *testing.T) { + _, err := client.CreateVectorStoreFile(context.TODO(), vectorStoreID, openai.VectorStoreFileRequest{ + FileID: vectorStoreFileID, + }) + checks.NoError(t, err, "CreateVectorStoreFile error") + }) + + t.Run("list_vector_store_files", func(t *testing.T) { + _, err := client.ListVectorStoreFiles(ctx, vectorStoreID, openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFiles error") + }) + + t.Run("retrieve_vector_store_file", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "RetrieveVectorStoreFile error") + }) + + t.Run("delete_vector_store_file", func(t *testing.T) { + err := client.DeleteVectorStoreFile(ctx, vectorStoreID, vectorStoreFileID) + checks.NoError(t, err, "DeleteVectorStoreFile error") + }) + + t.Run("modify_vector_store", func(t *testing.T) { + _, err := client.ModifyVectorStore(ctx, vectorStoreID, openai.VectorStoreRequest{ + Name: vectorStoreName, + }) + checks.NoError(t, err, "ModifyVectorStore error") + }) + + t.Run("create_vector_store_file_batch", func(t *testing.T) { + _, err := client.CreateVectorStoreFileBatch(ctx, vectorStoreID, openai.VectorStoreFileBatchRequest{ + FileIDs: []string{vectorStoreFileID}, + }) + checks.NoError(t, err, "CreateVectorStoreFileBatch error") + }) + + t.Run("retrieve_vector_store_file_batch", func(t *testing.T) { + _, err := client.RetrieveVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "RetrieveVectorStoreFileBatch error") + }) + + t.Run("list_vector_store_files_in_batch", func(t *testing.T) { + _, err := client.ListVectorStoreFilesInBatch( + ctx, + vectorStoreID, + vectorStoreFileBatchID, + openai.Pagination{ + Limit: &limit, + Order: &order, + After: &after, + Before: &before, + }) + checks.NoError(t, err, "ListVectorStoreFilesInBatch error") + }) + + t.Run("cancel_vector_store_file_batch", func(t *testing.T) { + _, err := client.CancelVectorStoreFileBatch(ctx, vectorStoreID, vectorStoreFileBatchID) + checks.NoError(t, err, "CancelVectorStoreFileBatch error") + }) +} From e31185974c45949cc58c24a6cbf5ca969fb0f622 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:06:52 +0100 Subject: [PATCH 31/31] remove errors.Join (#778) --- batch.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/batch.go b/batch.go index 4aba966bc..a43d401ab 100644 --- a/batch.go +++ b/batch.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -109,8 +108,6 @@ type BatchResponse struct { Batch } -var ErrUploadBatchFileFailed = errors.New("upload batch file failed") - // CreateBatch — API call to Create batch. func (c *Client) CreateBatch( ctx context.Context, @@ -202,7 +199,6 @@ func (c *Client) CreateBatchWithUploadFile( Lines: request.Lines, }) if err != nil { - err = errors.Join(ErrUploadBatchFileFailed, err) return } return c.CreateBatch(ctx, CreateBatchRequest{