Sfoglia il codice sorgente

Merge branch 'aj_ai' into debug

tuoling805 1 anno fa
parent
commit
dd34ad4793
6 ha cambiato i file con 526 aggiunte e 0 eliminazioni
  1. 314 0
      controllers/ai/ai.go
  2. 146 0
      models/aimod/ai.go
  3. 9 0
      models/db.go
  4. 6 0
      routers/router.go
  5. 41 0
      services/aiser/ai.go
  6. 10 0
      utils/config.go

+ 314 - 0
controllers/ai/ai.go

@@ -0,0 +1,314 @@
+package ai
+
+import (
+	"encoding/json"
+	"eta/eta_api/controllers"
+	"eta/eta_api/models"
+	"eta/eta_api/models/aimod"
+	"eta/eta_api/services/aiser"
+	"eta/eta_api/utils"
+	"fmt"
+	"strconv"
+	"time"
+)
+
+// AI
+type AiController struct {
+	controllers.BaseAuthController
+}
+
+// @Title 聊天接口
+// @Description 聊天接口
+// @Param	request	body aimod.ChatReq true "type json string"
+// @Success 200 {object} response.ListResp
+// @router /chat [post]
+func (this *AiController) List() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	var req aimod.ChatReq
+	err := json.Unmarshal(this.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+
+	if req.Ask == "" {
+		br.Msg = "请输入提问内容!"
+		br.ErrMsg = "请输入提问内容"
+		return
+	}
+
+	if utils.Re == nil {
+		key := "CACHE_CHAT_" + strconv.Itoa(this.SysUser.AdminId)
+		cacheVal, err := utils.Rc.RedisInt(key)
+		fmt.Println("RedisString:", cacheVal, "err:", err)
+		if err != nil && err.Error() != "redigo: nil returned" {
+			br.Msg = "获取数据失败!"
+			br.ErrMsg = "获取数据失败,Err:" + err.Error()
+			return
+		}
+		putVal := 0
+		if cacheVal <= 0 {
+			putVal = utils.AiChatLimit
+		} else {
+			putVal = cacheVal - 1
+		}
+
+		if putVal <= 0 {
+			br.Msg = "您今日50次问答已达上限,请明天再来!"
+			br.ErrMsg = "您今日50次问答已达上限,请明天再来!"
+			return
+		}
+		lastSecond := utils.GetTodayLastSecond()
+		utils.Rc.Put(key, putVal, lastSecond)
+	}
+
+	//根据提问,获取信息
+	askUuid := utils.MD5(req.Ask)
+	chatMode, err := aimod.GetAiChatByAsk(askUuid)
+	if err != nil && err.Error() != utils.ErrNoRow() {
+		br.Msg = "获取数据失败!"
+		br.ErrMsg = "获取数据失败,GetAiChatByAsk,Err:" + err.Error()
+		return
+	}
+	resp := new(aimod.ChatResp)
+	var answer string
+	//answerArr := []string{
+	//	"周度数据显示,成品油现货市场价格跟随原油下跌,但近期相对抗跌,裂解价差走扩。批零价差方面汽油收窄,柴油走扩",
+	//	"出口利润在原油下跌海外成品油矛盾更大的情况下汽柴油出口窗口完全关闭",
+	//	"汽油需求在经历五一假期的一段高峰后将回归平稳,总体没有明显矛盾,后期我们担心更多的还是柴油。"}
+	if chatMode != nil && chatMode.Answer != "" {
+		answer = chatMode.Answer
+	} else {
+		answer, _ = aiser.ChatAutoMsg(req.Ask)
+	}
+	resp.Ask = req.Ask
+	resp.Answer = answer
+
+	if req.AiChatTopicId <= 0 { //新增
+		topic := new(aimod.AiChatTopic)
+		topic.TopicName = req.Ask
+		topic.SysUserId = this.SysUser.AdminId
+		topic.SysUserRealName = this.SysUser.RealName
+		topic.CreateTime = time.Now()
+		topic.ModifyTime = time.Now()
+		topicId, err := aimod.AddAiChatTopic(topic)
+		if err != nil {
+			br.Msg = "获取数据失败!"
+			br.ErrMsg = "生成话题失败,Err:" + err.Error()
+			return
+		}
+		resp.AiChatTopicId = int(topicId)
+		chatItem := new(aimod.AiChat)
+		chatItem.AiChatTopicId = resp.AiChatTopicId
+		chatItem.Ask = req.Ask
+		chatItem.AskUuid = utils.MD5(req.Ask)
+		chatItem.Answer = answer
+		chatItem.Model = "gpt-4-1106-preview"
+		chatItem.SysUserId = this.SysUser.AdminId
+		chatItem.SysUserRealName = this.SysUser.RealName
+		chatItem.CreateTime = time.Now()
+		chatItem.ModifyTime = time.Now()
+		_, err = aimod.AddAiChat(chatItem)
+		if err != nil {
+			br.Msg = "获取数据失败!"
+			br.ErrMsg = "生成话题记录失败,Err:" + err.Error()
+			return
+		}
+	} else {
+		resp.AiChatTopicId = req.AiChatTopicId
+		chatItem := new(aimod.AiChat)
+		chatItem.AiChatTopicId = resp.AiChatTopicId
+		chatItem.Ask = req.Ask
+		chatItem.AskUuid = utils.MD5(req.Ask)
+		chatItem.Answer = answer
+		chatItem.Model = "gpt-4-1106-preview"
+		chatItem.SysUserId = this.SysUser.AdminId
+		chatItem.SysUserRealName = this.SysUser.RealName
+		chatItem.CreateTime = time.Now()
+		chatItem.ModifyTime = time.Now()
+		_, err = aimod.AddAiChat(chatItem)
+		if err != nil {
+			br.Msg = "获取数据失败!"
+			br.ErrMsg = "生成话题记录失败,Err:" + err.Error()
+			return
+		}
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+	br.Data = resp
+}
+
+// @Title 获取话题列表
+// @Description 获取话题列表接口
+// @Success 200 {object} aimod.AiChatTopicListResp
+// @router /topic/list [get]
+func (this *AiController) TopicList() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	sysUser := this.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+	list, err := aimod.GetAiChatTopicList(sysUser.AdminId)
+	if err != nil {
+		br.Msg = "获取数据失败!"
+		br.ErrMsg = "获取主题记录信息失败,Err:" + err.Error()
+		return
+	}
+	resp := new(aimod.AiChatTopicListResp)
+	resp.List = list
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+	br.Data = resp
+}
+
+// @Title 获取话题详情
+// @Description 获取话题详情接口
+// @Param   AiChatTopicId   query   int  true       "主题id"
+// @Success 200 {object} aimod.AiChatDetailResp
+// @router /topic/detail [get]
+func (this *AiController) TopicDetail() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	sysUser := this.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+
+	aiChatTopicId, _ := this.GetInt("AiChatTopicId")
+	list, err := aimod.GetAiChatList(aiChatTopicId)
+	if err != nil {
+		br.Msg = "获取数据失败!"
+		br.ErrMsg = "获取主题记录信息失败,Err:" + err.Error()
+		return
+	}
+	resp := new(aimod.AiChatDetailResp)
+	resp.List = list
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "获取成功"
+	br.Data = resp
+}
+
+// @Title 删除话题接口
+// @Description 删除话题接口
+// @Param	request	body aimod.TopicDeleteReq true "type json string"
+// @Success Ret=200 删除成功
+// @router /topic/delete [post]
+func (this *AiController) TopicDelete() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	sysUser := this.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+
+	var req aimod.TopicDeleteReq
+	err := json.Unmarshal(this.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+
+	if req.AiChatTopicId <= 0 {
+		br.Msg = "参数错误!"
+		br.ErrMsg = "参数错误!AiChatTopicId:" + strconv.Itoa(req.AiChatTopicId)
+		return
+	}
+	err = aimod.DeleteTopic(req.AiChatTopicId)
+	if err != nil {
+		br.Msg = "删除失败!"
+		br.ErrMsg = "删除失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "删除成功"
+	br.IsAddLog = true
+}
+
+// @Title 编辑话题接口
+// @Description 编辑话题接口
+// @Param	request	body aimod.TopicEditReq true "type json string"
+// @Success Ret=200 编辑成功
+// @router /topic/edit [post]
+func (this *AiController) TopicEdit() {
+	br := new(models.BaseResponse).Init()
+	defer func() {
+		this.Data["json"] = br
+		this.ServeJSON()
+	}()
+
+	sysUser := this.SysUser
+	if sysUser == nil {
+		br.Msg = "请登录"
+		br.ErrMsg = "请登录,SysUser Is Empty"
+		br.Ret = 408
+		return
+	}
+
+	var req aimod.TopicEditReq
+	err := json.Unmarshal(this.Ctx.Input.RequestBody, &req)
+	if err != nil {
+		br.Msg = "参数解析异常!"
+		br.ErrMsg = "参数解析失败,Err:" + err.Error()
+		return
+	}
+
+	if req.AiChatTopicId <= 0 {
+		br.Msg = "参数错误!"
+		br.ErrMsg = "参数错误!AiChatTopicId:" + strconv.Itoa(req.AiChatTopicId)
+		return
+	}
+	topic, err := aimod.GetAiChatTopicByTopicName(req.TopicName)
+	if err != nil && err.Error() != utils.ErrNoRow() {
+		br.Msg = "编辑失败!"
+		br.ErrMsg = "获取数据失败!Err:" + err.Error()
+		return
+	}
+	if topic != nil && topic.AiChatTopicId != req.AiChatTopicId {
+		br.Msg = "话题名称已存在,请重新修改!"
+		return
+	}
+
+	err = aimod.EditTopic(req.AiChatTopicId, req.TopicName)
+	if err != nil {
+		br.Msg = "编辑失败!"
+		br.ErrMsg = "编辑失败,Err:" + err.Error()
+		return
+	}
+	br.Ret = 200
+	br.Success = true
+	br.Msg = "编辑成功"
+	br.IsAddLog = true
+}

+ 146 - 0
models/aimod/ai.go

@@ -0,0 +1,146 @@
+package aimod
+
+import (
+	"github.com/beego/beego/v2/client/orm"
+	"time"
+)
+
+type AiChatTopic struct {
+	AiChatTopicId   int `orm:"column(ai_chat_topic_id);pk"`
+	TopicName       string
+	SysUserId       int
+	SysUserRealName string
+	CreateTime      time.Time
+	ModifyTime      time.Time
+}
+
+type AiChat struct {
+	AiChatId        int `orm:"column(ai_chat_id);pk"`
+	AiChatTopicId   int
+	Ask             string
+	AskUuid         string
+	Answer          string
+	Model           string
+	SysUserId       int
+	SysUserRealName string
+	CreateTime      time.Time
+	ModifyTime      time.Time
+}
+
+type ChatReq struct {
+	AiChatTopicId int    `description:"主题id"`
+	Ask           string `description:"提问"`
+}
+
+func GetAiChatByAsk(askUuid string) (item *AiChat, err error) {
+	sql := `SELECT * FROM ai_chat WHERE ask_uuid=?`
+	o := orm.NewOrmUsingDB("ai")
+	err = o.Raw(sql, askUuid).QueryRow(&item)
+	return
+}
+
+type ChatResp struct {
+	AiChatTopicId int    `description:"主题id"`
+	Ask           string `description:"提问"`
+	Answer        string `description:"回答"`
+	Model         string
+}
+
+// AddAiChatTopic 新增主题
+func AddAiChatTopic(item *AiChatTopic) (lastId int64, err error) {
+	o := orm.NewOrmUsingDB("ai")
+	lastId, err = o.Insert(item)
+	return
+}
+
+// AddAiChat 新增聊天
+func AddAiChat(item *AiChat) (lastId int64, err error) {
+	o := orm.NewOrmUsingDB("ai")
+	lastId, err = o.Insert(item)
+	return
+}
+
+type AiChatTopicView struct {
+	AiChatTopicId int    `description:"主题id"`
+	TopicName     string `description:"主题名称"`
+	CreateTime    string `description:"创建时间"`
+	ModifyTime    string `description:"修改时间"`
+}
+
+func GetAiChatTopicList(sysUserId int) (item []*AiChatTopicView, err error) {
+	sql := ` SELECT * FROM ai_chat_topic WHERE sys_user_id=? ORDER BY create_time DESC `
+	o := orm.NewOrmUsingDB("ai")
+	_, err = o.Raw(sql, sysUserId).QueryRows(&item)
+	return
+}
+
+type AiChatTopicListResp struct {
+	List []*AiChatTopicView
+}
+
+type AiChatView struct {
+	AiChatId      int    `description:"记录id"`
+	AiChatTopicId int    `description:"主题id"`
+	Ask           string `description:"提问"`
+	Answer        string `description:"答案"`
+	Model         string
+	CreateTime    string `description:"创建时间"`
+	ModifyTime    string `description:"修改时间"`
+}
+
+func GetAiChatList(aiChatTopicId int) (item []*AiChatView, err error) {
+	sql := ` SELECT * FROM ai_chat WHERE ai_chat_topic_id=? ORDER BY create_time ASC `
+	o := orm.NewOrmUsingDB("ai")
+	_, err = o.Raw(sql, aiChatTopicId).QueryRows(&item)
+	return
+}
+
+type AiChatDetailResp struct {
+	List []*AiChatView
+}
+
+type TopicDeleteReq struct {
+	AiChatTopicId int `description:"主题id"`
+}
+
+func DeleteTopic(topicId int) (err error) {
+	o := orm.NewOrmUsingDB("ai")
+	tx, err := o.Begin()
+	defer func() {
+		if err != nil {
+			tx.Rollback()
+		} else {
+			tx.Commit()
+		}
+	}()
+	sql := ` DELETE FROM ai_chat_topic WHERE  ai_chat_topic_id=? `
+	_, err = tx.Raw(sql, topicId).Exec()
+	if err != nil {
+		return err
+	}
+	sql = ` DELETE FROM ai_chat WHERE  ai_chat_topic_id=? `
+	_, err = tx.Raw(sql, topicId).Exec()
+	if err != nil {
+		return err
+	}
+	return err
+}
+
+type TopicEditReq struct {
+	AiChatTopicId int    `description:"主题id"`
+	TopicName     string `description:"主题名称"`
+}
+
+func GetAiChatTopicByTopicName(topicName string) (item *AiChatTopicView, err error) {
+	sql := ` SELECT * FROM ai_chat_topic WHERE topic_name=? `
+	o := orm.NewOrmUsingDB("ai")
+	err = o.Raw(sql, topicName).QueryRow(&item)
+	return
+}
+
+func EditTopic(topicId int, topicName string) (err error) {
+	o := orm.NewOrmUsingDB("ai")
+	sql := ` UPDATE ai_chat_topic SET topic_name=? WHERE  ai_chat_topic_id=? `
+	_, err = o.Raw(sql, topicName, topicId).Exec()
+	return err
+}

+ 9 - 0
models/db.go

@@ -70,6 +70,15 @@ func init() {
 		weeklyDb.SetConnMaxLifetime(10 * time.Minute)
 	}
 
+	if utils.MYSQL_AI_URL != "" {
+		_ = orm.RegisterDataBase("ai", "mysql", utils.MYSQL_AI_URL)
+		orm.SetMaxIdleConns("ai", 50)
+		orm.SetMaxOpenConns("ai", 100)
+
+		weeklyDb, _ := orm.GetDB("ai")
+		weeklyDb.SetConnMaxLifetime(10 * time.Minute)
+	}
+
 	orm.Debug = true
 	orm.DebugLog = orm.NewLog(utils.Binlog)
 

+ 6 - 0
routers/router.go

@@ -9,6 +9,7 @@ package routers
 
 import (
 	"eta/eta_api/controllers"
+	"eta/eta_api/controllers/ai"
 	"eta/eta_api/controllers/data_manage"
 	"eta/eta_api/controllers/data_manage/correlation"
 	"eta/eta_api/controllers/data_manage/cross_variety"
@@ -334,6 +335,11 @@ func init() {
 				&report_approve.ReportApproveFlowController{},
 			),
 		),
+		web.NSNamespace("/ai",
+			web.NSInclude(
+				&ai.AiController{},
+			),
+		),
 	)
 	web.AddNamespace(ns)
 }

