diff --git a/agents/agent.go b/agents/agent.go index 54d7eb8..c9b919c 100644 --- a/agents/agent.go +++ b/agents/agent.go @@ -74,28 +74,63 @@ func registerAgent(agentName string, agentId string, agentIp string) error { } -func connectToWebSocket(agentName string, agentIp string) error { +func connectToWebSocket(agentName, agentIp string) error { wsURL := fmt.Sprintf("ws://%s/data?agentName=%s&IPv4Address=%s", webSocketAddr, url.QueryEscape(agentName), url.QueryEscape(agentIp)) var err error - conn, _, err = websocket.DefaultDialer.Dial(wsURL, nil) - if err != nil { - return fmt.Errorf("Failed to connect to WebSocket: %v", err) - } + for { + conn, _, err = websocket.DefaultDialer.Dial(wsURL, nil) + if err == nil { + log.Println("WeSocket connection established") + return nil + } - log.Println("WebSocket connection established") - return nil + log.Printf("Failed to connect to WebSocket: %v. Retrying in 5 seconds...", err) + time.Sleep(5 * time.Second) + } } -func listenForCommands() { +func reconnectToWebSocket(agentName, agentIp string) error { + backoff := 2 * time.Second + maxBackoff := 1 * time.Minute + + for { + log.Println("Attempting to reconnect to WebSocket...") + err := connectToWebSocket(agentName, agentIp) + if err == nil { + log.Println("Reconnection succesful.") + return nil + } + + log.Printf("Reconnection failed: %v", err) + + time.Sleep(backoff) + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } +} + +func listenForCommands(agentName, agentIp string) { defer conn.Close() for { _, rawMessage, err := conn.ReadMessage() if err != nil { - log.Printf("Error reading message: %v", err) - break + log.Printf("Connection lost: %v", err) + if reconnectErr := reconnectToWebSocket(agentName, agentIp); reconnectErr != nil { + log.Printf("Critical error during reconnection: %v", reconnectErr) + } + continue } + // for { + // _, rawMessage, err := conn.ReadMessage() + // if err != nil { + // log.Printf("Error reading message: %v", err) + // break + // } + var message Message if err := json.Unmarshal(rawMessage, &message); err != nil { log.Printf("Error unmarshalling message: %v", err) @@ -154,5 +189,5 @@ func main() { log.Fatalf("Websocket connection failed: %v", err) } - listenForCommands() + listenForCommands(agentName, agentIp) }