Pārlūkot izejas kodu

Merge branch 'rag/4.0' into debug

# Conflicts:
#	utils/redis.go
#	utils/redis/cluster_redis.go
#	utils/redis/standalone_redis.go
Roc 3 dienas atpakaļ
vecāks
revīzija
70d61b0da5

+ 1 - 1
models/rag/ai_task.go

@@ -111,7 +111,7 @@ func (m *AiTask) GetListByCondition(field, condition string, pars []interface{},
 	if field == "" {
 		field = "*"
 	}
-	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by AiTask_id desc LIMIT ?,?`, field, m.TableName(), condition)
+	sqlStr := fmt.Sprintf(`SELECT %s FROM %s WHERE 1=1 %s order by ai_task_id desc LIMIT ?,?`, field, m.TableName(), condition)
 	pars = append(pars, startSize, pageSize)
 	err = global.DbMap[utils.DbNameAI].Raw(sqlStr, pars...).Find(&items).Error
 

+ 28 - 6
services/task.go

@@ -796,7 +796,9 @@ func HandleAiArticleAbstractLlmOp() {
 	}
 }
 
-var aiTaskIdMap = map[int]bool{}
+var aiTaskHandleIdMap = map[int]bool{}
+
+// todo 任务开始时间
 
 // handleAiArticleAbstractLlmOp
 // @Description: 处理AI库的报告摘要生成(批量任务)
@@ -822,6 +824,27 @@ func handleAiArticleAbstractLlmOp(b []byte) {
 		return
 	}
 
+	// 如果没有处理过该任务,那么就标记该任务开始
+	if _, ok := aiTaskHandleIdMap[item.AiTaskID]; !ok {
+		aiTaskObj := rag.AiTask{}
+		aiTaskInfo, tmpErr := aiTaskObj.GetByID(item.AiTaskID)
+		if tmpErr != nil {
+			err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
+			return
+		}
+		// 如果任务是初始化,那么就标记开始
+		if aiTaskInfo.Status == `init` {
+			aiTaskInfo.StartTime = time.Now()
+			aiTaskInfo.Status = `processing`
+			aiTaskInfo.UpdateTime = time.Now()
+			tmpErr = aiTaskInfo.Update([]string{`start_time`, "status", "update_time"})
+			if tmpErr != nil {
+				utils.FileLog.Error("标记任务开始状态失败, err: %s", tmpErr.Error())
+			}
+		}
+
+	}
+
 	// 处理完成后标记任务状态
 	defer func() {
 		// 修改任务状态
@@ -837,12 +860,12 @@ func handleAiArticleAbstractLlmOp(b []byte) {
 				err = fmt.Errorf("查找任务失败, err: %s", tmpErr.Error())
 				return
 			}
+			aiTaskInfo.EndTime = time.Now()
 			aiTaskInfo.Status = `done`
 			aiTaskInfo.UpdateTime = time.Now()
-			tmpErr = aiTaskInfo.Update([]string{"status", "update_time"})
+			tmpErr = aiTaskInfo.Update([]string{`end_time`, "status", "update_time"})
 			if tmpErr != nil {
-				err = fmt.Errorf("标记任务状态失败, err: %s", tmpErr.Error())
-				return
+				utils.FileLog.Error("标记任务状态失败, err: %s", tmpErr.Error())
 			}
 		}
 
@@ -867,8 +890,7 @@ func handleAiArticleAbstractLlmOp(b []byte) {
 		item.ModifyTime = time.Now()
 		tmpErr := item.Update([]string{"status", "remark", "modify_time"})
 		if tmpErr != nil {
-			err = fmt.Errorf("标记任务记录状态失败, err: %s", tmpErr.Error())
-			return
+			utils.FileLog.Error("标记任务记录状态失败, err: %s", tmpErr.Error())
 		}
 	}()
 

+ 39 - 26
utils/llm/eta_llm/eta_llm_client.go

@@ -341,48 +341,61 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
 	baseResp.Data = bodyBytes
 	return
 }
-func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
+func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}, closeLlmChan chan bool) {
 	contentChan = make(chan string, 10)
 	errChan = make(chan error, 10)
 	closeChan = make(chan struct{})
+	closeLlmChan = make(chan bool, 1)
 	go func() {
-		defer close(contentChan)
-		defer close(errChan)
-		defer close(closeChan)
+		defer func() {
+			close(contentChan)
+			close(errChan)
+			close(closeChan)
+			close(closeLlmChan)
+		}()
+
 		scanner := bufio.NewScanner(response.Body)
 		scanner.Split(bufio.ScanLines)
+
 		for scanner.Scan() {
-			line := scanner.Text()
-			if line == "" {
-				continue
-			}
-			// 忽略 "ping" 行
-			if strings.HasPrefix(line, ": ping") {
-				continue
-			}
-			// 去除 "data: " 前缀
-			if strings.HasPrefix(line, "data: ") {
-				line = strings.TrimPrefix(line, "data: ")
-			}
-			var chunk eta_llm_http.ChunkResponse
-			if err := json.Unmarshal([]byte(line), &chunk); err != nil {
-				fmt.Println("解析错误的line:" + line)
-				errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
+			select {
+			case <-closeLlmChan:
 				return
-			}
-			// 处理每个 chunk
-			if chunk.Choices != nil && len(chunk.Choices) > 0 {
-				for _, choice := range chunk.Choices {
-					if choice.Delta.Content != "" {
-						contentChan <- choice.Delta.Content
+			default:
+				line := scanner.Text()
+				if line == "" {
+					continue
+				}
+				// 忽略 "ping" 行
+				if strings.HasPrefix(line, ": ping") {
+					continue
+				}
+				// 去除 "data: " 前缀
+				if strings.HasPrefix(line, "data: ") {
+					line = strings.TrimPrefix(line, "data: ")
+				}
+				var chunk eta_llm_http.ChunkResponse
+				if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+					fmt.Println("解析错误的line:" + line)
+					errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
+					return
+				}
+				// 处理每个 chunk
+				if chunk.Choices != nil && len(chunk.Choices) > 0 {
+					for _, choice := range chunk.Choices {
+						if choice.Delta.Content != "" {
+							contentChan <- choice.Delta.Content
+						}
 					}
 				}
 			}
+
 		}
 		if err := scanner.Err(); err != nil {
 			errChan <- fmt.Errorf("读取响应体失败: %w", err)
 			return
 		}
+
 	}()
 	return
 }

+ 1 - 0
utils/redis.go

@@ -20,6 +20,7 @@ type RedisClient interface {
 	LPush(key string, val interface{}) error
 	Brpop(key string, callback func([]byte))
 	BrpopWithTimeout(key string, timeout time.Duration, callback func([]byte)) error
+	LLen(key string) (int64, error)
 	GetRedisTTL(key string) time.Duration
 	Incrby(key string, num int) (interface{}, error)
 	Do(commandName string, args ...interface{}) (reply interface{}, err error)

+ 12 - 0
utils/redis/cluster_redis.go

@@ -269,6 +269,18 @@ func (rc *ClusterRedisClient) BrpopWithTimeout(key string, timeout time.Duration
 	return
 }
 
+// LLen
+// @Description: 获取list中剩余的数据数
+// @author: Roc
+// @receiver rc
+// @datetime 2025-04-25 10:58:25
+// @param key string
+// @return int64
+// @return error
+func (rc *ClusterRedisClient) LLen(key string) (int64, error) {
+	return rc.redisClient.LLen(context.TODO(), key).Result()
+}
+
 // GetRedisTTL
 // @Description: 获取key的过期时间
 // @receiver rc

+ 12 - 0
utils/redis/standalone_redis.go

@@ -256,6 +256,18 @@ func (rc *StandaloneRedisClient) BrpopWithTimeout(key string, timeout time.Durat
 	return
 }
 
+// LLen
+// @Description: 获取list中剩余的数据数
+// @author: Roc
+// @receiver rc
+// @datetime 2025-04-25 10:58:25
+// @param key string
+// @return int64
+// @return error
+func (rc *StandaloneRedisClient) LLen(key string) (int64, error) {
+	return rc.redisClient.LLen(context.TODO(), key).Result()
+}
+
 // GetRedisTTL
 // @Description: 获取key的过期时间
 // @receiver rc

+ 24 - 19
utils/ws/session.go

@@ -12,22 +12,25 @@ import (
 
 // Session 会话结构
 type Session struct {
-	Id          string
-	UserId      int
-	Conn        *websocket.Conn
-	LastActive  time.Time
-	Latency     *LatencyMeasurer
-	History     []json.RawMessage
-	CloseChan   chan struct{}
-	MessageChan chan string
-	mu          sync.RWMutex
-	sessionOnce sync.Once
+	Id           string
+	UserId       int
+	Conn         *websocket.Conn
+	LastActive   time.Time
+	Latency      *LatencyMeasurer
+	History      []json.RawMessage
+	CloseChan    chan struct{}
+	MessageChan  chan string
+	mu           sync.RWMutex
+	sessionOnce  sync.Once
+	CloseLlmChan *chan bool
+	LLMStatus    int8 // llm提问状态,0:未提问,1:提问中,-1:暂停提问
 }
 
 type Message struct {
-	KbName string `json:"KbName"`
-	Query  string `json:"Query"`
-	ChatId int    `json:"ChatId"`
+	KbName      string `json:"KbName"`
+	Query       string `json:"Query"`
+	ChatId      int    `json:"ChatId"`
+	MessageType string `json:"MessageType"`
 	//LastTopics []json.RawMessage `json:"LastTopics"`
 }
 
@@ -48,13 +51,15 @@ func (s *Session) readPump() {
 		}
 		// 更新活跃时间
 		s.UpdateActivity()
+
 		// 处理消息
-		if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
-			//写应答
-			_ = s.writeWithTimeout("<think></think>")
-			_ = s.writeWithTimeout(err.Error())
-			_ = s.writeWithTimeout("<EOF/>")
-		}
+		//if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
+		//	//写应答
+		//	_ = s.writeWithTimeout("<think></think>")
+		//	_ = s.writeWithTimeout(err.Error())
+		//	_ = s.writeWithTimeout("<EOF/>")
+		//}
+		go manager.HandleMessage(s.UserId, s.Id, message)
 	}
 }
 

+ 58 - 21
utils/ws/session_manager.go

@@ -55,29 +55,53 @@ func Manager() *ConnectionManager {
 }
 
 // HandleMessage 消息处理核心逻辑
-func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
-
+func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) {
+	var err error
 	session, exists := manager.GetSession(sessionID)
 	if !exists {
-		return errors.New("session not found")
+		err = errors.New("session not found")
+		return
 	}
+
 	if strings.ToLower(string(message)) == "pong" {
 		session.UpdateActivity()
 		fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
-		return nil
-	}
-	if !Allow(userID, QA_LIMITER) {
-		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
-		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
-		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
-		return nil
+		return
 	}
+	defer func() {
+		if err != nil {
+			//写应答
+			_ = session.writeWithTimeout("<think></think>")
+			_ = session.writeWithTimeout(err.Error())
+			_ = session.writeWithTimeout("<EOF/>")
+		}
+	}()
 	var userMessage Message
-	err := json.Unmarshal(message, &userMessage)
+	err = json.Unmarshal(message, &userMessage)
 	if err != nil {
 		utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
 		fmt.Printf("消息格式错误:%s", string(message))
-		return errors.New("消息格式错误:" + err.Error())
+		err = errors.New("消息格式错误:" + err.Error())
+		return
+	}
+
+	if userMessage.MessageType == `stop` {
+		if session.LLMStatus == 1 {
+			// 标记llm提问状态:暂停提问
+			session.LLMStatus = -1
+		}
+		if session.CloseLlmChan != nil {
+			*session.CloseLlmChan <- true
+		}
+		return
+	}
+
+	// 限流
+	if !Allow(userID, QA_LIMITER) {
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
+		return
 	}
 	// 处理业务逻辑
 	//session.History = append(session.History, userMessage.LastTopics...)
@@ -104,29 +128,37 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	}()
 	if resp == nil {
 		utils.FileLog.Error("知识库问答失败: 无应答")
-		return errors.New("知识库问答失败: 无应答")
+		err = errors.New("知识库问答失败: 无应答")
+		return
 	}
 	if err != nil {
 		utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
 		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
-		return err
+		return
 	}
 
 	if resp.StatusCode != http.StatusOK {
 		utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
 		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
-		return err
+		return
 	}
+
 	// 解析流式响应
-	contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
+	contentChan, errChan, closeChan, closeLlmChan := eta_llm.ParseStreamResponse(resp)
+	session.CloseLlmChan = &closeLlmChan
+	// 标记llm提问状态:提问中
+	session.LLMStatus = 1
 	emptyContent := true
 	// 处理流式数据并发送到 WebSocket
 	for {
 		select {
 		case content, ok := <-contentChan:
-			if !ok {
-				err = errors.New("未知的错误异常")
-				return err
+			if !ok && session.LLMStatus != -1 {
+				err = errors.New("未知的内容错误异常")
+
+				// 标记llm提问状态:未提问
+				session.LLMStatus = 0
+				return
 			}
 			session.UpdateActivity()
 			if emptyContent {
@@ -140,15 +172,20 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 			} else {
 				err = errors.New(chanErr.Error())
 			}
+			// 标记llm提问状态:未提问
+			session.LLMStatus = 0
 			// 发送错误消息到 WebSocket
-			return err
+			return
 		case <-closeChan:
 			if emptyContent {
 				_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
 				_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("暂时找不到答案"))
 			}
 			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
-			return nil
+			// 标记llm提问状态:未提问
+			session.LLMStatus = 0
+
+			return
 		}
 	}
 	// 更新最后活跃时间