gontrol/src/server/websocket/websocketServer.go

385 lines
11 KiB
Go

package websocketserver
import (
"database/sql"
"encoding/json"
"fmt"
"gontrol/src/logger"
"gontrol/src/randomname"
"gontrol/src/server/api"
"io"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/PuerkitoBio/goquery"
_ "github.com/go-sql-driver/mysql"
"github.com/gorilla/websocket"
)
var responseChannels sync.Map // Key: agentName, Value: chan string
var db *sql.DB
type webSocketHandler struct {
upgrader websocket.Upgrader
}
var agentSockets = make(map[string]*websocket.Conn)
var agentSocketsMutex sync.Mutex
var getAgentNames http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
agentSocketsMutex.Lock()
agentNames := make([]string, 0, len(agentSockets))
for agentName := range agentSockets {
agentNames = append(agentNames, agentName)
}
agentSocketsMutex.Unlock()
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
json.NewEncoder(w).Encode(agentNames)
}
func registerAgent(agentName, agentId, agentIp, agentType, addPort, hostname string) error {
registerURL := "http://localhost:3333/agents"
form := url.Values{}
form.Add("agentId", agentId)
form.Add("agentName", agentName)
form.Add("agentType", agentType)
form.Add("IPv4Address", agentIp)
form.Add("addPort", addPort)
form.Add("hostname", hostname)
resp, err := http.PostForm(registerURL, form)
if err != nil {
log.Printf("Error registering agent: %v", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Error registering agent: %v", err))
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusCreated {
log.Printf("Agent %s successfully registered.", agentName)
logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s successfully registered.", agentName))
return nil
} else if resp.StatusCode == http.StatusOK {
log.Printf("Agent %s already registered.", agentName)
logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s already registered.", agentName))
return nil
} else {
log.Printf("Failed to register agent, status: %v", resp.Status)
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to register agent, status: %v", resp.Status))
return err
}
}
func getAgentDetails(agentId string) (*api.Agent, error) {
agentURL := "http://localhost:3333/agents/" + agentId
resp, err := http.Get(agentURL)
if err != nil {
log.Printf("Failed to make GET request: %s", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err))
return nil, err
}
defer resp.Body.Close()
doc, err := goquery.NewDocumentFromReader(resp.Body)
if err != nil {
log.Printf("Failed to parse HTML: %s", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse HTML: %s", err))
return nil, err
}
agent := &api.Agent{}
doc.Find("#agent-detail p").Each(func(i int, s *goquery.Selection) {
text := s.Text()
if strings.HasPrefix(text, "ID:") {
agent.AgentID, err = strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(text, "ID:")))
if err != nil {
log.Printf("Converting string to integer failed in getAgentDetails(): %s", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Converting string to integer failed in getAgentDetails(): %s", err))
}
} else if strings.HasPrefix(text, "Name:") {
agent.AgentName = strings.TrimSpace(strings.TrimPrefix(text, "Name:"))
} else if strings.HasPrefix(text, "Type:") {
agent.AgentType = strings.TrimSpace(strings.TrimPrefix(text, "Type:"))
}
})
return agent, nil
}
func containsId(ids []string, agentId string) bool {
for _, id := range ids {
if id == agentId {
return true
}
}
return false
}
func getAgentIds() ([]string, error) {
idURL := "http://localhost:3333/agentIds"
resp, err := http.Get(idURL)
if err != nil {
log.Printf("Failed to make GET request: %s", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err))
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Unexpected status code: %d", resp.StatusCode)
logger.InsertLog(logger.Info, fmt.Sprintf("Unexpected status code: %d", resp.StatusCode))
return nil, nil
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read response body: %s", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to read response body: %s", err))
return nil, err
}
var agentIds []string
if err := json.Unmarshal(body, &agentIds); err != nil {
log.Printf("Failed to parse JSON response: %s", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse JSON response: %s", err))
return nil, err
}
return agentIds, nil
}
func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){
c, err := wsh.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Error %s when upgrading connection to websocket", err)
logger.InsertLog(logger.Error, fmt.Sprintf("Error %s when upgrading connection to websocket", err))
return
}
agentIP := r.URL.Query().Get("IPv4Address")
agentId := r.URL.Query().Get("agentId")
agentType := r.URL.Query().Get("agentType")
hostName := r.URL.Query().Get("hostname")
agentName := ""
addPort := ""
if len(r.URL.Query().Get("addPort")) > 0 {
addPort = r.URL.Query().Get("addPort")
} else {
addPort = "None"
}
agentIds, err := getAgentIds()
if err != nil {
log.Printf("Error %v\n", err)
return
}
if !containsId(agentIds, agentId) {
agentName = randomname.GenerateRandomName()
registerAgent(agentName, agentId, agentIP, agentType, addPort, hostName)
} else {
agentDetails, _ := getAgentDetails(agentId)
agentName = agentDetails.AgentName
}
if agentName == "" || agentIP == "" {
log.Printf("Missing agentName or IPv4Address in query parameters")
logger.InsertLog(logger.Info, fmt.Sprintf("Missing agentName or IPv4Address in query parameters"))
c.Close()
return
}
log.Printf("Agent %s connected: %s (%s)", agentId, agentName, agentIP)
logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s connected: %s (%s)", agentId, agentName, agentIP))
agentSocketsMutex.Lock()
agentSockets[agentName] = c
agentSocketsMutex.Unlock()
defer func() {
agentSocketsMutex.Lock()
delete(agentSockets, agentName)
agentSocketsMutex.Unlock()
c.Close()
log.Printf("Agent disconnected: %s (%s)", agentName, agentIP)
logger.InsertLog(logger.Info, fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP))
}()
for {
_, message, err := c.ReadMessage()
if err != nil {
log.Printf("Error reading from agent %s: %v", agentName, err)
logger.InsertLog(logger.Error, fmt.Sprintf("Error reading from agent %s: %v", agentName, err))
break
}
log.Printf("Message from agent %s: %s", agentName, message)
logger.InsertLog(logger.Debug, fmt.Sprintf("Message from agent %s: %s", agentName, message))
if ch, ok := responseChannels.Load(agentName); ok {
responseChan := ch.(chan string)
responseChan <- string(message)
}
}
}
type Message struct {
Type string `json:"type"`
Payload string `json:"payload"`
}
var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Request){
err := r.ParseForm()
if err != nil {
http.Error(w, "Invalid form data", http.StatusBadRequest)
logger.InsertLog(logger.Info, "Invalid form data")
return
}
agentNameStr := r.FormValue("agentNames")
var agentNames []string
if agentNameStr != "" {
agentNames = strings.Split(agentNameStr, ",")
} else {
agentName := r.FormValue("agentName")
if agentName != "" {
agentNames = []string{agentName}
}
}
command := r.FormValue("command")
if len(agentNames) == 0 || command == "" {
http.Error(w, "Missing agent or command", http.StatusBadRequest)
logger.InsertLog(logger.Error, "Missing agent or command")
return
}
type result struct {
AgentName string
Type string
Payload string
Err error
}
resultsChan := make(chan result, len(agentNames))
for _, agentName := range agentNames {
agentName := strings.TrimSpace(agentName)
go func(agent string) {
agentSocketsMutex.Lock()
conn, ok := agentSockets[agentName]
agentSocketsMutex.Unlock()
if !ok {
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Agent not connected")}
return
}
responseChan := make(chan string, 1)
responseChannels.Store(agent, responseChan)
defer responseChannels.Delete(agent)
msg := Message {
Type: "command",
Payload: command,
}
msgBytes, _ := json.Marshal(msg)
err := conn.WriteMessage(websocket.TextMessage, msgBytes)
if err != nil {
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Send failed")}
return
}
select {
case resp := <- responseChan:
var parsed map[string]string
if err:= json.Unmarshal([]byte(resp), &parsed); err != nil {
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Invalid response")}
return
}
payload, ok := parsed["payload"]
if !ok {
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("No payload")}
return
}
resultsChan <- result{AgentName: agent, Payload: payload}
case <-time.After(10 * time.Second):
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Timeout")}
}
} (agentName)
}
var combined strings.Builder
for i := 0; i < len(agentNames); i++ {
res := <- resultsChan
if res.Err != nil {
combined.WriteString(fmt.Sprintf("\n[%s] ERROR: %s\n", res.AgentName, res.Err.Error()))
} else {
combined.WriteString(fmt.Sprintf("\n[%s]\n%s", res.AgentName, res.Payload))
}
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte(combined.String()))
}
func Server() (*http.Server) {
webSocketHandler := webSocketHandler {
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
corsMiddleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*") // Allow the WebUI origin, this needs to be changed before prod
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, HX-Current-URL, HX-Request, HX-Target, HX-Trigger, HX-Trigger-Name")
if r.Method == "OPTIONS" {
// Handle preflight requests
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
webSocketMux := http.NewServeMux()
webSocketMux.Handle("/data", webSocketHandler)
webSocketMux.Handle("/executeCommand", corsMiddleware(http.HandlerFunc(executeCommand)))
webSocketMux.Handle("/agentNames", corsMiddleware(http.HandlerFunc(getAgentNames)))
websocketServer := &http.Server{
Addr: ":5555",
Handler: webSocketMux,
}
return websocketServer
}