Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \
-e SENSITIVE_WORDS="aa,bb" \
-e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \
-e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \
-e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \
-e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/)
,觉得不错你可以来波素质三连." \
--restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
Expand Down Expand Up @@ -505,6 +506,15 @@ azure_resource_name: "xxxx"
azure_deployment_name: "xxxx"
azure_openai_token: "xxxx"

# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
# 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力
credentials:
-
client_id: "put-your-client-id-here"
client_secret: "put-your-client-secret-here"

```

## 常见问题
Expand Down
7 changes: 7 additions & 0 deletions config.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,10 @@ azure_resource_name: "xxxx"
azure_deployment_name: "xxxx"
azure_openai_token: "xxxx"

# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
#credentials:
# -
# client_id: "put-your-client-id-here"
# client_secret: "put-your-client-secret-here"
19 changes: 19 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import (
"gopkg.in/yaml.v2"
)

type Credential struct {
ClientID string `yaml:"client_id"`
ClientSecret string `yaml:"client_secret"`
}

// Configuration 项目配置
type Configuration struct {
// 日志级别,info或者debug
Expand Down Expand Up @@ -62,6 +67,8 @@ type Configuration struct {
AzureResourceName string `yaml:"azure_resource_name"`
AzureDeploymentName string `yaml:"azure_deployment_name"`
AzureOpenAIToken string `yaml:"azure_openai_token"`
// 钉钉应用鉴权凭据
Credentials []Credential `yaml:"credentials"`
}

var config *Configuration
Expand Down Expand Up @@ -190,6 +197,18 @@ func LoadConfig() *Configuration {
if azureOpenaiToken != "" {
config.AzureOpenAIToken = azureOpenaiToken
}
credentials := os.Getenv("DINGTALK_CREDENTIALS")
if credentials != "" {
if config.Credentials == nil {
config.Credentials = []Credential{}
}
for _, idSecret := range strings.Split(credentials, ",") {
items := strings.SplitN(idSecret, ":", 2)
if len(items) == 2 {
config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]})
}
}
}

})

Expand Down
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ services:
AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai"
AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai"
AZURE_OPENAI_TOKEN: "" # Azure token
DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2"
HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义
volumes:
- ./data:/app/data
Expand Down
10 changes: 9 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ func Start() {
return
}
// 先校验回调是否合法
clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
if !checkOk {
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
return
}
// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
c.Set(public.DingTalkClientIdKeyName, clientId)
// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials
if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
return
Expand Down Expand Up @@ -114,7 +122,7 @@ func Start() {
// 除去帮助之外的逻辑分流在这里处理
switch {
case strings.HasPrefix(msgObj.Text.Content, "#图片"):
err := process.ImageGenerate(&msgObj)
err := process.ImageGenerate(c, &msgObj)
if err != nil {
logger.Warning(fmt.Errorf("process request: %v", err))
return
Expand Down
20 changes: 17 additions & 3 deletions pkg/chatgpt/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package chatgpt

import (
"bytes"
"context"
"encoding/base64"
"encoding/gob"
"errors"
"fmt"
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
"github.com/pandodao/tokenizer-go"
"image/png"
"os"
Expand Down Expand Up @@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
return resp.Choices[0].Text, nil
}
}
func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
model := public.Config.Model
if model == openai.GPT3Dot5Turbo0301 ||
model == openai.GPT3Dot5Turbo ||
Expand Down Expand Up @@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
}

imageName := time.Now().Format("20060102-150405") + ".png"
clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
if client != nil {
mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
}

err = os.MkdirAll("data/images", 0755)
if err != nil {
return "", err
Expand All @@ -260,8 +271,11 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
if err := png.Encode(file, imgData); err != nil {
return "", err
}

return public.Config.ServiceURL + "/images/" + imageName, nil
if uploadErr == nil {
return mediaResult.MediaID, nil
} else {
return public.Config.ServiceURL + "/images/" + imageName, nil
}
}
return "", nil
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/chatgpt/export.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package chatgpt

import (
"context"
"time"

"github.com/avast/retry-go"
Expand Down Expand Up @@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error
}

// ImageQa 生成图片
func ImageQa(question, userId string) (answer string, err error) {
func ImageQa(ctx context.Context, question, userId string) (answer string, err error) {
chat := New(userId)
defer chat.Close()
// 定义一个重试策略
Expand All @@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) {
// 使用重试策略进行重试
err = retry.Do(
func() error {
answer, err = chat.GenreateImage(question)
answer, err = chat.GenreateImage(ctx, question)
if err != nil {
return err
}
Expand Down
213 changes: 213 additions & 0 deletions pkg/dingbot/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package dingbot

import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/eryajf/chatgpt-dingtalk/config"
"io"
"mime/multipart"
"net/http"
url2 "net/url"
"sync"
"time"
)

// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
const (
MediaTypeImage string = "image"
MediaTypeVoice string = "voice"
MediaTypeVideo string = "video"
MediaTypeFile string = "file"
)
const (
MimeTypeImagePng string = "image/png"
)

type MediaUploadResult struct {
ErrorCode int64 `json:"errcode"`
ErrorMessage string `json:"errmsg"`
MediaID string `json:"media_id"`
CreatedAt int64 `json:"created_at"`
Type string `json:"type"`
}

type OAuthTokenResult struct {
ErrorCode int `json:"errcode"`
ErrorMessage string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
}

type DingTalkClientInterface interface {
GetAccessToken() (string, error)
UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error)
}

type DingTalkClientManagerInterface interface {
GetClientByOAuthClientID(clientId string) DingTalkClientInterface
}

type DingTalkClient struct {
Credential config.Credential
AccessToken string
expireAt int64
mutex sync.Mutex
}

type DingTalkClientManager struct {
Credentials []config.Credential
Clients map[string]*DingTalkClient
mutex sync.Mutex
}

func NewDingTalkClient(credential config.Credential) *DingTalkClient {
return &DingTalkClient{
Credential: credential,
}
}

func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager {
clients := make(map[string]*DingTalkClient)

if conf != nil && conf.Credentials != nil {
for _, credential := range conf.Credentials {
clients[credential.ClientID] = NewDingTalkClient(credential)
}
}
return &DingTalkClientManager{
Credentials: conf.Credentials,
Clients: clients,
}
}

func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface {
m.mutex.Lock()
defer m.mutex.Unlock()
if client, ok := m.Clients[clientId]; ok {
return client
}
return nil
}

func (c *DingTalkClient) GetAccessToken() (string, error) {
accessToken := ""
{
// 先查询缓存
c.mutex.Lock()
now := time.Now().Unix()
if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt {
// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误
accessToken = c.AccessToken
}
c.mutex.Unlock()
}
if accessToken != "" {
return accessToken, nil
}

tokenResult, err := c.getAccessTokenFromDingTalk()
if err != nil {
return "", err
}

{
// 更新缓存
c.mutex.Lock()
c.AccessToken = tokenResult.AccessToken
c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn)
c.mutex.Unlock()
}
return tokenResult.AccessToken, nil
}

func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) {
// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
accessToken, err := c.GetAccessToken()
if err != nil {
return nil, err
}
if len(accessToken) == 0 {
return nil, errors.New("empty access token")
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("media", filename)
if err != nil {
return nil, err
}
_, err = part.Write(content)
writer.WriteField("type", mediaType)
err = writer.Close()
if err != nil {
return nil, err
}

// Create a new HTTP request to upload the media file
url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken))
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", writer.FormDataContentType())

// Send the HTTP request and parse the response
client := &http.Client{
Timeout: time.Second * 60,
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()

// Parse the response body as JSON and extract the media ID
media := &MediaUploadResult{}
bodyBytes, err := io.ReadAll(res.Body)
json.Unmarshal(bodyBytes, media)
if err != nil {
return nil, err
}
if media.ErrorCode != 0 {
return nil, errors.New(media.ErrorMessage)
}
return media, nil
}

func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) {
// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
apiUrl := "https://oapi.dingtalk.com/gettoken"
queryParams := url2.Values{}
queryParams.Add("appkey", c.Credential.ClientID)
queryParams.Add("appsecret", c.Credential.ClientSecret)

// Create a new HTTP request to get the AccessToken
req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil)
if err != nil {
return nil, err
}

// Send the HTTP request and parse the response body as JSON
client := http.Client{
Timeout: time.Second * 60,
}
res, err := client.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
tokenResult := &OAuthTokenResult{}
err = json.Unmarshal(body, tokenResult)
if err != nil {
return nil, err
}
if tokenResult.ErrorCode != 0 {
return nil, errors.New(tokenResult.ErrorMessage)
}
return tokenResult, nil
}
Loading