Browse Source

fix:问答结束

Roc 3 ngày trước cách đây
mục cha
commit
3661aac69a
3 tập tin đã thay đổi với 121 bổ sung66 xóa
  1. 39 26
      utils/llm/eta_llm/eta_llm_client.go
  2. 24 19
      utils/ws/session.go
  3. 58 21
      utils/ws/session_manager.go

+ 39 - 26
utils/llm/eta_llm/eta_llm_client.go

@@ -341,48 +341,61 @@ func parseResponse(response *http.Response) (baseResp eta_llm_http.BaseResponse,
 	baseResp.Data = bodyBytes
 	return
 }
-func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}) {
+func ParseStreamResponse(response *http.Response) (contentChan chan string, errChan chan error, closeChan chan struct{}, closeLlmChan chan bool) {
 	contentChan = make(chan string, 10)
 	errChan = make(chan error, 10)
 	closeChan = make(chan struct{})
+	closeLlmChan = make(chan bool, 1)
 	go func() {
-		defer close(contentChan)
-		defer close(errChan)
-		defer close(closeChan)
+		defer func() {
+			close(contentChan)
+			close(errChan)
+			close(closeChan)
+			close(closeLlmChan)
+		}()
+
 		scanner := bufio.NewScanner(response.Body)
 		scanner.Split(bufio.ScanLines)
+
 		for scanner.Scan() {
-			line := scanner.Text()
-			if line == "" {
-				continue
-			}
-			// 忽略 "ping" 行
-			if strings.HasPrefix(line, ": ping") {
-				continue
-			}
-			// 去除 "data: " 前缀
-			if strings.HasPrefix(line, "data: ") {
-				line = strings.TrimPrefix(line, "data: ")
-			}
-			var chunk eta_llm_http.ChunkResponse
-			if err := json.Unmarshal([]byte(line), &chunk); err != nil {
-				fmt.Println("解析错误的line:" + line)
-				errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
+			select {
+			case <-closeLlmChan:
 				return
-			}
-			// 处理每个 chunk
-			if chunk.Choices != nil && len(chunk.Choices) > 0 {
-				for _, choice := range chunk.Choices {
-					if choice.Delta.Content != "" {
-						contentChan <- choice.Delta.Content
+			default:
+				line := scanner.Text()
+				if line == "" {
+					continue
+				}
+				// 忽略 "ping" 行
+				if strings.HasPrefix(line, ": ping") {
+					continue
+				}
+				// 去除 "data: " 前缀
+				if strings.HasPrefix(line, "data: ") {
+					line = strings.TrimPrefix(line, "data: ")
+				}
+				var chunk eta_llm_http.ChunkResponse
+				if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+					fmt.Println("解析错误的line:" + line)
+					errChan <- fmt.Errorf("解析 JSON 块失败: %w", err)
+					return
+				}
+				// 处理每个 chunk
+				if chunk.Choices != nil && len(chunk.Choices) > 0 {
+					for _, choice := range chunk.Choices {
+						if choice.Delta.Content != "" {
+							contentChan <- choice.Delta.Content
+						}
 					}
 				}
 			}
+
 		}
 		if err := scanner.Err(); err != nil {
 			errChan <- fmt.Errorf("读取响应体失败: %w", err)
 			return
 		}
+
 	}()
 	return
 }

+ 24 - 19
utils/ws/session.go

@@ -12,22 +12,25 @@ import (
 
 // Session 会话结构
 type Session struct {
-	Id          string
-	UserId      int
-	Conn        *websocket.Conn
-	LastActive  time.Time
-	Latency     *LatencyMeasurer
-	History     []json.RawMessage
-	CloseChan   chan struct{}
-	MessageChan chan string
-	mu          sync.RWMutex
-	sessionOnce sync.Once
+	Id           string
+	UserId       int
+	Conn         *websocket.Conn
+	LastActive   time.Time
+	Latency      *LatencyMeasurer
+	History      []json.RawMessage
+	CloseChan    chan struct{}
+	MessageChan  chan string
+	mu           sync.RWMutex
+	sessionOnce  sync.Once
+	CloseLlmChan *chan bool
+	LLMStatus    int8 // llm提问状态,0:未提问,1:提问中,-1:暂停提问
 }
 
 type Message struct {
-	KbName string `json:"KbName"`
-	Query  string `json:"Query"`
-	ChatId int    `json:"ChatId"`
+	KbName      string `json:"KbName"`
+	Query       string `json:"Query"`
+	ChatId      int    `json:"ChatId"`
+	MessageType string `json:"MessageType"`
 	//LastTopics []json.RawMessage `json:"LastTopics"`
 }
 
@@ -48,13 +51,15 @@ func (s *Session) readPump() {
 		}
 		// 更新活跃时间
 		s.UpdateActivity()
+
 		// 处理消息
-		if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
-			//写应答
-			_ = s.writeWithTimeout("<think></think>")
-			_ = s.writeWithTimeout(err.Error())
-			_ = s.writeWithTimeout("<EOF/>")
-		}
+		//if err = manager.HandleMessage(s.UserId, s.Id, message); err != nil {
+		//	//写应答
+		//	_ = s.writeWithTimeout("<think></think>")
+		//	_ = s.writeWithTimeout(err.Error())
+		//	_ = s.writeWithTimeout("<EOF/>")
+		//}
+		go manager.HandleMessage(s.UserId, s.Id, message)
 	}
 }
 

+ 58 - 21
utils/ws/session_manager.go

@@ -55,29 +55,53 @@ func Manager() *ConnectionManager {
 }
 
 // HandleMessage 消息处理核心逻辑
-func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) error {
-
+func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, message []byte) {
+	var err error
 	session, exists := manager.GetSession(sessionID)
 	if !exists {
-		return errors.New("session not found")
+		err = errors.New("session not found")
+		return
 	}
+
 	if strings.ToLower(string(message)) == "pong" {
 		session.UpdateActivity()
 		fmt.Printf("收到心跳消息,续期长连接:%v", session.LastActive)
-		return nil
-	}
-	if !Allow(userID, QA_LIMITER) {
-		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
-		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
-		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
-		return nil
+		return
 	}
+	defer func() {
+		if err != nil {
+			//写应答
+			_ = session.writeWithTimeout("<think></think>")
+			_ = session.writeWithTimeout(err.Error())
+			_ = session.writeWithTimeout("<EOF/>")
+		}
+	}()
 	var userMessage Message
-	err := json.Unmarshal(message, &userMessage)
+	err = json.Unmarshal(message, &userMessage)
 	if err != nil {
 		utils.FileLog.Error(fmt.Sprintf("消息格式错误:%s", string(message)))
 		fmt.Printf("消息格式错误:%s", string(message))
-		return errors.New("消息格式错误:" + err.Error())
+		err = errors.New("消息格式错误:" + err.Error())
+		return
+	}
+
+	if userMessage.MessageType == `stop` {
+		if session.LLMStatus == 1 {
+			// 标记llm提问状态:暂停提问
+			session.LLMStatus = -1
+		}
+		if session.CloseLlmChan != nil {
+			*session.CloseLlmChan <- true
+		}
+		return
+	}
+
+	// 限流
+	if !Allow(userID, QA_LIMITER) {
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("您提问的太频繁了,请稍后再试"))
+		_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
+		return
 	}
 	// 处理业务逻辑
 	//session.History = append(session.History, userMessage.LastTopics...)
