Skip to content

Commit bc89139

Browse files
authored
Feat Implement threads API (sashabaranov#536)
* feat: implement threads API * fix * add tests * fix * trigger£ * trigger * chore: add beta header
1 parent 08c167f commit bc89139

File tree

3 files changed

+214
-0
lines changed

3 files changed

+214
-0
lines changed

client_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,18 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
301301
{"DeleteAssistantFile", func() (any, error) {
302302
return nil, client.DeleteAssistantFile(ctx, "", "")
303303
}},
304+
{"CreateThread", func() (any, error) {
305+
return client.CreateThread(ctx, ThreadRequest{})
306+
}},
307+
{"RetrieveThread", func() (any, error) {
308+
return client.RetrieveThread(ctx, "")
309+
}},
310+
{"ModifyThread", func() (any, error) {
311+
return client.ModifyThread(ctx, "", ModifyThreadRequest{})
312+
}},
313+
{"DeleteThread", func() (any, error) {
314+
return client.DeleteThread(ctx, "")
315+
}},
304316
}
305317

306318
for _, testCase := range testCases {

thread.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package openai
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
const (
9+
threadsSuffix = "/threads"
10+
)
11+
12+
type Thread struct {
13+
ID string `json:"id"`
14+
Object string `json:"object"`
15+
CreatedAt int64 `json:"created_at"`
16+
Metadata map[string]any `json:"metadata"`
17+
18+
httpHeader
19+
}
20+
21+
type ThreadRequest struct {
22+
Messages []ThreadMessage `json:"messages,omitempty"`
23+
Metadata map[string]any `json:"metadata,omitempty"`
24+
}
25+
26+
type ModifyThreadRequest struct {
27+
Metadata map[string]any `json:"metadata"`
28+
}
29+
30+
type ThreadMessageRole string
31+
32+
const (
33+
ThreadMessageRoleUser ThreadMessageRole = "user"
34+
)
35+
36+
type ThreadMessage struct {
37+
Role ThreadMessageRole `json:"role"`
38+
Content string `json:"content"`
39+
FileIDs []string `json:"file_ids,omitempty"`
40+
Metadata map[string]any `json:"metadata,omitempty"`
41+
}
42+
43+
type ThreadDeleteResponse struct {
44+
ID string `json:"id"`
45+
Object string `json:"object"`
46+
Deleted bool `json:"deleted"`
47+
48+
httpHeader
49+
}
50+
51+
// CreateThread creates a new thread.
52+
func (c *Client) CreateThread(ctx context.Context, request ThreadRequest) (response Thread, err error) {
53+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(threadsSuffix), withBody(request),
54+
withBetaAssistantV1())
55+
if err != nil {
56+
return
57+
}
58+
59+
err = c.sendRequest(req, &response)
60+
return
61+
}
62+
63+
// RetrieveThread retrieves a thread.
64+
func (c *Client) RetrieveThread(ctx context.Context, threadID string) (response Thread, err error) {
65+
urlSuffix := threadsSuffix + "/" + threadID
66+
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix),
67+
withBetaAssistantV1())
68+
if err != nil {
69+
return
70+
}
71+
72+
err = c.sendRequest(req, &response)
73+
return
74+
}
75+
76+
// ModifyThread modifies a thread.
77+
func (c *Client) ModifyThread(
78+
ctx context.Context,
79+
threadID string,
80+
request ModifyThreadRequest,
81+
) (response Thread, err error) {
82+
urlSuffix := threadsSuffix + "/" + threadID
83+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request),
84+
withBetaAssistantV1())
85+
if err != nil {
86+
return
87+
}
88+
89+
err = c.sendRequest(req, &response)
90+
return
91+
}
92+
93+
// DeleteThread deletes a thread.
94+
func (c *Client) DeleteThread(
95+
ctx context.Context,
96+
threadID string,
97+
) (response ThreadDeleteResponse, err error) {
98+
urlSuffix := threadsSuffix + "/" + threadID
99+
req, err := c.newRequest(ctx, http.MethodDelete, c.fullURL(urlSuffix),
100+
withBetaAssistantV1())
101+
if err != nil {
102+
return
103+
}
104+
105+
err = c.sendRequest(req, &response)
106+
return
107+
}

thread_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package openai_test
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
"testing"
9+
10+
openai "github.com/sashabaranov/go-openai"
11+
"github.com/sashabaranov/go-openai/internal/test/checks"
12+
)
13+
14+
// TestThread Tests the thread endpoint of the API using the mocked server.
15+
func TestThread(t *testing.T) {
16+
threadID := "thread_abc123"
17+
client, server, teardown := setupOpenAITestServer()
18+
defer teardown()
19+
20+
server.RegisterHandler(
21+
"/v1/threads/"+threadID,
22+
func(w http.ResponseWriter, r *http.Request) {
23+
switch r.Method {
24+
case http.MethodGet:
25+
resBytes, _ := json.Marshal(openai.Thread{
26+
ID: threadID,
27+
Object: "thread",
28+
CreatedAt: 1234567890,
29+
})
30+
fmt.Fprintln(w, string(resBytes))
31+
case http.MethodPost:
32+
var request openai.ThreadRequest
33+
err := json.NewDecoder(r.Body).Decode(&request)
34+
checks.NoError(t, err, "Decode error")
35+
36+
resBytes, _ := json.Marshal(openai.Thread{
37+
ID: threadID,
38+
Object: "thread",
39+
CreatedAt: 1234567890,
40+
})
41+
fmt.Fprintln(w, string(resBytes))
42+
case http.MethodDelete:
43+
fmt.Fprintln(w, `{
44+
"id": "thread_abc123",
45+
"object": "thread.deleted",
46+
"deleted": true
47+
}`)
48+
}
49+
},
50+
)
51+
52+
server.RegisterHandler(
53+
"/v1/threads",
54+
func(w http.ResponseWriter, r *http.Request) {
55+
if r.Method == http.MethodPost {
56+
var request openai.ModifyThreadRequest
57+
err := json.NewDecoder(r.Body).Decode(&request)
58+
checks.NoError(t, err, "Decode error")
59+
60+
resBytes, _ := json.Marshal(openai.Thread{
61+
ID: threadID,
62+
Object: "thread",
63+
CreatedAt: 1234567890,
64+
Metadata: request.Metadata,
65+
})
66+
fmt.Fprintln(w, string(resBytes))
67+
}
68+
},
69+
)
70+
71+
ctx := context.Background()
72+
73+
_, err := client.CreateThread(ctx, openai.ThreadRequest{
74+
Messages: []openai.ThreadMessage{
75+
{
76+
Role: openai.ThreadMessageRoleUser,
77+
Content: "Hello, World!",
78+
},
79+
},
80+
})
81+
checks.NoError(t, err, "CreateThread error")
82+
83+
_, err = client.RetrieveThread(ctx, threadID)
84+
checks.NoError(t, err, "RetrieveThread error")
85+
86+
_, err = client.ModifyThread(ctx, threadID, openai.ModifyThreadRequest{
87+
Metadata: map[string]interface{}{
88+
"key": "value",
89+
},
90+
})
91+
checks.NoError(t, err, "ModifyThread error")
92+
93+
_, err = client.DeleteThread(ctx, threadID)
94+
checks.NoError(t, err, "DeleteThread error")
95+
}

0 commit comments

Comments
 (0)