ai_controller.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package controllers
  2. import (
  3. "ems-backend/models"
  4. "ems-backend/services"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strconv"
  11. "strings"
  12. "time"
  13. "github.com/gin-gonic/gin"
  14. "gorm.io/gorm"
  15. )
  16. // GetAIConfig 获取AI配置
  17. func GetAIConfig(c *gin.Context) {
  18. configs := make(map[string]string)
  19. keys := []string{"ai_api_url", "ai_api_key", "ai_model", "ai_alert_threshold", "ai_auto_ack"}
  20. for _, k := range keys {
  21. var config models.SysConfig
  22. if err := models.DB.Where("config_key = ?", k).First(&config).Error; err == nil {
  23. configs[k] = config.ConfigValue
  24. } else {
  25. // 默认值
  26. if k == "ai_alert_threshold" {
  27. configs[k] = "200"
  28. } else if k == "ai_auto_ack" {
  29. configs[k] = "false"
  30. } else if k == "ai_model" {
  31. configs[k] = "gpt-3.5-turbo"
  32. } else {
  33. configs[k] = ""
  34. }
  35. }
  36. }
  37. c.JSON(http.StatusOK, configs)
  38. }
  39. // UpdateAIConfig 更新AI配置
  40. func UpdateAIConfig(c *gin.Context) {
  41. // 使用 interface{} 接收,防止前端传数字导致 BindJSON 失败
  42. var input map[string]interface{}
  43. if err := c.ShouldBindJSON(&input); err != nil {
  44. c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON format: " + err.Error()})
  45. return
  46. }
  47. tx := models.DB.Begin()
  48. defer func() {
  49. if r := recover(); r != nil {
  50. tx.Rollback()
  51. }
  52. }()
  53. for k, v := range input {
  54. // 强制转换为字符串
  55. valStr := fmt.Sprintf("%v", v)
  56. var config models.SysConfig
  57. err := tx.Where("config_key = ?", k).First(&config).Error
  58. if errors.Is(err, gorm.ErrRecordNotFound) {
  59. config = models.SysConfig{ConfigKey: k, ConfigValue: valStr, ConfigType: "Y"}
  60. if err := tx.Create(&config).Error; err != nil {
  61. tx.Rollback()
  62. c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create config " + k})
  63. return
  64. }
  65. } else if err != nil {
  66. tx.Rollback()
  67. c.JSON(http.StatusInternalServerError, gin.H{"error": "Database error checking " + k})
  68. return
  69. } else {
  70. config.ConfigValue = valStr
  71. if err := tx.Save(&config).Error; err != nil {
  72. tx.Rollback()
  73. c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update config " + k})
  74. return
  75. }
  76. }
  77. }
  78. if err := tx.Commit().Error; err != nil {
  79. c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to commit transaction"})
  80. return
  81. }
  82. c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
  83. }
  84. // TestAIConnection 测试AI连接
  85. func TestAIConnection(c *gin.Context) {
  86. var input struct {
  87. APIUrl string `json:"ai_api_url"`
  88. APIKey string `json:"ai_api_key"`
  89. Model string `json:"ai_model"`
  90. }
  91. if err := c.ShouldBindJSON(&input); err != nil {
  92. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  93. return
  94. }
  95. if input.APIUrl == "" || input.APIKey == "" {
  96. c.JSON(http.StatusBadRequest, gin.H{"error": "API URL和Key不能为空"})
  97. return
  98. }
  99. model := input.Model
  100. if model == "" {
  101. model = "gpt-3.5-turbo"
  102. }
  103. // 智能修正 URL
  104. chatUrl := services.NormalizeChatUrl(input.APIUrl)
  105. // 简单的 Ping 请求 (OpenAI 格式)
  106. if _, err := services.FetchAIResponse(chatUrl, input.APIKey, model, "Hello, are you there? Reply with 'Yes'."); err != nil {
  107. c.JSON(http.StatusInternalServerError, gin.H{"error": "连接失败: " + err.Error()})
  108. return
  109. }
  110. c.JSON(http.StatusOK, gin.H{"message": "连接成功"})
  111. }
  112. // GetRemoteModels 获取远程模型列表
  113. func GetRemoteModels(c *gin.Context) {
  114. var input struct {
  115. APIUrl string `json:"ai_api_url"`
  116. APIKey string `json:"ai_api_key"`
  117. }
  118. if err := c.ShouldBindJSON(&input); err != nil {
  119. c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  120. return
  121. }
  122. if input.APIUrl == "" || input.APIKey == "" {
  123. c.JSON(http.StatusBadRequest, gin.H{"error": "API URL和Key不能为空"})
  124. return
  125. }
  126. // 推断 Models API 地址
  127. modelsUrl := normalizeModelsUrl(input.APIUrl)
  128. req, err := http.NewRequest("GET", modelsUrl, nil)
  129. if err != nil {
  130. c.JSON(http.StatusInternalServerError, gin.H{"error": "Request creation failed: " + err.Error()})
  131. return
  132. }
  133. req.Header.Set("Authorization", "Bearer "+input.APIKey)
  134. client := &http.Client{Timeout: 10 * time.Second}
  135. resp, err := client.Do(req)
  136. if err != nil {
  137. c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch models: " + err.Error()})
  138. return
  139. }
  140. defer resp.Body.Close()
  141. bodyBytes, _ := io.ReadAll(resp.Body)
  142. if resp.StatusCode != 200 {
  143. c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("API Error (%d): %s", resp.StatusCode, string(bodyBytes))})
  144. return
  145. }
  146. // 解析 OpenAI 格式的模型列表
  147. var result struct {
  148. Data []struct {
  149. ID string `json:"id"`
  150. } `json:"data"`
  151. }
  152. if err := json.Unmarshal(bodyBytes, &result); err != nil {
  153. c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to parse models response: " + err.Error()})
  154. return
  155. }
  156. var models []string
  157. for _, m := range result.Data {
  158. models = append(models, m.ID)
  159. }
  160. c.JSON(http.StatusOK, gin.H{"models": models})
  161. }
  162. // GenerateAIReport 生成AI报表
  163. func GenerateAIReport(c *gin.Context) {
  164. // 1. 检查当前实时告警数量
  165. var count int64
  166. var alarms []models.AlarmLog
  167. models.DB.Model(&models.AlarmLog{}).Where("status = ?", "ACTIVE").Count(&count)
  168. // 获取部分告警内容用于生成 Prompt
  169. models.DB.Model(&models.AlarmLog{}).Where("status = ?", "ACTIVE").Limit(20).Order("start_time desc").Find(&alarms)
  170. // 2. 调用 Service 生成报表
  171. report, err := services.GenerateAIReportInternal(alarms, "实时告警分析报告 - " + time.Now().Format("2006-01-02 15:04:05"), "ALARM_ANALYSIS")
  172. if err != nil {
  173. c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
  174. return
  175. }
  176. c.JSON(http.StatusOK, report)
  177. }
  178. // GetAIReports 获取报表列表
  179. func GetAIReports(c *gin.Context) {
  180. page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
  181. pageSize, _ := strconv.Atoi(c.DefaultQuery("pageSize", "10"))
  182. keyword := c.Query("keyword")
  183. query := models.DB.Model(&models.AIAnalysisReport{})
  184. if keyword != "" {
  185. query = query.Where("title LIKE ?", "%"+keyword+"%")
  186. }
  187. var total int64
  188. query.Count(&total)
  189. var reports []models.AIAnalysisReport
  190. offset := (page - 1) * pageSize
  191. query.Order("created_at desc").Limit(pageSize).Offset(offset).Find(&reports)
  192. c.JSON(http.StatusOK, gin.H{"data": reports, "total": total})
  193. }
  194. // GetAIReportDetail 获取报表详情
  195. func GetAIReportDetail(c *gin.Context) {
  196. id := c.Param("id")
  197. var report models.AIAnalysisReport
  198. if err := models.DB.Where("id = ?", id).First(&report).Error; err != nil {
  199. c.JSON(http.StatusNotFound, gin.H{"error": "Report not found"})
  200. return
  201. }
  202. c.JSON(http.StatusOK, report)
  203. }
  204. // 内部辅助函数
  205. // normalizeModelsUrl 规范化 Models 接口地址
  206. func normalizeModelsUrl(inputUrl string) string {
  207. inputUrl = strings.TrimSpace(inputUrl)
  208. // 如果包含 /chat/completions,替换为 /models
  209. if strings.Contains(inputUrl, "/chat/completions") {
  210. return strings.Replace(inputUrl, "/chat/completions", "/models", 1)
  211. }
  212. // 如果是 /models 结尾,直接用
  213. if strings.HasSuffix(inputUrl, "/models") {
  214. return inputUrl
  215. }
  216. // 如果是 /v1 结尾
  217. if strings.HasSuffix(inputUrl, "/v1") {
  218. return strings.TrimRight(inputUrl, "/") + "/models"
  219. }
  220. // 默认追加
  221. return strings.TrimRight(inputUrl, "/") + "/models"
  222. }