diff --git a/agents/agent.go b/agents/agent.go index 4a2b8cb..4e5cbf0 100644 --- a/agents/agent.go +++ b/agents/agent.go @@ -35,6 +35,7 @@ type Agent struct { type Message struct { Type string `json:"type"` + Level string `json:"level"` Payload string `json:"payload"` } @@ -84,6 +85,12 @@ func reconnectToWebSocket(agentName, agentId, agentIp, agentType, hostName strin backoff := 2 * time.Second maxBackoff := 1 * time.Minute + err := sendLog(conn) + + if err != nil { + log.Println("sendLog() error") + } + for { log.Println("Attempting to reconnect to WebSocket...") err := connectToWebSocket(agentName, agentId, agentIp, agentType, hostName) @@ -102,11 +109,28 @@ func reconnectToWebSocket(agentName, agentId, agentIp, agentType, hostName strin } } +func sendLog(conn *websocket.Conn) error { + + response := Message{ + Type: "response", + Level: "Fatal", + Payload: "Remote logging is working!", + } + responseBytes, _ := json.Marshal(response) + if err := conn.WriteMessage(websocket.TextMessage, responseBytes); err != nil { + log.Printf("Error sending output: %v", err) + return err + } + + return nil +} + func listenForCommands(agentName, agentId, agentIp, agentType, hostName string) { defer conn.Close() for { _, rawMessage, err := conn.ReadMessage() + if err != nil { log.Printf("Connection lost: %v", err) if reconnectErr := reconnectToWebSocket(agentName, agentId, agentIp, agentType, hostName); reconnectErr != nil { @@ -203,5 +227,10 @@ func main() { log.Fatalf("Websocket connection failed: %v", err) } + err := sendLog(conn) + if err != nil { + log.Println("sendLog() error") + } + listenForCommands(agentName, agentId, agentIp, agentType, hostName) } diff --git a/main.go b/main.go index 486bb26..ef7ac67 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,9 @@ import ( "gontrol/src/logger" "gontrol/src/server/database" "gontrol/src/server/webapp" - "gontrol/src/server/websocket" + websocketserver "gontrol/src/server/websocket" + + // _ "net/http/pprof" ) func main() { @@ -24,12 +26,13 @@ func main() { db := database.InitSQLiteDB("/tmp/gontrol_agents.db") defer db.Close() - if err := logger.InitDB("/tmp/gontrol_logs.db"); err != nil { + logService, err := logger.Init("/tmp/gontrol_logs.db") + if err != nil { log.Fatalf("Init log db: %v", err) } - defer logger.CloseDB() + defer logService.Close() - app := &webapp.App{Tmpl: tmpl, DB: db} + app := &webapp.App{Tmpl: tmpl, DB: db, Logger: logService} srv := &http.Server { Addr: ":3333", @@ -38,6 +41,10 @@ func main() { WriteTimeout: 10 * time.Second, } + // go func () { + // log.Println(http.ListenAndServe("localhost:6060", nil)) + // }() + go func() { log.Println("Web server is running on port :3333") if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { @@ -48,14 +55,12 @@ func main() { go func() { logLine := "Websocket server is running on port :5555" log.Println(logLine) + websocketserver.SetLogger(logService) websocketServer := websocketserver.Server() if err := websocketServer.ListenAndServe(); err != http.ErrServerClosed { log.Fatalf("Websocket server failed: %s", err) } - err := logger.InsertLog(logger.Info, logLine) - if err != nil { - log.Println("Error inserting log:", err) - } + app.Logger.Log(logger.Info, logLine) }() stop := make(chan os.Signal, 1) diff --git a/src/logger/logger.go b/src/logger/logger.go index 3988602..abf8a79 100644 --- a/src/logger/logger.go +++ b/src/logger/logger.go @@ -3,8 +3,6 @@ package logger import ( "database/sql" "fmt" - "log" - // "net/http" "strings" "sync" "time" @@ -12,14 +10,16 @@ import ( _ "github.com/mattn/go-sqlite3" ) - -var ( - Lite_db *sql.DB - lite_dbMutex sync.Mutex - - logMutex sync.Mutex - logLimit = 100 -) +type LoggerInterface interface { + Log(level LogLevel, message string) + FetchLogs(limit int, levels []string) ([]LogEntry, error) + Error(format string, args ...any) + Info(format string, args ...any) + Warning(format string, args ...any) + Debug(format string, args ...any) + Fatal(format string, args ...any) + Close() error +} const ( Debug LogLevel = "debug" @@ -37,16 +37,20 @@ type LogEntry struct { Level LogLevel } -func ToLog(logLine string) string { - log := fmt.Sprintf("%s",time.Now().Format(time.RFC3339) + " " + logLine) - return log +type Logger struct { + db *sql.DB + lock sync.Mutex } -func InitDB(dbPath string) error { - var err error - Lite_db, err = sql.Open("sqlite3", dbPath) +func New(db *sql.DB) *Logger { + return &Logger{db: db} +} + + +func Init(dbPath string) (*Logger, error) { + db, err := sql.Open("sqlite3", dbPath) if err != nil { - return fmt.Errorf("Error opening DB: %w", err) + return nil, fmt.Errorf("Error opening DB: %w", err) } CreateTableQuery := `CREATE TABLE IF NOT EXISTS logs ( @@ -56,58 +60,55 @@ func InitDB(dbPath string) error { level TEXT );` - _, err = Lite_db.Exec(CreateTableQuery) - if err != nil { - return fmt.Errorf("Error creating table: %w", err) + if _, err = db.Exec(CreateTableQuery); err != nil { + return nil, err } - return nil + return New(db), nil } -func InsertLog(level LogLevel, message string) error { - lite_dbMutex.Lock() - defer lite_dbMutex.Unlock() +func (l *Logger) InsertLog(level LogLevel, message string) error { + l.lock.Lock() + defer l.lock.Unlock() - // Future use may fulfill multiple transactions - tx, err := Lite_db.Begin() + tx, err := l.db.Begin() if err != nil { - return fmt.Errorf("Error starting transaction: %w", err) + return fmt.Errorf("InsertLog: start tx: %w", err) } + insertQuery := `INSERT INTO logs (message, level) VALUES(?, ?)` message = strings.ReplaceAll(message, `"`, `\"`) - insertQuery := `INSERT INTO logs (message, level) VALUES (?, ?)` _, err = tx.Exec(insertQuery, message, level) - if err != nil { tx.Rollback() - return fmt.Errorf("Error inserting log: %v", err) + return fmt.Errorf("InsertLog: exec: %w", err) } - err = tx.Commit() - if err != nil { - return fmt.Errorf("Error committing transaction: %w", err) + if err := tx.Commit(); err != nil { + return fmt.Errorf("InsertLog: commit: %w", err) } return nil } -func FetchLogs(limit int, levels []string) ([]LogEntry, error) { - lite_dbMutex.Lock() - defer lite_dbMutex.Unlock() +func (l *Logger) FetchLogs(limit int, levels []string) ([]LogEntry, error) { + l.lock.Lock() + defer l.lock.Unlock() if len(levels) == 0 { levels = []string{"%"} } - var args []interface{} + var args[] interface {} placeholders := make([]string, len(levels)) - - for i, level := range levels { + for i, level := range levels{ placeholders[i] = "level LIKE ?" args = append(args, level) } + args = append(args, limit) + query := fmt.Sprintf(` SELECT timestamp, level, message FROM logs @@ -115,31 +116,54 @@ func FetchLogs(limit int, levels []string) ([]LogEntry, error) { ORDER BY timestamp DESC LIMIT ?`, strings.Join(placeholders, " OR ")) - args = append(args, limit) - - rows, err := Lite_db.Query(query, args...) + rows, err := l.db.Query(query, args...) if err != nil { - return nil, fmt.Errorf("Error fetching logs: %w", err) + return nil, fmt.Errorf("FetchLogs: query: %w", err) } defer rows.Close() var logs []LogEntry for rows.Next() { - var logEntry LogEntry - if err := rows.Scan( &logEntry.Timestamp, &logEntry.Level, &logEntry.Message); err != nil { - return nil, fmt.Errorf("Error scanning row: %w", err) + var entry LogEntry + if err := rows.Scan(&entry.Timestamp, &entry.Level, &entry.Message); err != nil { + return nil, fmt.Errorf("FetchLogs: scan: %w", err) } - logs = append(logs, logEntry) + logs = append(logs, entry) } return logs, nil } -func CloseDB() { - if Lite_db != nil { - err := Lite_db.Close() - if err != nil { - log.Printf("Error closing database: %v", err) - } - } + +func (l * Logger) Close() error { + return l.db.Close() +} + +func ToLog(logLine string) string { + log := fmt.Sprintf("%s",time.Now().Format(time.RFC3339) + " " + logLine) + return log +} + +func (l *Logger) Log(level LogLevel, message string) { + _ = l.InsertLog(level, message) +} + +func (l *Logger) Error(format string, args ...any) { + l.Log(Error, fmt.Sprintf(format, args...)) +} + +func (l *Logger) Info(format string, args ...any) { + l.Log(Info, fmt.Sprintf(format, args...)) +} + +func (l *Logger) Warning(format string, args ...any) { + l.Log(Warning, fmt.Sprintf(format, args...)) +} + +func (l *Logger) Debug(format string, args ...any) { + l.Log(Debug, fmt.Sprintf(format, args...)) +} + +func (l *Logger) Fatal(format string, args ...any) { + l.Log(Fatal, fmt.Sprintf(format, args...)) } diff --git a/src/server/webapp/handlers.go b/src/server/webapp/handlers.go index 3d6dd7e..935407b 100644 --- a/src/server/webapp/handlers.go +++ b/src/server/webapp/handlers.go @@ -25,6 +25,8 @@ import ( type App struct { Tmpl *template.Template DB *sql.DB + // Logger *logger.Logger + Logger logger.LoggerInterface } func (a *App) renderTemplate(w http.ResponseWriter, name string, data any) { @@ -139,7 +141,7 @@ func (a *App) logsHandler(w http.ResponseWriter, r *http.Request) { } // Call the police... I mean logger - logs, err := logger.FetchLogs(limit, levels) + logs, err := a.Logger.FetchLogs(limit, levels) if err != nil { http.Error(w, "Error fetching logs", http.StatusInternalServerError) @@ -165,7 +167,7 @@ func (a *App) getAgentsStatus() []string { resp, err := http.Get("http://localhost:5555/agentNames") if err != nil { log.Println("Error fetching agent names:", err) - logger.InsertLog(logger.Error, "Error fetching agent names from websocketServer") + a.Logger.Log(logger.Error, "Error fetching agent names from websocketServer") } defer resp.Body.Close() diff --git a/src/server/webapp/handlers_test.go b/src/server/webapp/handlers_test.go index f8036bc..100acc8 100644 --- a/src/server/webapp/handlers_test.go +++ b/src/server/webapp/handlers_test.go @@ -1,116 +1,106 @@ package webapp import ( - "net/http" - "net/http/httptest" - "os" - "testing" - "strings" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" - "github.com/DATA-DOG/go-sqlmock" - "gontrol/src/logger" + "gontrol/src/logger" ) -func newTestApp(t *testing.T) (*App, sqlmock.Sqlmock) { - t.Helper() +func newTestAppWithSQLiteLogger(t *testing.T) *App { + t.Helper() - tmpl, err := ParseTemplates() - if err != nil { - t.Fatalf("ParseTemplates: %v", err) - } + tmpl, err := ParseTemplates() + if err != nil { + t.Fatalf("ParseTemplates: %v", err) + } - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("sqlmock: %v", err) - } - return &App{Tmpl: tmpl, DB: db}, mock -} - -// initFakeLogs puts one record in the in‑memory log DB used by logger.FetchLogs. -func initFakeLogs(t *testing.T) { - t.Helper() - if err := logger.InsertLog(logger.Info, "unit‑test"); err != nil { - t.Fatalf("InsertLog: %v", err) - } + tmp, err := os.CreateTemp("", "logs_*.db") + if err != nil { + t.Fatalf("Temp DB: %v", err) + } + tmp.Close() + t.Cleanup(func() { + os.Remove(tmp.Name()) + }) + + logInstance, err := logger.Init(tmp.Name()) + if err != nil { + t.Fatalf("logger.Init: %v", err) + } + t.Cleanup(func() { + logInstance.Close() + }) + + return &App{ + Tmpl: tmpl, + Logger: logInstance, + } } +// Test that the logsHandler returns HTML with logs correctly func TestLogsHandler_HTML(t *testing.T) { - tmp, err := os.CreateTemp("", "logs_*.db") - if err != nil { - t.Fatalf("Temp DB: %v", err) - } - tmp.Close() - defer os.Remove(tmp.Name()) + app := newTestAppWithSQLiteLogger(t) - if err := logger.InitDB(tmp.Name()); err != nil { - t.Fatalf("logger.InitDB: %v", err) - } - defer logger.CloseDB() + // Insert a log entry to be fetched + app.Logger.Error("fake-log-1") - logger.InsertLog(logger.Error, "fake-log-1") + req := httptest.NewRequest(http.MethodGet, "/logs/error?limit=1", nil) + rec := httptest.NewRecorder() - app, mock := newTestApp(t) - defer app.DB.Close() + app.logsHandler(rec, req) - if err := mock.ExpectationsWereMet(); err != nil { - t.Fatalf("sqlmock expectations: %v", err) - } - - - req := httptest.NewRequest(http.MethodGet, "/logs/error?limit=1", nil) - rec := httptest.NewRecorder() - - app.logsHandler(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("status = %d; want 200", rec.Code) - } - if ct := rec.Header().Get("Content-Type"); ct != "text/html; charset=utf-8" { - t.Errorf("Content-Type = %s; want text/html", ct) - } - if !strings.Contains(rec.Body.String(), "fake-log-1") { - t.Errorf("rendered HTML missing our fake log line") - } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 200", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "text/html; charset=utf-8" { + t.Errorf("Content-Type = %s; want text/html", ct) + } + if !strings.Contains(rec.Body.String(), "fake-log-1") { + t.Errorf("rendered HTML missing our fake log line") + } } -// Test that `limit=abc` is rejected with 400 and FetchLogs is **not** hit. +// Test that invalid 'limit' query param returns 400 Bad Request func TestLogsHandler_InvalidLimit(t *testing.T) { - initFakeLogs(t) + app := newTestAppWithSQLiteLogger(t) - app, _ := newTestApp(t) // helper from previous file + req := httptest.NewRequest(http.MethodGet, "/logs?limit=abc", nil) + rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/logs?limit=abc", nil) - rec := httptest.NewRecorder() + app.logsHandler(rec, req) - app.logsHandler(rec, req) - - if rec.Code != http.StatusBadRequest { - t.Fatalf("status = %d; want 400", rec.Code) - } - if !strings.Contains(rec.Body.String(), "Invalid count value") { - t.Errorf("missing error message") - } + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d; want 400", rec.Code) + } + if !strings.Contains(rec.Body.String(), "Invalid count value") { + t.Errorf("missing error message") + } } -// Test JSON response when client sends Accept: application/json. +// Test JSON output when client sends Accept: application/json func TestLogsHandler_JSON(t *testing.T) { - initFakeLogs(t) + app := newTestAppWithSQLiteLogger(t) - app, _ := newTestApp(t) + // Insert a log entry to be fetched + app.Logger.Info("unit-test-json") - req := httptest.NewRequest(http.MethodGet, "/logs?limit=1", nil) - req.Header.Set("Accept", "application/json") - rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?limit=1", nil) + req.Header.Set("Accept", "application/json") + rec := httptest.NewRecorder() - app.logsHandler(rec, req) + app.logsHandler(rec, req) - if rec.Code != http.StatusOK { - t.Fatalf("status = %d; want 200", rec.Code) - } - if ct := rec.Header().Get("Content-Type"); ct != "application/json" { - t.Errorf("Content‑Type = %q; want application/json", ct) - } - if !strings.Contains(rec.Body.String(), "unit‑test") { - t.Errorf("JSON body missing expected log entry") - } + if rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 200", rec.Code) + } + if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q; want application/json", ct) + } + if !strings.Contains(rec.Body.String(), "unit-test-json") { + t.Errorf("JSON body missing expected log entry") + } } diff --git a/src/server/webapp/static_files_test.go b/src/server/webapp/static_files_test.go index 0389f6c..1d923ab 100644 --- a/src/server/webapp/static_files_test.go +++ b/src/server/webapp/static_files_test.go @@ -1,54 +1,54 @@ -// src/server/webapp/static_files_test.go package webapp import ( - "io" - "io/fs" - "net/http" - "net/http/httptest" - "testing" + "io" + "io/fs" + "net/http" + "net/http/httptest" + "testing" ) // findFirstStaticFile walks the embedded FS and returns the first file path. func findFirstStaticFile() (webPath string, content []byte, ok bool) { - _ = fs.WalkDir(assets, "static", func(path string, d fs.DirEntry, err error) error { - if !d.IsDir() && ok == false { - data, _ := assets.ReadFile(path) // ignore err; test will skip if nil - webPath = "/" + path // e.g. /static/css/main.css - content = data - ok = true - } - return nil - }) - return + _ = fs.WalkDir(assets, "static", func(path string, d fs.DirEntry, err error) error { + if !d.IsDir() && !ok { + data, _ := assets.ReadFile(path) // ignore err; test will skip if nil + webPath = "/" + path // e.g. /static/css/main.css + content = data + ok = true + } + return nil + }) + return } // Requires at least one file under static/. Skips if none embedded. func TestStaticFileServer(t *testing.T) { - webPath, wantBytes, ok := findFirstStaticFile() - if !ok { - t.Skip("no embedded static files to test") - } + webPath, wantBytes, ok := findFirstStaticFile() + if !ok { + t.Skip("no embedded static files to test") + } - //----------------------------------------------------------------- - // build router with sqlmock DB (not used in this test) - //----------------------------------------------------------------- - app, _ := newTestApp(t) - ts := httptest.NewServer(BuildRouter(app)) - defer ts.Close() + //----------------------------------------------------------------- + // build router with SQLite Logger (real DB) for test + //----------------------------------------------------------------- + app := newTestAppWithSQLiteLogger(t) - res, err := http.Get(ts.URL + webPath) - if err != nil { - t.Fatalf("GET %s: %v", webPath, err) - } - defer res.Body.Close() + ts := httptest.NewServer(BuildRouter(app)) + defer ts.Close() - if res.StatusCode != http.StatusOK { - t.Fatalf("status = %d; want 200", res.StatusCode) - } + res, err := http.Get(ts.URL + webPath) + if err != nil { + t.Fatalf("GET %s: %v", webPath, err) + } + defer res.Body.Close() - gotBytes, _ := io.ReadAll(res.Body) - if len(gotBytes) == 0 || string(gotBytes) != string(wantBytes) { - t.Errorf("served file differs from embedded asset") - } + if res.StatusCode != http.StatusOK { + t.Fatalf("status = %d; want 200", res.StatusCode) + } + + gotBytes, _ := io.ReadAll(res.Body) + if len(gotBytes) == 0 || string(gotBytes) != string(wantBytes) { + t.Errorf("served file differs from embedded asset") + } } diff --git a/src/server/websocket/websocketServer.go b/src/server/websocket/websocketServer.go index f5d6e57..5a75b49 100644 --- a/src/server/websocket/websocketServer.go +++ b/src/server/websocket/websocketServer.go @@ -1,7 +1,6 @@ package websocketserver import ( - "database/sql" "encoding/json" "fmt" "gontrol/src/logger" @@ -22,12 +21,16 @@ import ( ) var responseChannels sync.Map // Key: agentName, Value: chan string -var db *sql.DB +var logService logger.LoggerInterface type webSocketHandler struct { upgrader websocket.Upgrader } +func SetLogger(logger logger.LoggerInterface) { + logService = logger +} + var agentSockets = make(map[string]*websocket.Conn) var agentSocketsMutex sync.Mutex @@ -62,22 +65,26 @@ func registerAgent(agentName, agentId, agentIp, agentType, addPort, hostname str resp, err := http.PostForm(registerURL, form) if err != nil { log.Printf("Error registering agent: %v", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Error registering agent: %v", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Error registering agent: %v", err)) + logService.Info(fmt.Sprintf("Error registering agent: %v", err)) return err } defer resp.Body.Close() if resp.StatusCode == http.StatusCreated { log.Printf("Agent %s successfully registered.", agentName) - logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s successfully registered.", agentName)) + // logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s successfully registered.", agentName)) + logService.Info(fmt.Sprintf("Agent %s successfully registered.", agentName)) return nil } else if resp.StatusCode == http.StatusOK { log.Printf("Agent %s already registered.", agentName) - logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s already registered.", agentName)) + // logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s already registered.", agentName)) + logService.Info(fmt.Sprintf("Agent %s already registered.", agentName)) return nil } else { log.Printf("Failed to register agent, status: %v", resp.Status) - logger.InsertLog(logger.Error, fmt.Sprintf("Failed to register agent, status: %v", resp.Status)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Failed to register agent, status: %v", resp.Status)) + logService.Error(fmt.Sprintf("Failer to register agent, status: %v", resp.Status)) return err } @@ -89,7 +96,8 @@ func getAgentDetails(agentId string) (*api.Agent, error) { resp, err := http.Get(agentURL) if err != nil { log.Printf("Failed to make GET request: %s", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err)) + logService.Error(fmt.Sprintf("Failed to make GET request: %s", err)) return nil, err } defer resp.Body.Close() @@ -97,7 +105,8 @@ func getAgentDetails(agentId string) (*api.Agent, error) { doc, err := goquery.NewDocumentFromReader(resp.Body) if err != nil { log.Printf("Failed to parse HTML: %s", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse HTML: %s", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse HTML: %s", err)) + logService.Error(fmt.Sprintf("Failed to parse HTML: %s", err)) return nil, err } @@ -108,7 +117,8 @@ func getAgentDetails(agentId string) (*api.Agent, error) { agent.AgentID, err = strconv.Atoi(strings.TrimSpace(strings.TrimPrefix(text, "ID:"))) if err != nil { log.Printf("Converting string to integer failed in getAgentDetails(): %s", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Converting string to integer failed in getAgentDetails(): %s", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Converting string to integer failed in getAgentDetails(): %s", err)) + logService.Error(fmt.Sprintf("Converting string to integer failed in getAgentDetails(): %s", err)) } } else if strings.HasPrefix(text, "Name:") { agent.AgentName = strings.TrimSpace(strings.TrimPrefix(text, "Name:")) @@ -134,28 +144,32 @@ func getAgentIds() ([]string, error) { resp, err := http.Get(idURL) if err != nil { log.Printf("Failed to make GET request: %s", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Failed to make GET request: %s", err)) + logService.Error(fmt.Sprintf("Failed to make GET request: %s", err)) return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Printf("Unexpected status code: %d", resp.StatusCode) - logger.InsertLog(logger.Info, fmt.Sprintf("Unexpected status code: %d", resp.StatusCode)) + // logger.InsertLog(logger.Info, fmt.Sprintf("Unexpected status code: %d", resp.StatusCode)) + logService.Info(fmt.Sprintf("Unexpected status code: %d", resp.StatusCode)) return nil, nil } body, err := io.ReadAll(resp.Body) if err != nil { log.Printf("Failed to read response body: %s", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Failed to read response body: %s", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Failed to read response body: %s", err)) + logService.Error(fmt.Sprintf("Failed to read response body: %s", err)) return nil, err } var agentIds []string if err := json.Unmarshal(body, &agentIds); err != nil { log.Printf("Failed to parse JSON response: %s", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse JSON response: %s", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Failed to parse JSON response: %s", err)) + logService.Error(fmt.Sprintf("Failed to parse JSON response: %s", err)) return nil, err } @@ -167,7 +181,8 @@ func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ c, err := wsh.upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Error %s when upgrading connection to websocket", err) - logger.InsertLog(logger.Error, fmt.Sprintf("Error %s when upgrading connection to websocket", err)) + // logger.InsertLog(logger.Error, fmt.Sprintf("Error %s when upgrading connection to websocket", err)) + logService.Error(fmt.Sprintf("Error %s when upgrading connection to websocket", err)) return } @@ -200,13 +215,15 @@ func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ if agentName == "" || agentIP == "" { log.Printf("Missing agentName or IPv4Address in query parameters") - logger.InsertLog(logger.Info, fmt.Sprintf("Missing agentName or IPv4Address in query parameters")) + // logger.InsertLog(logger.Info, fmt.Sprintf("Missing agentName or IPv4Address in query parameters")) + logService.Info(fmt.Sprintf("Missing agentName or IPv4Address in query parameters")) c.Close() return } log.Printf("Agent %s connected: %s (%s)", agentId, agentName, agentIP) - logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s connected: %s (%s)", agentId, agentName, agentIP)) + // logger.InsertLog(logger.Info, fmt.Sprintf("Agent %s connected: %s (%s)", agentId, agentName, agentIP)) + logService.Info(fmt.Sprintf("Agent %s connected: %s (%s)", agentId, agentName, agentIP)) agentSocketsMutex.Lock() agentSockets[agentName] = c @@ -218,19 +235,69 @@ func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ agentSocketsMutex.Unlock() c.Close() log.Printf("Agent disconnected: %s (%s)", agentName, agentIP) - logger.InsertLog(logger.Info, fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP)) + // logger.InsertLog(logger.Info, fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP)) + logService.Info(fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP)) }() + // for { + // _, message, err := c.ReadMessage() + // if err != nil { + // log.Printf("Error reading from agent %s: %v", agentName, err) + // // logger.InsertLog(logger.Error, fmt.Sprintf("Error reading from agent %s: %v", agentName, err)) + // logService.Error(fmt.Sprintf("Error reading from agent %s: %v", agentName, err)) + // break + // } + // // log.Printf("Message from agent %s: %s", agentName, message) + // log.Printf("Message from agent %s received", agentName) + // // logger.InsertLog(logger.Debug, fmt.Sprintf("Message from agent %s: %s", agentName, message)) + // logService.Debug(fmt.Sprintf("Message from agent %s: %s", agentName, message)) + + // if ch, ok := responseChannels.Load(agentName); ok { + // responseChan := ch.(chan string) + // responseChan <- string(message) + // } + // } + for { _, message, err := c.ReadMessage() - if err != nil { + if err != nil { log.Printf("Error reading from agent %s: %v", agentName, err) - logger.InsertLog(logger.Error, fmt.Sprintf("Error reading from agent %s: %v", agentName, err)) + logService.Error(fmt.Sprintf("Agent disconnected: %s (%s)", agentName, agentIP)) break } - // log.Printf("Message from agent %s: %s", agentName, message) - log.Printf("Message from agent %s received", agentName) - logger.InsertLog(logger.Debug, fmt.Sprintf("Message from agent %s: %s", agentName, message)) + + var generic map [string]interface{} + if err := json.Unmarshal(message, &generic); err != nil { + logService.Error(fmt.Sprintf("Invalid JSON from %s: %v", agentName, err)) + continue + } + + // In case the message coming in is of type log, it gets switched into the corresponding level + if msgType, ok := generic["type"].(string); ok && msgType == "log" { + level := strings.ToLower(fmt.Sprintf("%v", generic["level"])) + content := fmt.Sprintf("%v", generic["payload"]) + + formatted := fmt.Sprintf("Log from %s: %s", agentName, content) + + switch level { + case "info": + logService.Info(formatted) + case "error": + logService.Error(formatted) + case "debug": + logService.Debug(formatted) + case "warn", "warning": + logService.Fatal(formatted) + case "fatal": + logService.Fatal(formatted) + default: + logService.Info(formatted) + } + continue + } + + // From here things get prepared for /executeCommand + logService.Debug(fmt.Sprintf("Message from agent %s: %s", agentName, message)) if ch, ok := responseChannels.Load(agentName); ok { responseChan := ch.(chan string) @@ -241,6 +308,7 @@ func (wsh webSocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request){ type Message struct { Type string `json:"type"` + Level string `json:"level"` Payload string `json:"payload"` } @@ -248,7 +316,8 @@ var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Reques err := r.ParseForm() if err != nil { http.Error(w, "Invalid form data", http.StatusBadRequest) - logger.InsertLog(logger.Info, "Invalid form data") + // logger.InsertLog(logger.Info, "Invalid form data") + logService.Info("Invalid form data") return } @@ -268,7 +337,8 @@ var executeCommand http.HandlerFunc = func(w http.ResponseWriter, r *http.Reques if len(agentNames) == 0 || command == "" { http.Error(w, "Missing agent or command", http.StatusBadRequest) - logger.InsertLog(logger.Error, "Missing agent or command") + // logger.InsertLog(logger.Error, "Missing agent or command") + logService.Error("Missing agent or command") return }