@@ -104,29 +128,37 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 	}()
 	if resp == nil {
 		utils.FileLog.Error("知识库问答失败: 无应答")
-		return errors.New("知识库问答失败: 无应答")
+		err = errors.New("知识库问答失败: 无应答")
+		return
 	}
 	if err != nil {
 		utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
 		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
-		return err
+		return
 	}
 
 	if resp.StatusCode != http.StatusOK {
 		utils.FileLog.Error(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
 		err = errors.New(fmt.Sprintf("知识库问答失败: httpCode:%d,错误信息:%s", resp.StatusCode, http.StatusText(resp.StatusCode)))
-		return err
+		return
 	}
+
 	// 解析流式响应
-	contentChan, errChan, closeChan := eta_llm.ParseStreamResponse(resp)
+	contentChan, errChan, closeChan, closeLlmChan := eta_llm.ParseStreamResponse(resp)
+	session.CloseLlmChan = &closeLlmChan
+	// 标记llm提问状态:提问中
+	session.LLMStatus = 1
 	emptyContent := true
 	// 处理流式数据并发送到 WebSocket
 	for {
 		select {
 		case content, ok := <-contentChan:
-			if !ok {
-				err = errors.New("未知的错误异常")
-				return err
+			if !ok && session.LLMStatus != -1 {
+				err = errors.New("未知的内容错误异常")
+
+				// 标记llm提问状态:未提问
+				session.LLMStatus = 0
+				return
 			}
 			session.UpdateActivity()
 			if emptyContent {
@@ -140,15 +172,20 @@ func (manager *ConnectionManager) HandleMessage(userID int, sessionID string, me
 			} else {
 				err = errors.New(chanErr.Error())
 			}
+			// 标记llm提问状态:未提问
+			session.LLMStatus = 0
 			// 发送错误消息到 WebSocket
-			return err
+			return
 		case <-closeChan:
 			if emptyContent {
 				_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<think></think>"))
 				_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("暂时找不到答案"))
 			}
 			_ = session.Conn.WriteMessage(websocket.TextMessage, []byte("<EOF/>"))
-			return nil
+			// 标记llm提问状态:未提问
+			session.LLMStatus = 0
+
+			return
 		}
 	}
 	// 更新最后活跃时间