@@ -7,9 +7,11 @@ import (
77 "fmt"
88 "io"
99 "net/http"
10+ "reflect"
1011 "testing"
1112
1213 "github.com/sashabaranov/go-openai/internal/test"
14+ "github.com/sashabaranov/go-openai/internal/test/checks"
1315)
1416
1517var errTestRequestBuilderFailed = errors .New ("test request builder failed" )
@@ -43,38 +45,68 @@ func TestDecodeResponse(t *testing.T) {
4345 testCases := []struct {
4446 name string
4547 value interface {}
48+ expected interface {}
4649 body io.Reader
4750 hasError bool
4851 }{
4952 {
50- name : "nil input" ,
51- value : nil ,
52- body : bytes .NewReader ([]byte ("" )),
53+ name : "nil input" ,
54+ value : nil ,
55+ body : bytes .NewReader ([]byte ("" )),
56+ expected : nil ,
5357 },
5458 {
55- name : "string input" ,
56- value : & stringInput ,
57- body : bytes .NewReader ([]byte ("test" )),
59+ name : "string input" ,
60+ value : & stringInput ,
61+ body : bytes .NewReader ([]byte ("test" )),
62+ expected : "test" ,
5863 },
5964 {
6065 name : "map input" ,
6166 value : & map [string ]interface {}{},
6267 body : bytes .NewReader ([]byte (`{"test": "test"}` )),
68+ expected : map [string ]interface {}{
69+ "test" : "test" ,
70+ },
6371 },
6472 {
6573 name : "reader return error" ,
6674 value : & stringInput ,
6775 body : & errorReader {err : errors .New ("dummy" )},
6876 hasError : true ,
6977 },
78+ {
79+ name : "audio text input" ,
80+ value : & audioTextResponse {},
81+ body : bytes .NewReader ([]byte ("test" )),
82+ expected : audioTextResponse {
83+ Text : "test" ,
84+ },
85+ },
86+ }
87+
88+ assertEqual := func (t * testing.T , expected , actual interface {}) {
89+ t .Helper ()
90+ if expected == actual {
91+ return
92+ }
93+ v := reflect .ValueOf (actual ).Elem ().Interface ()
94+ if ! reflect .DeepEqual (v , expected ) {
95+ t .Fatalf ("Unexpected value: %v, expected: %v" , v , expected )
96+ }
7097 }
7198
7299 for _ , tc := range testCases {
73100 t .Run (tc .name , func (t * testing.T ) {
74101 err := decodeResponse (tc .body , tc .value )
75- if (err != nil ) != tc .hasError {
76- t .Errorf ("Unexpected error: %v" , err )
102+ if tc .hasError {
103+ checks .HasError (t , err , "Unexpected nil error" )
104+ return
105+ }
106+ if err != nil {
107+ t .Fatalf ("Unexpected error: %v" , err )
77108 }
109+ assertEqual (t , tc .expected , tc .value )
78110 })
79111 }
80112}
0 commit comments