385 lines
11 KiB
Go
385 lines
11 KiB
Go
package websocketserver
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"gontrol/src/logger"
|
|
"gontrol/src/randomname"
|
|
"gontrol/src/server/api"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/PuerkitoBio/goquery"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var responseChannels sync.Map // Key: agentName, Value: chan string
|
|
var db *sql.DB
|
|
|
|
type webSocketHandler struct {
|
|
upgrader websocket.Upgrader
|
|
}
|
|
|
|
var agentSockets = make(map[string]*websocket.Conn)
|
|
var agentSocketsMutex sync.Mutex
|
|
|
|
var getAgentNames http.HandlerFunc = func(w http.ResponseWriter, r *http.Request) {
|
|
agentSocketsMutex.Lock()
|
|
agentNames := make([]string, 0, len(agentSockets))
|
|
for agentName := range agentSockets {
|
|
agentNames = append(agentNames, agentName)
|
|
}
|
|
agentSocketsMutex.Unlock()
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
|
json.NewEncoder(w).Encode(agentNames)
|
|
|
|
}
|
|
|
|
func registerAgent(agentName, agentId, agentIp, agentType, addPort, hostname string) error {
|
|
|
|
registerURL := "http://localhost:3333/agents"
|
|
|
|
form := url.Values{}
|
|
form.Add("agentId", agentId)
|
|
form.Add("agentName", agentName)
|
|
form.Add("agentType", agentType)
|
|
form.Add("IPv4Address", agentIp)
|
|
form.Add("addPort", addPort)
|
|
form.Add("hostname", hostname)
|
|
|
|
resp, err := http.PostForm(registerURL, form)
|
|
if err != nil {
|
|
log.Printf("Error registering agent: %v", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Error registering agent: %v", err))
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusCreated {
|
|
log.Printf("Agent %s successfully registered.", agentName)
|
|
logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s successfully registered.", agentName))
|
|
return nil
|
|
} else if resp.StatusCode == http.StatusOK {
|
|
log.Printf("Agent %s already registered.", agentName)
|
|
logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s already registered.", agentName))
|
|
return nil
|
|
} else {
|
|
log.Printf("Failed to register agent, status: %v", resp.Status)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to register agent, status: %v", resp.Status))
|
|
return err
|
|
}
|
|
|
|
}
|
|
|
|
func getAgentDetails(agentId string) (*api.Agent, error) {
|
|
agentURL := "http://localhost:3333/agents/" + agentId
|
|
|
|
resp, err := http.Get(agentURL)
|
|
if err != nil {
|
|
log.Printf("Failed to make GET request: %s", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err))
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
doc, err := goquery.NewDocumentFromReader(resp.Body)
|
|
if err != nil {
|
|
log.Printf("Failed to parse HTML: %s", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse HTML: %s", err))
|
|
return nil, err
|
|
}
|
|
|
|
agent := &api.Agent{}
|
|
doc.Find("#agent-detail p").Each(func(i int, s *goquery.Selection) {
|
|
text := s.Text()
|
|
if strings.HasPrefix(text, "ID:") {
|
|
agent.AgentID, err = strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(text, "ID:")))
|
|
if err != nil {
|
|
log.Printf("Converting string to integer failed in getAgentDetails(): %s", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Converting string to integer failed in getAgentDetails(): %s", err))
|
|
}
|
|
} else if strings.HasPrefix(text, "Name:") {
|
|
agent.AgentName = strings.TrimSpace(strings.TrimPrefix(text, "Name:"))
|
|
} else if strings.HasPrefix(text, "Type:") {
|
|
agent.AgentType = strings.TrimSpace(strings.TrimPrefix(text, "Type:"))
|
|
}
|
|
})
|
|
return agent, nil
|
|
}
|
|
|
|
func containsId(ids []string, agentId string) bool {
|
|
for _, id := range ids {
|
|
if id == agentId {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func getAgentIds() ([]string, error) {
|
|
idURL := "http://localhost:3333/agentIds"
|
|
|
|
resp, err := http.Get(idURL)
|
|
if err != nil {
|
|
log.Printf("Failed to make GET request: %s", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err))
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
log.Printf("Unexpected status code: %d", resp.StatusCode)
|
|
logger.InsertLog(logger.Info, fmt.Sprintf("Unexpected status code: %d", resp.StatusCode))
|
|
return nil, nil
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
log.Printf("Failed to read response body: %s", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to read response body: %s", err))
|
|
return nil, err
|
|
}
|
|
|
|
var agentIds []string
|
|
if err := json.Unmarshal(body, &agentIds); err != nil {
|
|
log.Printf("Failed to parse JSON response: %s", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse JSON response: %s", err))
|
|
return nil, err
|
|
}
|
|
|
|
return agentIds, nil
|
|
}
|
|
|
|
|
|
func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){
|
|
c, err := wsh.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Error %s when upgrading connection to websocket", err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Error %s when upgrading connection to websocket", err))
|
|
return
|
|
}
|
|
|
|
agentIP := r.URL.Query().Get("IPv4Address")
|
|
agentId := r.URL.Query().Get("agentId")
|
|
agentType := r.URL.Query().Get("agentType")
|
|
hostName := r.URL.Query().Get("hostname")
|
|
agentName := ""
|
|
|
|
addPort := ""
|
|
if len(r.URL.Query().Get("addPort")) > 0 {
|
|
addPort = r.URL.Query().Get("addPort")
|
|
} else {
|
|
addPort = "None"
|
|
}
|
|
|
|
agentIds, err := getAgentIds()
|
|
if err != nil {
|
|
log.Printf("Error %v\n", err)
|
|
return
|
|
}
|
|
|
|
if !containsId(agentIds, agentId) {
|
|
agentName = randomname.GenerateRandomName()
|
|
registerAgent(agentName, agentId, agentIP, agentType, addPort, hostName)
|
|
} else {
|
|
agentDetails, _ := getAgentDetails(agentId)
|
|
agentName = agentDetails.AgentName
|
|
}
|
|
|
|
if agentName == "" || agentIP == "" {
|
|
log.Printf("Missing agentName or IPv4Address in query parameters")
|
|
logger.InsertLog(logger.Info, fmt.Sprintf("Missing agentName or IPv4Address in query parameters"))
|
|
c.Close()
|
|
return
|
|
}
|
|
|
|
log.Printf("Agent %s connected: %s (%s)", agentId, agentName, agentIP)
|
|
logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s connected: %s (%s)", agentId, agentName, agentIP))
|
|
|
|
agentSocketsMutex.Lock()
|
|
agentSockets[agentName] = c
|
|
agentSocketsMutex.Unlock()
|
|
|
|
defer func() {
|
|
agentSocketsMutex.Lock()
|
|
delete(agentSockets, agentName)
|
|
agentSocketsMutex.Unlock()
|
|
c.Close()
|
|
log.Printf("Agent disconnected: %s (%s)", agentName, agentIP)
|
|
logger.InsertLog(logger.Info, fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP))
|
|
}()
|
|
|
|
for {
|
|
_, message, err := c.ReadMessage()
|
|
if err != nil {
|
|
log.Printf("Error reading from agent %s: %v", agentName, err)
|
|
logger.InsertLog(logger.Error, fmt.Sprintf("Error reading from agent %s: %v", agentName, err))
|
|
break
|
|
}
|
|
log.Printf("Message from agent %s: %s", agentName, message)
|
|
logger.InsertLog(logger.Debug, fmt.Sprintf("Message from agent %s: %s", agentName, message))
|
|
|
|
if ch, ok := responseChannels.Load(agentName); ok {
|
|
responseChan := ch.(chan string)
|
|
responseChan <- string(message)
|
|
}
|
|
}
|
|
}
|
|
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
Payload string `json:"payload"`
|
|
}
|
|
|
|
var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Request){
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
http.Error(w, "Invalid form data", http.StatusBadRequest)
|
|
logger.InsertLog(logger.Info, "Invalid form data")
|
|
return
|
|
}
|
|
|
|
agentNameStr := r.FormValue("agentNames")
|
|
var agentNames []string
|
|
|
|
if agentNameStr != "" {
|
|
agentNames = strings.Split(agentNameStr, ",")
|
|
} else {
|
|
agentName := r.FormValue("agentName")
|
|
if agentName != "" {
|
|
agentNames = []string{agentName}
|
|
}
|
|
}
|
|
|
|
command := r.FormValue("command")
|
|
|
|
if len(agentNames) == 0 || command == "" {
|
|
http.Error(w, "Missing agent or command", http.StatusBadRequest)
|
|
logger.InsertLog(logger.Error, "Missing agent or command")
|
|
return
|
|
}
|
|
|
|
type result struct {
|
|
AgentName string
|
|
Type string
|
|
Payload string
|
|
Err error
|
|
}
|
|
|
|
resultsChan := make(chan result, len(agentNames))
|
|
|
|
for _, agentName := range agentNames {
|
|
agentName := strings.TrimSpace(agentName)
|
|
|
|
go func(agent string) {
|
|
agentSocketsMutex.Lock()
|
|
conn, ok := agentSockets[agentName]
|
|
agentSocketsMutex.Unlock()
|
|
|
|
if !ok {
|
|
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Agent not connected")}
|
|
return
|
|
}
|
|
|
|
responseChan := make(chan string, 1)
|
|
responseChannels.Store(agent, responseChan)
|
|
defer responseChannels.Delete(agent)
|
|
|
|
msg := Message {
|
|
Type: "command",
|
|
Payload: command,
|
|
}
|
|
|
|
msgBytes, _ := json.Marshal(msg)
|
|
err := conn.WriteMessage(websocket.TextMessage, msgBytes)
|
|
if err != nil {
|
|
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Send failed")}
|
|
return
|
|
}
|
|
|
|
select {
|
|
case resp := <- responseChan:
|
|
var parsed map[string]string
|
|
if err:= json.Unmarshal([]byte(resp), &parsed); err != nil {
|
|
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Invalid response")}
|
|
return
|
|
}
|
|
|
|
payload, ok := parsed["payload"]
|
|
if !ok {
|
|
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("No payload")}
|
|
return
|
|
}
|
|
|
|
resultsChan <- result{AgentName: agent, Payload: payload}
|
|
case <-time.After(10 * time.Second):
|
|
resultsChan <- result{AgentName: agent, Err: fmt.Errorf("Timeout")}
|
|
}
|
|
} (agentName)
|
|
}
|
|
|
|
var combined strings.Builder
|
|
for i := 0; i < len(agentNames); i++ {
|
|
res := <- resultsChan
|
|
if res.Err != nil {
|
|
combined.WriteString(fmt.Sprintf("\n[%s] ERROR: %s\n", res.AgentName, res.Err.Error()))
|
|
} else {
|
|
combined.WriteString(fmt.Sprintf("\n[%s]\n%s", res.AgentName, res.Payload))
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.Write([]byte(combined.String()))
|
|
|
|
|
|
}
|
|
|
|
|
|
func Server() (*http.Server) {
|
|
webSocketHandler := webSocketHandler {
|
|
upgrader: websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
},
|
|
}
|
|
|
|
corsMiddleware := func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*") // Allow the WebUI origin, this needs to be changed before prod
|
|
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, HX-Current-URL, HX-Request, HX-Target, HX-Trigger, HX-Trigger-Name")
|
|
if r.Method == "OPTIONS" {
|
|
// Handle preflight requests
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
webSocketMux := http.NewServeMux()
|
|
webSocketMux.Handle("/data", webSocketHandler)
|
|
webSocketMux.Handle("/executeCommand", corsMiddleware(http.HandlerFunc(executeCommand)))
|
|
webSocketMux.Handle("/agentNames", corsMiddleware(http.HandlerFunc(getAgentNames)))
|
|
websocketServer := &http.Server{
|
|
Addr: ":5555",
|
|
Handler: webSocketMux,
|
|
}
|
|
|
|
return websocketServer
|
|
}
|