package websocketserver import ( "database/sql" "encoding/json" "fmt" "gontrol/src/logger" "gontrol/src/randomname" "gontrol/src/server/api" "io" "log" "net/http" "net/url" "strconv" "strings" "sync" "time" "github.com/PuerkitoBio/goquery" _ "github.com/go-sql-driver/mysql" "github.com/gorilla/websocket" ) var responseChannels sync.Map // Key: agentName, Value: chan string var db *sql.DB type webSocketHandler struct { upgrader websocket.Upgrader } var agentSockets = make(map[string]*websocket.Conn) var agentSocketsMutex sync.Mutex var getAgentNames http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) { agentSocketsMutex.Lock() agentNames := make([]string, 0, len(agentSockets)) for agentName := range agentSockets { agentNames = append(agentNames, agentName) } agentSocketsMutex.Unlock() w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") json.NewEncoder(w).Encode(agentNames) } func registerAgent(agentName, agentId, agentIp, agentType, addPort, hostname string) error { registerURL := "http://localhost:3333/agents" form := url.Values{} form.Add("agentId", agentId) form.Add("agentName", agentName) form.Add("agentType", agentType) form.Add("IPv4Address", agentIp) form.Add("addPort", addPort) form.Add("hostname", hostname) resp, err := http.PostForm(registerURL, form) if err != nil { log.Printf("Error registering agent: %v", err) logger.InsertLog(logger.Error, fmt.Sprintf("Error registering agent: %v", err)) return err } defer resp.Body.Close() if resp.StatusCode == http.StatusCreated { log.Printf("Agent %s successfully registered.", agentName) logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s successfully registered.", agentName)) return nil } else if resp.StatusCode == http.StatusOK { log.Printf("Agent %s already registered.", agentName) logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s already registered.", agentName)) return nil } else { log.Printf("Failed to register agent, status: %v", resp.Status) logger.InsertLog(logger.Error, fmt.Sprintf("Failed to register agent, status: %v", resp.Status)) return err } } func getAgentDetails(agentId string) (*api.Agent, error) { agentURL := "http://localhost:3333/agents/" + agentId resp, err := http.Get(agentURL) if err != nil { log.Printf("Failed to make GET request: %s", err) logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err)) return nil, err } defer resp.Body.Close() doc, err := goquery.NewDocumentFromReader(resp.Body) if err != nil { log.Printf("Failed to parse HTML: %s", err) logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse HTML: %s", err)) return nil, err } agent := &api.Agent{} doc.Find("#agent-detail p").Each(func(i int, s *goquery.Selection) { text := s.Text() if strings.HasPrefix(text, "ID:") { agent.AgentID, err = strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(text, "ID:"))) if err != nil { log.Printf("Converting string to integer failed in getAgentDetails(): %s", err) logger.InsertLog(logger.Error, fmt.Sprintf("Converting string to integer failed in getAgentDetails(): %s", err)) } } else if strings.HasPrefix(text, "Name:") { agent.AgentName = strings.TrimSpace(strings.TrimPrefix(text, "Name:")) } else if strings.HasPrefix(text, "Type:") { agent.AgentType = strings.TrimSpace(strings.TrimPrefix(text, "Type:")) } }) return agent, nil } func containsId(ids []string, agentId string) bool { for _, id := range ids { if id == agentId { return true } } return false } func getAgentIds() ([]string, error) { idURL := "http://localhost:3333/agentIds" resp, err := http.Get(idURL) if err != nil { log.Printf("Failed to make GET request: %s", err) logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err)) return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Printf("Unexpected status code: %d", resp.StatusCode) logger.InsertLog(logger.Info, fmt.Sprintf("Unexpected status code: %d", resp.StatusCode)) return nil, nil } body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("Failed to read response body: %s", err) logger.InsertLog(logger.Error, fmt.Sprintf("Failed to read response body: %s", err)) return nil, err } var agentIds []string if err := json.Unmarshal(body, &agentIds); err != nil { log.Printf("Failed to parse JSON response: %s", err) logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse JSON response: %s", err)) return nil, err } return agentIds, nil } func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ c, err := wsh.upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Error %s when upgrading connection to websocket", err) logger.InsertLog(logger.Error, fmt.Sprintf("Error %s when upgrading connection to websocket", err)) return } agentIP := r.URL.Query().Get("IPv4Address") agentId := r.URL.Query().Get("agentId") agentType := r.URL.Query().Get("agentType") hostName := r.URL.Query().Get("hostname") agentName := "" addPort := "" if len(r.URL.Query().Get("addPort")) > 0 { addPort = r.URL.Query().Get("addPort") } else { addPort = "None" } agentIds, err := getAgentIds() if err != nil { log.Printf("Error %v\n", err) return } if !containsId(agentIds, agentId) { agentName = randomname.GenerateRandomName() registerAgent(agentName, agentId, agentIP, agentType, addPort, hostName) } else { agentDetails, _ := getAgentDetails(agentId) agentName = agentDetails.AgentName } if agentName == "" || agentIP == "" { log.Printf("Missing agentName or IPv4Address in query parameters") logger.InsertLog(logger.Info, fmt.Sprintf("Missing agentName or IPv4Address in query parameters")) c.Close() return } log.Printf("Agent %s connected: %s (%s)", agentId, agentName, agentIP) logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s connected: %s (%s)", agentId, agentName, agentIP)) agentSocketsMutex.Lock() agentSockets[agentName] = c agentSocketsMutex.Unlock() defer func() { agentSocketsMutex.Lock() delete(agentSockets, agentName) agentSocketsMutex.Unlock() c.Close() log.Printf("Agent disconnected: %s (%s)", agentName, agentIP) logger.InsertLog(logger.Info, fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP)) }() for { _, message, err := c.ReadMessage() if err != nil { log.Printf("Error reading from agent %s: %v", agentName, err) logger.InsertLog(logger.Error, fmt.Sprintf("Error reading from agent %s: %v", agentName, err)) break } log.Printf("Message from agent %s: %s", agentName, message) logger.InsertLog(logger.Debug, fmt.Sprintf("Message from agent %s: %s", agentName, message)) if ch, ok := responseChannels.Load(agentName); ok { responseChan := ch.(chan string) responseChan <- string(message) } } } type Message struct { Type string `json:"type"` Payload string `json:"payload"` } var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Request){ err := r.ParseForm() if err != nil { http.Error(w, "Invalid form data", http.StatusBadRequest) logger.InsertLog(logger.Info, "Invalid form data") return } agentNameStr := r.FormValue("agentNames") var agentNames []string if agentNameStr != "" { agentNames = strings.Split(agentNameStr, ",") } else { agentName := r.FormValue("agentName") if agentName != "" { agentNames = []string{agentName} } } command := r.FormValue("command") if len(agentNames) == 0 || command == "" { http.Error(w, "Missing agent or command", http.StatusBadRequest) logger.InsertLog(logger.Error, "Missing agent or command") return } type result struct { AgentName string Type string Payload string Err error } resultsChan := make(chan result, len(agentNames)) for _, agentName := range agentNames { agentName := strings.TrimSpace(agentName) go func(agent string) { agentSocketsMutex.Lock() conn, ok := agentSockets[agentName] agentSocketsMutex.Unlock() if !ok { resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Agent not connected")} return } responseChan := make(chan string, 1) responseChannels.Store(agent, responseChan) defer responseChannels.Delete(agent) msg := Message { Type: "command", Payload: command, } msgBytes, _ := json.Marshal(msg) err := conn.WriteMessage(websocket.TextMessage, msgBytes) if err != nil { resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Send failed")} return } select { case resp := <- responseChan: var parsed map[string]string if err:= json.Unmarshal([]byte(resp), &parsed); err != nil { resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Invalid response")} return } payload, ok := parsed["payload"] if !ok { resultsChan <- result{AgentName: agent, Err: fmt.Errorf("No payload")} return } resultsChan <- result{AgentName: agent, Payload: payload} case <-time.After(10 * time.Second): resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Timeout")} } } (agentName) } var combined strings.Builder for i := 0; i < len(agentNames); i++ { res := <- resultsChan if res.Err != nil { combined.WriteString(fmt.Sprintf("\n[%s] ERROR: %s\n", res.AgentName, res.Err.Error())) } else { combined.WriteString(fmt.Sprintf("\n[%s]\n%s", res.AgentName, res.Payload)) } } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "text/plain") w.Write([]byte(combined.String())) } func Server() (*http.Server) { webSocketHandler := webSocketHandler { upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, } corsMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") // Allow the WebUI origin, this needs to be changed before prod w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, HX-Current-URL, HX-Request, HX-Target, HX-Trigger, HX-Trigger-Name") if r.Method == "OPTIONS" { // Handle preflight requests w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } webSocketMux := http.NewServeMux() webSocketMux.Handle("/data", webSocketHandler) webSocketMux.Handle("/executeCommand", corsMiddleware(http.HandlerFunc(executeCommand))) webSocketMux.Handle("/agentNames", corsMiddleware(http.HandlerFunc(getAgentNames))) websocketServer := &http.Server{ Addr: ":5555", Handler: webSocketMux, } return websocketServer }