+ 41 - 0
services/aiser/ai.go

@@ -0,0 +1,41 @@
+package aiser
+
+import (
+	"encoding/json"
+	"eta/eta_api/utils"
+	"github.com/rdlucklib/rdluck_tools/http"
+)
+
+func ChatAutoMsg(prompt string) (result string, err error) {
+	chatUrl := utils.EtaAiUrl + `/chat/auto_msg`
+	param := make(map[string]interface{})
+	param["Prompt"] = prompt
+	postData, err := json.Marshal(param)
+	if err != nil {
+		return result, err
+	}
+
+	utils.FileLog.Info("postData:" + string(postData))
+	body, err := http.HttpPost(chatUrl, string(postData), "application/json; charset=utf-8")
+	if err != nil {
+		return result, err
+	}
+	utils.FileLog.Info("result:" + string(body))
+
+	resp := new(ChatAutoMsgResp)
+	err = json.Unmarshal(body, &resp)
+	if err != nil {
+		return result, err
+	}
+	if resp.Ret != 200 {
+		return resp.Msg, nil
+	}
+	result = resp.Data
+	return result, nil
+}
+
+type ChatAutoMsgResp struct {
+	Ret  int
+	Data string
+	Msg  string
+}

+ 10 - 0
utils/config.go

@@ -17,6 +17,7 @@ var (
 	MYSQL_URL_GL     string
 	MYSQL_LOG_URL    string
 	MYSQL_WEEKLY_URL string //用户主库
+	MYSQL_AI_URL     string //ETA-AI 数据库
 
 	REDIS_CACHE string       //缓存地址
 	Rc          *cache.Cache //redis缓存
@@ -204,6 +205,9 @@ var (
 // PythonUrlReport2Img 生成长图服务地址
 var PythonUrlReport2Img string
 
+// ETA-AI服务
+var EtaAiUrl string
+
 func init() {
 	tmpRunMode, err := web.AppConfig.String("run_mode")
 	if err != nil {
@@ -243,6 +247,11 @@ func init() {
 	// 用户主库
 	MYSQL_WEEKLY_URL = config["mysql_url_weekly"]
 
+	// 用户主库
+	MYSQL_WEEKLY_URL = config["mysql_url_weekly"]
+	//ETA-AI
+	MYSQL_AI_URL = config["mysql_url_ai"]
+
 	REDIS_CACHE = config["beego_cache"]
 	if len(REDIS_CACHE) <= 0 {
 		panic(any("redis链接参数没有配置"))
@@ -460,6 +469,7 @@ func init() {
 	// 生成长图服务地址
 	PythonUrlReport2Img = config["python_url_report2img"]
 
+	EtaAiUrl = config["eta_ai_url"]
 	// 初始化ES
 	initEs()
 }