From 9c0393f199d6a099b3b8281faeeecfc5d44873f7 Mon Sep 17 00:00:00 2001 From: beyond Date: Wed, 15 Nov 2023 22:37:25 +0800 Subject: [PATCH 1/2] add CreateCompletionStreamByCustom --- stream.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/stream.go b/stream.go index b277f3c29..806d92e45 100644 --- a/stream.go +++ b/stream.go @@ -47,3 +47,35 @@ func (c *Client) CreateCompletionStream( } return } + +// CreateCompletionStreamByCustom — more parameters supported +func (c *Client) CreateCompletionStreamByCustom( + ctx context.Context, + model, prompt string, + request any, +) (stream *CompletionStream, err error) { + urlSuffix := "/completions" + if !checkEndpointSupportsModel(urlSuffix, model) { + err = ErrCompletionUnsupportedModel + return + } + + if !checkPromptType(prompt) { + err = ErrCompletionRequestPromptTypeNotSupported + return + } + + req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, model), withBody(request)) + if err != nil { + return nil, err + } + + resp, err := sendRequestStream[CompletionResponse](c, req) + if err != nil { + return + } + stream = &CompletionStream{ + streamReader: resp, + } + return +} \ No newline at end of file From 7d856ea96da4ee87bfd0ba3a5e6739c83e82c921 Mon Sep 17 00:00:00 2001 From: beyond Date: Thu, 16 Nov 2023 12:23:38 +0800 Subject: [PATCH 2/2] update paramter --- completion.go | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++ stream.go | 9 +++--- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/completion.go b/completion.go index 2709c8b03..8bd5f9580 100644 --- a/completion.go +++ b/completion.go @@ -136,6 +136,58 @@ type CompletionRequest struct { User string `json:"user,omitempty"` } +type CompletionCustomRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + BestOf int `json:"best_of,omitempty"` + Echo bool `json:"echo,omitempty"` + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + Logprobs int `json:"logprobs,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float32 `json:"presence_penalty,omitempty"` + Stop []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + Suffix string `json:"suffix,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopP float32 `json:"top_p,omitempty"` + User string `json:"user,omitempty"` + Preset string `json:"preset,omitempty"` + MinP int `json:"min_p,omitempty"` + TopK int `json:"top_k,omitempty"` + RepetitionPenalty float32 `json:"repetition_penalty,omitempty"` + RepetitionPenaltyRange int `json:"repetition_penalty_range,omitempty"` + TypicalP int `json:"typical_p,omitempty"` + Tfs int `json:"tfs,omitempty"` + TopA int `json:"top_a,omitempty"` + EpsilonCutoff int `json:"epsilon_cutoff,omitempty"` + EtaCutoff int `json:"eta_cutoff,omitempty"` + GuidanceScale int `json:"guidance_scale,omitempty"` + NegativePrompt string `json:"negative_prompt,omitempty"` + PenaltyAlpha int `json:"penalty_alpha,omitempty"` + MirostatMode int `json:"mirostat_mode,omitempty"` + MirostatTau int `json:"mirostat_tau,omitempty"` + MirostatEta float32 `json:"mirostat_eta,omitempty"` + TemperatureLast bool `json:"temperature_last,omitempty"` + DoSample bool `json:"do_sample,omitempty"` + Seed int `json:"seed,omitempty"` + LogitBias map[string]int `json:"logit_bias,omitempty"` + EncoderRepetitionPenalty int `json:"encoder_repetition_penalty,omitempty"` + NoRepeatNgramSize int `json:"no_repeat_ngram_size,omitempty"` + MinLength int `json:"min_length,omitempty"` + NumBeams int `json:"num_beams,omitempty"` + LengthPenalty int `json:"length_penalty,omitempty"` + EarlyStopping bool `json:"early_stopping,omitempty"` + TruncationLength int `json:"truncation_length,omitempty"` + MaxTokensSecond int `json:"max_tokens_second,omitempty"` + CustomTokenBans string `json:"custom_token_bans,omitempty"` + AutoMaxNewTokens bool `json:"auto_max_new_tokens,omitempty"` + BanEosToken bool `json:"ban_eos_token,omitempty"` + AddBosToken bool `json:"add_bos_token,omitempty"` + SkipSpecialTokens bool `json:"skip_special_tokens,omitempty"` + GrammarString string `json:"grammar_string,omitempty"` +} + // CompletionChoice represents one of possible completions. type CompletionChoice struct { Text string `json:"text"` @@ -197,3 +249,32 @@ func (c *Client) CreateCompletion( err = c.sendRequest(req, &response) return } + +func (c *Client) CreateCompletionByCustom( + ctx context.Context, + request CompletionCustomRequest, +) (response CompletionResponse, err error) { + if request.Stream { + err = ErrCompletionStreamNotSupported + return + } + + urlSuffix := "/completions" + if !checkEndpointSupportsModel(urlSuffix, request.Model) { + err = ErrCompletionUnsupportedModel + return + } + + if !checkPromptType(request.Prompt) { + err = ErrCompletionRequestPromptTypeNotSupported + return + } + + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request)) + if err != nil { + return + } + + err = c.sendRequest(req, &response) + return +} \ No newline at end of file diff --git a/stream.go b/stream.go index 806d92e45..137ca4530 100644 --- a/stream.go +++ b/stream.go @@ -51,21 +51,20 @@ func (c *Client) CreateCompletionStream( // CreateCompletionStreamByCustom — more parameters supported func (c *Client) CreateCompletionStreamByCustom( ctx context.Context, - model, prompt string, - request any, + request CompletionCustomRequest, ) (stream *CompletionStream, err error) { urlSuffix := "/completions" - if !checkEndpointSupportsModel(urlSuffix, model) { + if !checkEndpointSupportsModel(urlSuffix, request.Model) { err = ErrCompletionUnsupportedModel return } - if !checkPromptType(prompt) { + if !checkPromptType(request.Prompt) { err = ErrCompletionRequestPromptTypeNotSupported return } - req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, model), withBody(request)) + req, err := c.newRequest(ctx, "POST", c.fullURL(urlSuffix, request.Model), withBody(request)) if err != nil { return nil, err }