diff --git a/main.go b/main.go index 895b791..9534568 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,7 @@ import ( "os" "strings" "time" - + "slices" "database/sql" "fmt" "html/template" @@ -118,17 +118,11 @@ func listAgents(w http.ResponseWriter, r *http.Request) { agents, err := api.GetAgents(db) currentAgents := getAgentsStatus() - for _, currAgent := range currentAgents { - for i, agent := range agents { - if currAgent == agent.AgentName { - // log.Printf("%s online", agent.AgentName) - // logger.InsertLog(logger.Debug, fmt.Sprintf("%s online after page refresh", agent.AgentName)) - // agents[i].Status = fmt.Sprint("Connected") - agents[i].Status = "Connected" - } else { - // agent.Status = fmt.Sprintf("Disconnected") - agents[i].Status = "Disconnected" - } + for i := range agents { + if slices.Contains(currentAgents, agents[i].AgentName) { + agents[i].Status = "Connected" + } else { + agents[i].Status = "Disconnected" } } diff --git a/src/server/websocket/websocketServer.go b/src/server/websocket/websocketServer.go index ed127aa..7e373c2 100644 --- a/src/server/websocket/websocketServer.go +++ b/src/server/websocket/websocketServer.go @@ -235,6 +235,67 @@ type Message struct { Payload string `json:"payload"` } +// var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Request){ +// err := r.ParseForm() +// if err != nil { +// http.Error(w, "Invalid form data", http.StatusBadRequest) +// logger.InsertLog(logger.Info, "Invalid form data") +// return +// } + +// agentName := r.FormValue("agentName") +// command := r.FormValue("command") + +// agentSocketsMutex.Lock() +// conn, ok := agentSockets[agentName] +// agentSocketsMutex.Unlock() + +// if !ok { +// http.Error(w, "Agent not connected", http.StatusNotFound) +// logger.InsertLog(logger.Info, "Agent not connected") +// return +// } + +// responseChan := make(chan string, 1) +// responseChannels.Store(agentName, responseChan) +// defer responseChannels.Delete(agentName) + +// message := Message { +// Type: "command", +// Payload: command, +// } + +// messageBytes, _ := json.Marshal(message) + +// err = conn.WriteMessage(websocket.TextMessage, messageBytes) +// if err != nil { +// http.Error(w, "Failed to send command to the agent", http.StatusInternalServerError) +// logger.InsertLog(logger.Error, "Failed to send command to the agent") +// return +// } + +// select { +// case response := <-responseChan: +// var parsedResponse map[string]string +// if err := json.Unmarshal([]byte(response), &parsedResponse); err != nil { +// http.Error(w, "Failed to parse response", http.StatusInternalServerError) +// return +// } +// payload, ok := parsedResponse["payload"] +// if !ok { +// http.Error(w, "Invalid response structure", http.StatusInternalServerError) +// logger.InsertLog(logger.Error, "Invalid response structure") +// return +// } +// w.WriteHeader(http.StatusOK) +// w.Header().Set("Content-Type", "text/plain") +// w.Write([]byte(payload)) +// case <- time.After(10 * time.Second): +// http.Error(w, "Agent response timed out", http.StatusGatewayTimeout) +// logger.InsertLog(logger.Info, "Agent response timed out") +// } +// } + var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Request){ err := r.ParseForm() if err != nil { @@ -243,59 +304,103 @@ var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Reques return } - agentName := r.FormValue("agentName") + agentNameStr := r.FormValue("agentNames") + var agentNames []string + + if agentNameStr != "" { + agentNames = strings.Split(agentNameStr, ",") + } else { + agentName := r.FormValue("agentName") + if agentName != "" { + agentNames = []string{agentName} + } + } + command := r.FormValue("command") - agentSocketsMutex.Lock() - conn, ok := agentSockets[agentName] - agentSocketsMutex.Unlock() - - if !ok { - http.Error(w, "Agent not connected", http.StatusNotFound) - logger.InsertLog(logger.Info, "Agent not connected") + if len(agentNames) == 0 || command == "" { + http.Error(w, "Missing agent or command", http.StatusBadRequest) + logger.InsertLog(logger.Error, "Missing agent or command") return } - responseChan := make(chan string, 1) - responseChannels.Store(agentName, responseChan) - defer responseChannels.Delete(agentName) - - message := Message { - Type: "command", - Payload: command, + type result struct { + AgentName string + Type string + Payload string + Err error } - messageBytes, _ := json.Marshal(message) + resultsChan := make(chan result, len(agentNames)) - err = conn.WriteMessage(websocket.TextMessage, messageBytes) - if err != nil { - http.Error(w, "Failed to send command to the agent", http.StatusInternalServerError) - logger.InsertLog(logger.Error, "Failed to send command to the agent") - return + for _, agentName := range agentNames { + agentName := strings.TrimSpace(agentName) + + go func(agent string) { + agentSocketsMutex.Lock() + conn, ok := agentSockets[agentName] + agentSocketsMutex.Unlock() + + if !ok { + resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Agent not connected")} + return + } + + responseChan := make(chan string, 1) + responseChannels.Store(agent, responseChan) + defer responseChannels.Delete(agent) + + msg := Message { + Type: "command", + Payload: command, + } + + msgBytes, _ := json.Marshal(msg) + err := conn.WriteMessage(websocket.TextMessage, msgBytes) + if err != nil { + resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Send failed")} + return + } + + select { + case resp := <- responseChan: + var parsed map[string]string + if err:= json.Unmarshal([]byte(resp), &parsed); err != nil { + resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Invalid response")} + return + } + + payload, ok := parsed["payload"] + if !ok { + resultsChan <- result{AgentName: agent, Err: fmt.Errorf("No payload")} + return + } + + resultsChan <- result{AgentName: agent, Payload: payload} + case <-time.After(10 * time.Second): + resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Timeout")} + } + } (agentName) } - select { - case response := <-responseChan: - var parsedResponse map[string]string - if err := json.Unmarshal([]byte(response), &parsedResponse); err != nil { - http.Error(w, "Failed to parse response", http.StatusInternalServerError) - return - } - payload, ok := parsedResponse["payload"] - if !ok { - http.Error(w, "Invalid response structure", http.StatusInternalServerError) - logger.InsertLog(logger.Error, "Invalid response structure") - return + var combined strings.Builder + for i := 0; i < len(agentNames); i++ { + res := <- resultsChan + if res.Err != nil { + combined.WriteString(fmt.Sprintf("[%s] ERROR: %s\n", res.AgentName, res.Err.Error())) + } else { + combined.WriteString(fmt.Sprintf("[%s] %s\n", res.AgentName, res.Payload)) + } } + w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "text/plain") - w.Write([]byte(payload)) - case <- time.After(10 * time.Second): - http.Error(w, "Agent response timed out", http.StatusGatewayTimeout) - logger.InsertLog(logger.Info, "Agent response timed out") - } + w.Write([]byte(combined.String())) + + } + func Server() (*http.Server) { webSocketHandler := webSocketHandler { upgrader: websocket.Upgrader{ diff --git a/templates/index.html b/templates/index.html index 30c094c..a261aa2 100644 --- a/templates/index.html +++ b/templates/index.html @@ -10,101 +10,84 @@