diff --git a/go.mod b/go.mod index 7788eef..858c740 100755 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module zlh-agent go 1.21.6 + +require github.com/creack/pty v1.1.21 + +require github.com/gorilla/websocket v1.5.1 + +require golang.org/x/net v0.17.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..7b8f512 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= +github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= diff --git a/internal/http/agent.go b/internal/http/agent.go index 992bd62..17ededa 100755 --- a/internal/http/agent.go +++ b/internal/http/agent.go @@ -8,9 +8,11 @@ import ( "net/http" "os" "path/filepath" + "strconv" "strings" "time" + mcstatus "zlh-agent/internal/minecraft" "zlh-agent/internal/provision" "zlh-agent/internal/provision/devcontainer" "zlh-agent/internal/provision/devcontainer/go" @@ -20,11 +22,16 @@ import ( "zlh-agent/internal/provision/minecraft" "zlh-agent/internal/state" "zlh-agent/internal/system" + "zlh-agent/internal/update" + "zlh-agent/internal/version" ) -/* -------------------------------------------------------------------------- - Helpers -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + Helpers + +---------------------------------------------------------------------------- +*/ func fileExists(path string) bool { _, err := os.Stat(path) return err == nil @@ -35,9 +42,12 @@ func dirExists(path string) bool { return err == nil && s.IsDir() } -/* -------------------------------------------------------------------------- - Shared provision pipeline (installer + Minecraft verify) -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + Shared provision pipeline (installer + Minecraft verify) + +---------------------------------------------------------------------------- +*/ func runProvisionPipeline(cfg *state.Config) error { state.SetState(state.StateInstalling) state.SetInstallStep("provision_all") @@ -61,44 +71,45 @@ func runProvisionPipeline(cfg *state.Config) error { return nil } -/* -------------------------------------------------------------------------- - ensureProvisioned() — idempotent, unified -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + ensureProvisioned() — idempotent, unified + +---------------------------------------------------------------------------- +*/ func ensureProvisioned(cfg *state.Config) error { -if cfg.ContainerType == "dev" { + if cfg.ContainerType == "dev" { - if !devcontainer.IsProvisioned() { - if err := runProvisionPipeline(cfg); err != nil { - return err - } - } + if !devcontainer.IsProvisioned() { + if err := runProvisionPipeline(cfg); err != nil { + return err + } + } - var err error + var err error - switch strings.ToLower(cfg.Runtime) { - case "node": - err = node.Verify(*cfg) - case "python": - err = python.Verify(*cfg) - case "go": - err = goenv.Verify(*cfg) - case "java": - err = java.Verify(*cfg) - default: - return fmt.Errorf("unsupported devcontainer runtime: %s", cfg.Runtime) - } + switch strings.ToLower(cfg.Runtime) { + case "node": + err = node.Verify(*cfg) + case "python": + err = python.Verify(*cfg) + case "go": + err = goenv.Verify(*cfg) + case "java": + err = java.Verify(*cfg) + default: + return fmt.Errorf("unsupported devcontainer runtime: %s", cfg.Runtime) + } - if err != nil { - return err - } - - // ✅ DEV READY = RUNNING - state.SetState(state.StateRunning) - state.SetError(nil) - return nil -} + if err != nil { + return err + } + state.SetState(state.StateIdle) + state.SetError(nil) + return nil + } dir := provision.ServerDir(*cfg) game := strings.ToLower(cfg.Game) @@ -131,9 +142,12 @@ if cfg.ContainerType == "dev" { return runProvisionPipeline(cfg) } -/* -------------------------------------------------------------------------- - /config — the REAL provisioning trigger (async) -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + /config — the REAL provisioning trigger (async) + +---------------------------------------------------------------------------- +*/ func handleConfig(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "POST only", http.StatusMethodNotAllowed) @@ -193,6 +207,23 @@ func handleConfig(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) } + // Wait for server.properties to exist before enforcing + propsPath := filepath.Join(provision.ServerDir(c), "server.properties") + propsDeadline := time.Now().Add(2 * time.Minute) + for { + if _, err := os.Stat(propsPath); err == nil { + break + } + if time.Now().After(propsDeadline) { + err := fmt.Errorf("forge server.properties not found before timeout") + log.Println("[agent] forge post-start error:", err) + state.SetError(err) + state.SetState(state.StateError) + return + } + time.Sleep(2 * time.Second) + } + _ = system.StopServer() if err := minecraft.EnforceForgeServerProperties(c); err != nil { @@ -219,9 +250,12 @@ func handleConfig(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"ok": true, "state": "installing"}`)) } -/* -------------------------------------------------------------------------- - /start -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + /start + +---------------------------------------------------------------------------- +*/ func handleStart(w http.ResponseWriter, r *http.Request) { cfg, err := state.LoadConfig() if err != nil { @@ -243,9 +277,12 @@ func handleStart(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"ok": true, "state": "starting"}`)) } -/* -------------------------------------------------------------------------- - /stop -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + /stop + +---------------------------------------------------------------------------- +*/ func handleStop(w http.ResponseWriter, r *http.Request) { if err := system.StopServer(); err != nil { http.Error(w, "stop error: "+err.Error(), http.StatusInternalServerError) @@ -254,9 +291,12 @@ func handleStop(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } -/* -------------------------------------------------------------------------- - /restart -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + /restart + +---------------------------------------------------------------------------- +*/ func handleRestart(w http.ResponseWriter, r *http.Request) { cfg, err := state.LoadConfig() if err != nil { @@ -280,9 +320,12 @@ func handleRestart(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(`{"ok": true, "state": "starting"}`)) } -/* -------------------------------------------------------------------------- - /status -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + /status + +---------------------------------------------------------------------------- +*/ func handleStatus(w http.ResponseWriter, r *http.Request) { cfg, _ := state.LoadConfig() @@ -303,9 +346,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(resp) } -/* -------------------------------------------------------------------------- - /console/command -----------------------------------------------------------------------------*/ +/* + -------------------------------------------------------------------------- + /console/command + +---------------------------------------------------------------------------- +*/ func handleSendCommand(w http.ResponseWriter, r *http.Request) { cmd := r.URL.Query().Get("cmd") if cmd == "" { @@ -322,8 +368,158 @@ func handleSendCommand(w http.ResponseWriter, r *http.Request) { } /* -------------------------------------------------------------------------- - Router + /agent/update ----------------------------------------------------------------------------*/ + +func handleAgentUpdate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "POST only", http.StatusMethodNotAllowed) + return + } + + res := update.CheckAndUpdate(version.AgentVersion) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(res) +} + +/* -------------------------------------------------------------------------- + /agent/update/status +----------------------------------------------------------------------------*/ + +func handleAgentUpdateStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "GET only", http.StatusMethodNotAllowed) + return + } + + res := update.ReadStatus() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(res) +} + +/* -------------------------------------------------------------------------- + /version +----------------------------------------------------------------------------*/ + +func handleVersion(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "GET only", http.StatusMethodNotAllowed) + return + } + resp := map[string]any{ + "version": version.AgentVersion, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + +/* -------------------------------------------------------------------------- + /game/players +----------------------------------------------------------------------------*/ + +func handleGamePlayers(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "GET only", http.StatusMethodNotAllowed) + return + } + + cfg, err := state.LoadConfig() + if err != nil { + http.Error(w, "no config loaded", http.StatusBadRequest) + return + } + if strings.ToLower(cfg.ContainerType) != "game" { + http.Error(w, "not a game container", http.StatusBadRequest) + return + } + if strings.ToLower(cfg.Game) != "minecraft" { + http.Error(w, "unsupported game", http.StatusNotImplemented) + return + } + + ports := make([]int, 0, 3) + propsPath := filepath.Join(provision.ServerDir(*cfg), "server.properties") + if b, err := os.ReadFile(propsPath); err == nil { + lines := strings.Split(string(b), "\n") + for _, l := range lines { + if strings.HasPrefix(l, "server-port=") { + if p, err := strconv.Atoi(strings.TrimPrefix(l, "server-port=")); err == nil && p > 0 { + ports = append(ports, p) + } + break + } + } + } + if len(cfg.Ports) > 0 && cfg.Ports[0] > 0 { + ports = append(ports, cfg.Ports[0]) + } + ports = append(ports, 25565) + + seenPorts := make(map[int]struct{}, len(ports)) + uniqPorts := make([]int, 0, len(ports)) + for _, p := range ports { + if _, ok := seenPorts[p]; ok { + continue + } + seenPorts[p] = struct{}{} + uniqPorts = append(uniqPorts, p) + } + + protocols := []int{mcstatus.ProtocolForVersion(cfg.Version), 767, 765, 763, 762, 754} + seenProtocols := make(map[int]struct{}, len(protocols)) + uniqProtocols := make([]int, 0, len(protocols)) + for _, pr := range protocols { + if _, ok := seenProtocols[pr]; ok { + continue + } + seenProtocols[pr] = struct{}{} + uniqProtocols = append(uniqProtocols, pr) + } + + var status mcstatus.StatusResponse + var lastErr error + for _, port := range uniqPorts { + for _, protocol := range uniqProtocols { + s, err := mcstatus.QueryStatus("127.0.0.1", port, protocol) + if err != nil { + lastErr = err + continue + } + status = s + lastErr = nil + break + } + if lastErr == nil { + break + } + } + if lastErr != nil { + http.Error(w, "status query failed: "+lastErr.Error(), http.StatusBadGateway) + return + } + + players := make([]string, 0, len(status.Players.Sample)) + for _, p := range status.Players.Sample { + if p.Name != "" { + players = append(players, p.Name) + } + } + + resp := map[string]any{ + "online": status.Players.Online, + "max": status.Players.Max, + "players": players, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) +} + +/* + -------------------------------------------------------------------------- + Router + +---------------------------------------------------------------------------- +*/ func NewMux() *http.ServeMux { m := http.NewServeMux() @@ -333,6 +529,10 @@ func NewMux() *http.ServeMux { m.HandleFunc("/restart", handleRestart) m.HandleFunc("/status", handleStatus) m.HandleFunc("/console/command", handleSendCommand) + m.HandleFunc("/agent/update", handleAgentUpdate) + m.HandleFunc("/agent/update/status", handleAgentUpdateStatus) + m.HandleFunc("/version", handleVersion) + m.HandleFunc("/game/players", handleGamePlayers) registerWebSocket(m) diff --git a/internal/http/console_sessions.go b/internal/http/console_sessions.go new file mode 100644 index 0000000..adf8ced --- /dev/null +++ b/internal/http/console_sessions.go @@ -0,0 +1,189 @@ +package agenthttp + +import ( + "fmt" + "io" + "log" + "os" + "sync" + "time" + + "github.com/gorilla/websocket" + + "zlh-agent/internal/runtime" + "zlh-agent/internal/state" + "zlh-agent/internal/system" +) + +const sessionTTL = 60 * time.Second + +type consoleConn struct { + send chan []byte +} + +type consoleSession struct { + key string + cfg *state.Config + ptyFile *os.File + createdAt time.Time + lastActive time.Time + + mu sync.Mutex + conns map[*websocket.Conn]*consoleConn + readerOnce sync.Once +} + +var ( + sessionMu sync.Mutex + sessions = make(map[string]*consoleSession) +) + +func sessionKey(cfg *state.Config) string { + return fmt.Sprintf("%d:%s", cfg.VMID, cfg.ContainerType) +} + +func getConsoleSession(cfg *state.Config) (*consoleSession, bool, error) { + key := sessionKey(cfg) + + sessionMu.Lock() + if sess, ok := sessions[key]; ok { + sessionMu.Unlock() + sess.touch() + log.Printf("[ws] session reuse: vmid=%d type=%s", cfg.VMID, cfg.ContainerType) + return sess, true, nil + } + sessionMu.Unlock() + + ptyFile, err := system.GetConsolePTY(cfg) + if err != nil { + return nil, false, err + } + + sess := &consoleSession{ + key: key, + cfg: cfg, + ptyFile: ptyFile, + createdAt: time.Now(), + lastActive: time.Now(), + conns: make(map[*websocket.Conn]*consoleConn), + } + + sess.startReader() + + sessionMu.Lock() + sessions[key] = sess + sessionMu.Unlock() + + log.Printf("[ws] session created: vmid=%d type=%s", cfg.VMID, cfg.ContainerType) + return sess, false, nil +} + +func (s *consoleSession) touch() { + s.mu.Lock() + s.lastActive = time.Now() + s.mu.Unlock() +} + +func (s *consoleSession) addConn(conn *websocket.Conn, cc *consoleConn) *consoleConn { + s.mu.Lock() + defer s.mu.Unlock() + + if cc == nil { + cc = &consoleConn{send: make(chan []byte, 128)} + } + s.conns[conn] = cc + s.lastActive = time.Now() + log.Printf("[ws] conn add: vmid=%d type=%s conns=%d", s.cfg.VMID, s.cfg.ContainerType, len(s.conns)) + return cc +} + +func (s *consoleSession) removeConn(conn *websocket.Conn) int { + s.mu.Lock() + defer s.mu.Unlock() + + cc, ok := s.conns[conn] + if ok { + delete(s.conns, conn) + close(cc.send) + } + s.lastActive = time.Now() + log.Printf("[ws] conn remove: vmid=%d type=%s conns=%d", s.cfg.VMID, s.cfg.ContainerType, len(s.conns)) + return len(s.conns) +} + +func (s *consoleSession) startReader() { + s.readerOnce.Do(func() { + go func() { + buf := make([]byte, 4096) + for { + n, err := s.ptyFile.Read(buf) + if n > 0 { + out := make([]byte, n) + copy(out, buf[:n]) + log.Printf("[ws] pty read: vmid=%d bytes=%d", s.cfg.VMID, n) + s.broadcast(out) + } + if err != nil { + if err == io.EOF { + log.Printf("[ws] pty read loop exit: vmid=%d err=EOF", s.cfg.VMID) + } else { + log.Printf("[ws] pty read loop exit: vmid=%d err=%v", s.cfg.VMID, err) + } + return + } + if n == 0 && err == nil { + continue + } + } + }() + }) +} + +func (s *consoleSession) broadcast(data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + s.lastActive = time.Now() + for _, cc := range s.conns { + select { + case cc.send <- data: + default: + select { + case <-cc.send: + default: + } + select { + case cc.send <- data: + default: + } + } + } +} + +func (s *consoleSession) writeInput(data []byte) error { + s.touch() + return runtime.Write(s.ptyFile, data) +} + +func (s *consoleSession) scheduleCleanupIfIdle() { + s.mu.Lock() + last := s.lastActive + s.mu.Unlock() + + go func(ts time.Time) { + time.Sleep(sessionTTL) + s.mu.Lock() + conns := len(s.conns) + lastActive := s.lastActive + s.mu.Unlock() + + if conns == 0 && lastActive.Equal(ts) { + log.Printf("[ws] session cleanup: vmid=%d type=%s", s.cfg.VMID, s.cfg.ContainerType) + if s.cfg.ContainerType == "dev" { + _ = system.StopDevShell() + } + sessionMu.Lock() + delete(sessions, s.key) + sessionMu.Unlock() + } + }(last) +} diff --git a/internal/http/websocket.go b/internal/http/websocket.go index 7442f62..9bbcec5 100644 --- a/internal/http/websocket.go +++ b/internal/http/websocket.go @@ -1,19 +1,13 @@ package agenthttp import ( - "bufio" - "crypto/sha1" - "encoding/base64" "fmt" "log" - "net" "net/http" - "os" - "path/filepath" - "strings" "time" - "zlh-agent/internal/provision" + "github.com/gorilla/websocket" + "zlh-agent/internal/state" ) @@ -25,86 +19,14 @@ import ( GET /console/stream */ -const maxInitialTail = 4096 // 4 KB +const maxPayloadSize = 64 * 1024 -/* -------------------------------------------------------------------------- - Minimal WebSocket Upgrader (stdlib only) -----------------------------------------------------------------------------*/ - -type WebSocketConn struct { - Conn net.Conn - Rw *bufio.ReadWriter -} - -func upgradeToWebSocket(w http.ResponseWriter, r *http.Request) (*WebSocketConn, error) { - if !strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") || - strings.ToLower(r.Header.Get("Upgrade")) != "websocket" { - return nil, fmt.Errorf("invalid websocket upgrade request") - } - - key := r.Header.Get("Sec-WebSocket-Key") - if key == "" { - return nil, fmt.Errorf("missing Sec-WebSocket-Key") - } - - accept := computeAcceptKey(key) - - h := w.Header() - h.Set("Upgrade", "websocket") - h.Set("Connection", "Upgrade") - h.Set("Sec-WebSocket-Accept", accept) - h.Set("Sec-WebSocket-Version", "13") - - w.WriteHeader(http.StatusSwitchingProtocols) - - hj, ok := w.(http.Hijacker) - if !ok { - return nil, fmt.Errorf("websocket: hijacking not supported") - } - - conn, rw, err := hj.Hijack() - if err != nil { - return nil, fmt.Errorf("websocket hijack: %w", err) - } - - return &WebSocketConn{ - Conn: conn, - Rw: rw, - }, nil -} - -func (c *WebSocketConn) WriteText(msg string) error { - payload := []byte(msg) - - // FIN + opcode(1 = text) - header := []byte{0x81} - - // Length encoding - if len(payload) < 126 { - header = append(header, byte(len(payload))) - } else { - header = append(header, 126, byte(len(payload)>>8), byte(len(payload))) - } - - if _, err := c.Conn.Write(header); err != nil { - return err - } - _, err := c.Conn.Write(payload) - return err -} - -func (c *WebSocketConn) Close() error { - return c.Conn.Close() -} - -/* -------------------------------------------------------------------------- - SHA-1 + Base64 for Sec-WebSocket-Accept -----------------------------------------------------------------------------*/ - -func computeAcceptKey(key string) string { - const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - h := sha1.Sum([]byte(key + magic)) - return base64.StdEncoding.EncodeToString(h[:]) +var wsUpgrader = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, } /* -------------------------------------------------------------------------- @@ -117,45 +39,131 @@ func handleConsoleStream(w http.ResponseWriter, r *http.Request) { http.Error(w, "no config loaded", http.StatusBadRequest) return } + log.Printf("[ws] console connect: vmid=%d type=%s runtime=%s game=%s", cfg.VMID, cfg.ContainerType, cfg.Runtime, cfg.Game) - ws, err := upgradeToWebSocket(w, r) + conn, err := wsUpgrader.Upgrade(w, r, nil) if err != nil { log.Println("[ws] upgrade failed:", err) http.Error(w, "websocket upgrade failed", http.StatusBadRequest) return } - defer ws.Close() + defer conn.Close() - dir := provision.ServerDir(*cfg) - logfile := filepath.Join(dir, "logs", "latest.log") + conn.SetReadLimit(maxPayloadSize) + conn.SetCloseHandler(func(code int, text string) error { + log.Printf("[ws] close frame: vmid=%d code=%d text=%q", cfg.VMID, code, text) + return nil + }) - f, err := os.Open(logfile) - if err != nil { - _ = ws.WriteText(fmt.Sprintf("[error] cannot open log: %v", err)) - return - } - defer f.Close() - - // 1) Send last 4 KB (initial tail) - stat, _ := f.Stat() - sz := stat.Size() - if sz > maxInitialTail { - _, _ = f.Seek(sz-maxInitialTail, 0) - } - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - _ = ws.WriteText(scanner.Text()) - } - - // 2) Follow the file — stream new log lines live - for { - time.Sleep(500 * time.Millisecond) - for scanner.Scan() { - line := scanner.Text() - _ = ws.WriteText(line) + sendCh := make(chan []byte, 64) + writeErrCh := make(chan error, 1) + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case msg, ok := <-sendCh: + if !ok { + writeErrCh <- nil + return + } + if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil { + writeErrCh <- err + return + } + case <-ticker.C: + if err := conn.WriteMessage(websocket.TextMessage, []byte{}); err != nil { + writeErrCh <- err + return + } + } + } + }() + + inputCh := make(chan []byte, 32) + readErrCh := make(chan error, 1) + go func() { + for { + msgType, msg, err := conn.ReadMessage() + if err != nil { + log.Printf("[ws] read error: vmid=%d err=%v", cfg.VMID, err) + readErrCh <- err + close(inputCh) + return + } + if msgType == websocket.TextMessage || msgType == websocket.BinaryMessage { + log.Printf("[ws] input: vmid=%d bytes=%d type=%d", cfg.VMID, len(msg), msgType) + inputCh <- msg + } + } + }() + + waitingNotified := false + var sess *consoleSession + sessionBound := false + for { + sess, _, err = getConsoleSession(cfg) + if err == nil { + log.Printf("[ws] console attached: vmid=%d type=%s", cfg.VMID, cfg.ContainerType) + break + } + if cfg.ContainerType != "dev" { + if !waitingNotified { + sendCh <- []byte("[info] waiting for server console...") + log.Printf("[ws] waiting for server console: vmid=%d type=%s err=%v", cfg.VMID, cfg.ContainerType, err) + waitingNotified = true + } + } else { + log.Printf("[ws] dev console unavailable: vmid=%d err=%v", cfg.VMID, err) + sendCh <- []byte(fmt.Sprintf("[error] %v", err)) + if !sessionBound { + close(sendCh) + } + return + } + + select { + case <-time.After(500 * time.Millisecond): + case <-readErrCh: + if !sessionBound { + close(sendCh) + } + return + case <-writeErrCh: + if !sessionBound { + close(sendCh) + } + return + } + } + + sess.addConn(conn, &consoleConn{send: sendCh}) + sessionBound = true + defer func() { + if sess != nil { + if sess.removeConn(conn) == 0 { + sess.scheduleCleanupIfIdle() + } + } + }() + + for { + select { + case msg, ok := <-inputCh: + if !ok { + return + } + if err := sess.writeInput(msg); err != nil { + sendCh <- []byte(fmt.Sprintf("[error] %v", err)) + } + case <-readErrCh: + return + case err := <-writeErrCh: + if err != nil { + log.Printf("[ws] write error: vmid=%d err=%v", cfg.VMID, err) + } + return } - // on EOF, loop continues and scanner will pick up new lines } } diff --git a/internal/minecraft/status.go b/internal/minecraft/status.go new file mode 100644 index 0000000..5b13b73 --- /dev/null +++ b/internal/minecraft/status.go @@ -0,0 +1,205 @@ +package minecraft + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net" + "strings" + "time" +) + +type StatusResponse struct { + Version struct { + Name string `json:"name"` + Protocol int `json:"protocol"` + } `json:"version"` + Players struct { + Max int `json:"max"` + Online int `json:"online"` + Sample []struct { + Name string `json:"name"` + ID string `json:"id"` + } `json:"sample"` + } `json:"players"` +} + +func ProtocolForVersion(version string) int { + v := strings.TrimSpace(strings.TrimPrefix(version, "v")) + switch v { + case "1.21.1": + return 767 + case "1.21": + return 767 + case "1.20.4": + return 765 + case "1.20.1": + return 763 + case "1.19.4": + return 762 + default: + return 754 + } +} + +func QueryStatus(host string, port int, protocol int) (StatusResponse, error) { + addr := fmt.Sprintf("%s:%d", host, port) + conn, err := net.DialTimeout("tcp", addr, 3*time.Second) + if err != nil { + return StatusResponse{}, err + } + defer conn.Close() + + _ = conn.SetDeadline(time.Now().Add(5 * time.Second)) + + if err := writeHandshake(conn, host, port, protocol); err != nil { + return StatusResponse{}, err + } + if err := writeStatusRequest(conn); err != nil { + return StatusResponse{}, err + } + + payload, err := readPacket(conn) + if err != nil { + return StatusResponse{}, err + } + + r := bytes.NewReader(payload) + packetID, err := readVarInt(r) + if err != nil { + return StatusResponse{}, err + } + if packetID != 0x00 { + return StatusResponse{}, fmt.Errorf("unexpected packet id: %d", packetID) + } + + respStr, err := readString(r) + if err != nil { + return StatusResponse{}, err + } + + var status StatusResponse + if err := json.Unmarshal([]byte(respStr), &status); err != nil { + return StatusResponse{}, err + } + return status, nil +} + +func writeHandshake(w io.Writer, host string, port int, protocol int) error { + var payload bytes.Buffer + if err := writeVarInt(&payload, 0x00); err != nil { + return err + } + if err := writeVarInt(&payload, protocol); err != nil { + return err + } + if err := writeString(&payload, host); err != nil { + return err + } + if err := binary.Write(&payload, binary.BigEndian, uint16(port)); err != nil { + return err + } + if err := writeVarInt(&payload, 0x01); err != nil { + return err + } + + return writePacket(w, payload.Bytes()) +} + +func writeStatusRequest(w io.Writer) error { + var payload bytes.Buffer + if err := writeVarInt(&payload, 0x00); err != nil { + return err + } + return writePacket(w, payload.Bytes()) +} + +func writePacket(w io.Writer, payload []byte) error { + var buf bytes.Buffer + if err := writeVarInt(&buf, len(payload)); err != nil { + return err + } + if _, err := buf.Write(payload); err != nil { + return err + } + _, err := w.Write(buf.Bytes()) + return err +} + +func readPacket(r io.Reader) ([]byte, error) { + length, err := readVarInt(r) + if err != nil { + return nil, err + } + if length <= 0 || length > 1<<20 { + return nil, fmt.Errorf("invalid packet length: %d", length) + } + buf := make([]byte, length) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + return buf, nil +} + +func writeVarInt(w io.Writer, value int) error { + for { + temp := byte(value & 0x7F) + value >>= 7 + if value != 0 { + temp |= 0x80 + } + if _, err := w.Write([]byte{temp}); err != nil { + return err + } + if value == 0 { + break + } + } + return nil +} + +func readVarInt(r io.Reader) (int, error) { + numRead := 0 + result := 0 + for { + if numRead > 5 { + return 0, fmt.Errorf("varint too long") + } + b := make([]byte, 1) + if _, err := r.Read(b); err != nil { + return 0, err + } + value := int(b[0] & 0x7F) + result |= value << (7 * numRead) + numRead++ + if b[0]&0x80 == 0 { + break + } + } + return result, nil +} + +func writeString(w io.Writer, s string) error { + if err := writeVarInt(w, len(s)); err != nil { + return err + } + _, err := w.Write([]byte(s)) + return err +} + +func readString(r io.Reader) (string, error) { + length, err := readVarInt(r) + if err != nil { + return "", err + } + if length < 0 || length > 1<<20 { + return "", fmt.Errorf("invalid string length: %d", length) + } + buf := make([]byte, length) + if _, err := io.ReadFull(r, buf); err != nil { + return "", err + } + return string(buf), nil +} diff --git a/internal/runtime/pty.go b/internal/runtime/pty.go new file mode 100644 index 0000000..13f23e1 --- /dev/null +++ b/internal/runtime/pty.go @@ -0,0 +1,45 @@ +package runtime + +import ( + "os" + "os/exec" + + "github.com/creack/pty" +) + +// CreatePTY starts cmd attached to a PTY and returns the PTY file. +func CreatePTY(cmd *exec.Cmd) (*os.File, error) { + ptmx, err := pty.Start(cmd) + if err != nil { + return nil, err + } + return ptmx, nil +} + +// ReadLoop reads from a PTY in non-blocking mode and streams chunks to onData. +func ReadLoop(ptyFile *os.File, stop <-chan struct{}, onData func([]byte) error) error { + buf := make([]byte, 4096) + for { + n, err := ptyFile.Read(buf) + if n > 0 { + if writeErr := onData(buf[:n]); writeErr != nil { + return writeErr + } + } + if err != nil { + return err + } + + select { + case <-stop: + return nil + default: + } + } +} + +// Write sends data to the PTY. +func Write(ptyFile *os.File, data []byte) error { + _, err := ptyFile.Write(data) + return err +} diff --git a/internal/state/state.go b/internal/state/state.go index 29fcc1b..2fb6be5 100755 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -4,6 +4,7 @@ import ( "encoding/json" "log" "os" + "strconv" "sync" "time" ) @@ -42,8 +43,6 @@ type Config struct { AdminPass string `json:"admin_pass,omitempty"` } - - /* -------------------------------------------------------------------------- AGENT STATE ENUM ----------------------------------------------------------------------------*/ @@ -181,5 +180,24 @@ func LoadConfig() (*Config, error) { if err := json.Unmarshal(b, &cfg); err != nil { return nil, err } + + if cfg.ContainerType == "" { + vmidStr := strconv.Itoa(cfg.VMID) + if len(vmidStr) > 0 { + switch vmidStr[0] { + case '6': + cfg.ContainerType = "dev" + case '5': + cfg.ContainerType = "game" + } + } + if cfg.ContainerType == "" && cfg.Runtime != "" { + cfg.ContainerType = "dev" + } + if cfg.ContainerType == "" { + cfg.ContainerType = "game" + } + } + return &cfg, nil } diff --git a/internal/system/process.go b/internal/system/process.go index c0b229e..2c8ab2d 100755 --- a/internal/system/process.go +++ b/internal/system/process.go @@ -1,28 +1,30 @@ package system import ( - "fmt" - "io" - "os" - "os/exec" - "path/filepath" - "strings" // <-- ADD THIS + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" "sync" - "time" + "time" - "zlh-agent/internal/provision" - "zlh-agent/internal/state" + "zlh-agent/internal/provision" + "zlh-agent/internal/runtime" + "zlh-agent/internal/state" ) - /* -------------------------------------------------------------------------- GLOBAL PROCESS STATE ----------------------------------------------------------------------------*/ var ( - mu sync.Mutex - serverCmd *exec.Cmd - serverStdin io.WriteCloser + mu sync.Mutex + serverCmd *exec.Cmd + serverPTY *os.File + + devCmd *exec.Cmd + devPTY *os.File ) /* -------------------------------------------------------------------------- @@ -30,129 +32,54 @@ var ( ----------------------------------------------------------------------------*/ func StartServer(cfg *state.Config) error { - mu.Lock() - defer mu.Unlock() + mu.Lock() + defer mu.Unlock() - // Already running? - if serverCmd != nil { - return fmt.Errorf("server already running") - } + // Already running? + if serverCmd != nil { + return fmt.Errorf("server already running") + } - dir := provision.ServerDir(*cfg) - startScript := filepath.Join(dir, "start.sh") + dir := provision.ServerDir(*cfg) + startScript := filepath.Join(dir, "start.sh") - cmd := exec.Command("/bin/bash", startScript) - cmd.Dir = dir + cmd := exec.Command("/bin/bash", startScript) + cmd.Dir = dir - stdout, _ := cmd.StdoutPipe() - stderr, _ := cmd.StderrPipe() - stdin, _ := cmd.StdinPipe() + ptmx, err := runtime.CreatePTY(cmd) + if err != nil { + return fmt.Errorf("start server: %w", err) + } - serverStdin = stdin - serverCmd = cmd + serverCmd = cmd + serverPTY = ptmx - // Mark STARTING (not running) - state.SetState(state.StateStarting) + state.SetState(state.StateRunning) + state.SetError(nil) - if err := cmd.Start(); err != nil { - serverCmd = nil - return fmt.Errorf("start server: %w", err) - } + go func() { + err := cmd.Wait() - /* ------------------------- - Log pumps - --------------------------*/ - go pumpOutput(stdout, os.Stdout) - go pumpOutput(stderr, os.Stderr) + mu.Lock() + defer mu.Unlock() - /* ------------------------- - Detect "Done" → running - --------------------------*/ - go detectMinecraftReady(cfg) + if serverPTY != nil { + _ = serverPTY.Close() + } - /* ------------------------- - Crash watcher - --------------------------*/ - go func() { - err := cmd.Wait() + if err != nil { + state.RecordCrash(err) + } else { + state.SetState(state.StateIdle) + } - mu.Lock() - defer mu.Unlock() + serverCmd = nil + serverPTY = nil + }() - if err != nil { - state.RecordCrash(err) - serverCmd = nil - serverStdin = nil - return - } - - // Normal stop - state.SetState(state.StateIdle) - serverCmd = nil - serverStdin = nil - }() - - return nil + return nil } -/* helper to pump logs */ -func pumpOutput(r io.Reader, w *os.File) { - buf := make([]byte, 4096) - for { - n, err := r.Read(buf) - if n > 0 { - w.Write(buf[:n]) - } - if err != nil { - return - } - } -} - -/* Detects Minecraft "Done" and updates state */ -func detectMinecraftReady(cfg *state.Config) { - dir := provision.ServerDir(*cfg) - logPath := filepath.Join(dir, "logs", "latest.log") - - deadline := time.Now().Add(5 * time.Minute) // FORGE NEEDS MORE TIME - - lastSize := int64(0) - - for time.Now().Before(deadline) { - - // Wait for log file to appear - st, err := os.Stat(logPath) - if err == nil { - // ensure file is growing - if st.Size() != lastSize { - lastSize = st.Size() - - b, _ := os.ReadFile(logPath) - s := string(b) - - // UNIVERSAL READY MATCHES - if strings.Contains(s, "Done (") || - strings.Contains(s, "For help, type \"help\"") || - strings.Contains(s, "Successfully loaded forge") || - strings.Contains(s, "Preparing spawn area: 100%") { - - state.SetState(state.StateRunning) - state.SetError(nil) - return - } - } - } - - time.Sleep(2 * time.Second) - } - - state.SetState(state.StateError) - state.SetError(fmt.Errorf("server failed to reach running state before timeout")) -} - - - - /* -------------------------------------------------------------------------- StopServer ----------------------------------------------------------------------------*/ @@ -168,10 +95,10 @@ func StopServer() error { state.SetState(state.StateStopping) // Try graceful stop - if serverStdin != nil { - _, _ = serverStdin.Write([]byte("save-all\n")) + if serverPTY != nil { + _ = runtime.Write(serverPTY, []byte("save-all\n")) time.Sleep(2 * time.Second) - _, _ = serverStdin.Write([]byte("stop\n")) + _ = runtime.Write(serverPTY, []byte("stop\n")) } // Wait a moment @@ -182,9 +109,6 @@ func StopServer() error { _ = serverCmd.Process.Kill() } - state.SetState(state.StateIdle) - serverCmd = nil - serverStdin = nil return nil } @@ -208,10 +132,137 @@ func SendConsoleCommand(cmd string) error { mu.Lock() defer mu.Unlock() - if serverStdin == nil { + if serverPTY == nil { return fmt.Errorf("server console not available") } - _, err := serverStdin.Write([]byte(cmd + "\n")) - return err + return runtime.Write(serverPTY, []byte(cmd+"\n")) +} + +/* -------------------------------------------------------------------------- + Dev Shell PTY +----------------------------------------------------------------------------*/ + +func StartDevShell() (*os.File, error) { + mu.Lock() + defer mu.Unlock() + + if devPTY != nil && devCmd != nil { + return devPTY, nil + } + + shell := "/bin/bash" + if _, err := os.Stat(shell); err != nil { + shell = "/bin/sh" + } + + var cmd *exec.Cmd + if shell == "/bin/bash" { + cmd = exec.Command(shell, "-l", "-i") + } else { + cmd = exec.Command(shell, "-i") + } + cmd.Dir = "/opt" + + ptmx, err := runtime.CreatePTY(cmd) + if err != nil { + return nil, fmt.Errorf("start dev shell: %w", err) + } + + devCmd = cmd + devPTY = ptmx + + state.SetState(state.StateRunning) + state.SetError(nil) + + go func() { + err := cmd.Wait() + + mu.Lock() + defer mu.Unlock() + + if devPTY != nil { + _ = devPTY.Close() + } + + if err != nil { + state.RecordCrash(err) + } else { + state.SetState(state.StateIdle) + } + + devCmd = nil + devPTY = nil + }() + + return devPTY, nil +} + +func GetConsolePTY(cfg *state.Config) (*os.File, error) { + if cfg.ContainerType == "dev" { + return StartDevShell() + } + + mu.Lock() + defer mu.Unlock() + + if serverPTY == nil { + return nil, fmt.Errorf("server console not available") + } + return serverPTY, nil +} + +func WriteConsoleInput(cfg *state.Config, input string) error { + if strings.HasSuffix(input, "\n") { + input = strings.TrimSuffix(input, "\n") + } + payload := []byte(input + "\n") + + if cfg.ContainerType == "dev" { + mu.Lock() + defer mu.Unlock() + if devPTY == nil { + return fmt.Errorf("dev shell not available") + } + return runtime.Write(devPTY, payload) + } + + mu.Lock() + defer mu.Unlock() + if serverPTY == nil { + return fmt.Errorf("server console not available") + } + return runtime.Write(serverPTY, payload) +} + +/* -------------------------------------------------------------------------- + Stop Dev Shell +----------------------------------------------------------------------------*/ + +func StopDevShell() error { + mu.Lock() + defer mu.Unlock() + + if devCmd == nil { + return nil + } + + if devPTY != nil { + _ = runtime.Write(devPTY, []byte("exit\n")) + } + + time.Sleep(1 * time.Second) + + if devCmd.Process != nil { + _ = devCmd.Process.Kill() + } + + if devPTY != nil { + _ = devPTY.Close() + devPTY = nil + } + + devCmd = nil + state.SetState(state.StateIdle) + return nil } diff --git a/internal/update/update.go b/internal/update/update.go new file mode 100644 index 0000000..bf8f0d4 --- /dev/null +++ b/internal/update/update.go @@ -0,0 +1,530 @@ +package update + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" +) + +const ( + artifactBaseURL = "http://10.60.0.251:8080/agents/" + releasesDir = "/opt/zlh-agent/releases" + currentLink = "/opt/zlh-agent/current" + previousLink = "/opt/zlh-agent/previous" + binaryPath = "/opt/zlh-agent/zlh-agent" + stateDir = "/opt/zlh-agent/state" + statusFile = "/opt/zlh-agent/state/update.json" + defaultUnit = "zlh-agent" + defaultMode = "notify" +) + +type Manifest struct { + SchemaVersion int `json:"schema_version"` + Latest string `json:"latest"` + MinSupported string `json:"min_supported"` + Channels map[string]string `json:"channels"` + Artifacts map[string]struct { + LinuxAMD64 struct { + Binary string `json:"binary"` + SHA256 string `json:"sha256"` + } `json:"linux_amd64"` + } `json:"artifacts"` + ReleasedAt string `json:"released_at"` +} + +type Result struct { + Status string `json:"status"` + Current string `json:"current,omitempty"` + Target string `json:"target,omitempty"` + Error string `json:"error,omitempty"` + CheckedAtUTC string `json:"checked_at_utc,omitempty"` +} + +func CheckAvailable(currentVersion string) Result { + currentVersion = normalizeVersion(currentVersion) + result := Result{ + Status: "error", + Current: currentVersion, + CheckedAtUTC: time.Now().UTC().Format(time.RFC3339), + } + + manifest, err := fetchManifest() + if err != nil { + result.Error = err.Error() + writeStatus(result) + return result + } + if manifest.SchemaVersion != 1 { + result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion) + writeStatus(result) + return result + } + + target := normalizeVersion(manifest.Channels["stable"]) + if target == "" { + result.Error = "missing stable channel version" + writeStatus(result) + return result + } + result.Target = target + + if compareVersions(currentVersion, target) >= 0 { + result.Status = "noop" + result.Error = "" + writeStatus(result) + return result + } + + result.Status = "available" + result.Error = "" + writeStatus(result) + return result +} + +func StartPeriodic(currentVersion string) { + mode := strings.ToLower(strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_MODE"))) + if mode == "" { + mode = defaultMode + } + + switch mode { + case "off", "disabled": + log.Printf("[update] periodic checks disabled (mode=%s)", mode) + return + case "notify", "auto": + default: + log.Printf("[update] invalid mode %q, using %q", mode, defaultMode) + mode = defaultMode + } + + interval := 30 * time.Minute + if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_INTERVAL")); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + log.Printf("[update] invalid ZLH_AGENT_UPDATE_INTERVAL=%q: %v (using %s)", v, err, interval) + } else if d < time.Minute { + log.Printf("[update] ZLH_AGENT_UPDATE_INTERVAL too small (%s), using %s", d, interval) + } else { + interval = d + } + } + + log.Printf("[update] periodic checks enabled (mode=%s interval=%s)", mode, interval) + + run := func() { + switch mode { + case "auto": + res := CheckAndUpdate(currentVersion) + switch res.Status { + case "updated": + log.Printf("[update] applied update current=%s target=%s", res.Current, res.Target) + case "noop": + log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target) + default: + log.Printf("[update] auto check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error) + } + case "notify": + res := CheckAvailable(currentVersion) + switch res.Status { + case "available": + log.Printf("[update] update available current=%s target=%s", res.Current, res.Target) + case "noop": + log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target) + default: + log.Printf("[update] notify check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error) + } + } + } + + go func() { + time.Sleep(10 * time.Second) + run() + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for range ticker.C { + run() + } + }() +} + +func CheckAndUpdate(currentVersion string) Result { + currentVersion = normalizeVersion(currentVersion) + result := Result{ + Status: "error", + Current: currentVersion, + CheckedAtUTC: time.Now().UTC().Format(time.RFC3339), + } + + if runtime.GOOS != "linux" || runtime.GOARCH != "amd64" { + result.Error = fmt.Sprintf("unsupported platform: %s_%s", runtime.GOOS, runtime.GOARCH) + writeStatus(result) + return result + } + + lockPath := filepath.Join(stateDir, "update.lock") + if err := acquireLock(lockPath); err != nil { + result.Error = err.Error() + writeStatus(result) + return result + } + defer releaseLock(lockPath) + + manifest, err := fetchManifest() + if err != nil { + result.Error = err.Error() + writeStatus(result) + return result + } + if manifest.SchemaVersion != 1 { + result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion) + writeStatus(result) + return result + } + + target := normalizeVersion(manifest.Channels["stable"]) + if target == "" { + result.Error = "missing stable channel version" + writeStatus(result) + return result + } + result.Target = target + + if compareVersions(currentVersion, target) == 0 { + result.Status = "noop" + result.Error = "" + writeStatus(result) + return result + } + + if compareVersions(target, normalizeVersion(manifest.MinSupported)) < 0 { + result.Error = "target below min_supported" + writeStatus(result) + return result + } + + artifact, ok := manifest.Artifacts[target] + if !ok { + result.Error = "missing artifacts for target version" + writeStatus(result) + return result + } + + binPath := artifact.LinuxAMD64.Binary + shaPath := artifact.LinuxAMD64.SHA256 + if binPath == "" || shaPath == "" { + result.Error = "artifact paths missing" + writeStatus(result) + return result + } + + if err := ensureCurrentSymlinks(currentVersion); err != nil { + result.Error = fmt.Sprintf("prepare current symlink: %v", err) + writeStatus(result) + return result + } + + if err := os.MkdirAll(filepath.Join(releasesDir, target), 0o755); err != nil { + result.Error = err.Error() + writeStatus(result) + return result + } + + tmpPath := filepath.Join(releasesDir, target, "zlh-agent.new") + finalPath := filepath.Join(releasesDir, target, "zlh-agent") + + binURL := artifactBaseURL + binPath + shaURL := artifactBaseURL + shaPath + + if err := downloadFile(binURL, tmpPath); err != nil { + result.Error = fmt.Sprintf("download binary: %v", err) + writeStatus(result) + return result + } + + expected, err := downloadSHA256(shaURL) + if err != nil { + result.Error = fmt.Sprintf("download sha256: %v", err) + writeStatus(result) + return result + } + + if err := verifySHA256(tmpPath, expected); err != nil { + result.Error = fmt.Sprintf("sha256 verify failed: %v", err) + writeStatus(result) + return result + } + + if err := os.Chmod(tmpPath, 0o755); err != nil { + result.Error = fmt.Sprintf("chmod: %v", err) + writeStatus(result) + return result + } + + if err := os.Rename(tmpPath, finalPath); err != nil { + result.Error = fmt.Sprintf("install: %v", err) + writeStatus(result) + return result + } + + if err := updateSymlinks(target); err != nil { + result.Error = fmt.Sprintf("update symlinks: %v", err) + writeStatus(result) + return result + } + + if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) { + result.Error = fmt.Sprintf("remove binary path: %v", err) + writeStatus(result) + return result + } + if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil { + result.Error = fmt.Sprintf("update current symlink: %v", err) + writeStatus(result) + return result + } + + result.Status = "updated" + result.Error = "" + writeStatus(result) + go func() { + if err := restartService(); err != nil { + log.Printf("[update] restart failed: %v", err) + } + }() + return result +} + +func fetchManifest() (Manifest, error) { + url := artifactBaseURL + "manifest.json" + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(url) + if err != nil { + return Manifest{}, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return Manifest{}, fmt.Errorf("manifest status %d", resp.StatusCode) + } + var m Manifest + if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { + return Manifest{}, err + } + return m, nil +} + +func downloadFile(url, dest string) error { + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status %d", resp.StatusCode) + } + f, err := os.Create(dest) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err +} + +func downloadSHA256(url string) (string, error) { + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("status %d", resp.StatusCode) + } + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + fields := strings.Fields(string(b)) + if len(fields) == 0 { + return "", errors.New("empty sha256 file") + } + return strings.TrimSpace(fields[0]), nil +} + +func verifySHA256(path, expected string) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + sum := hex.EncodeToString(h.Sum(nil)) + if !strings.EqualFold(sum, expected) { + return fmt.Errorf("checksum mismatch") + } + return nil +} + +func restartService() error { + unit := os.Getenv("ZLH_AGENT_UNIT") + if unit == "" { + unit = defaultUnit + } + cmd := exec.Command("systemctl", "restart", unit) + return cmd.Run() +} + +func normalizeVersion(v string) string { + return strings.TrimPrefix(strings.TrimSpace(v), "v") +} + +func compareVersions(a, b string) int { + as := strings.Split(a, ".") + bs := strings.Split(b, ".") + for len(as) < 3 { + as = append(as, "0") + } + for len(bs) < 3 { + bs = append(bs, "0") + } + for i := 0; i < 3; i++ { + ai := parseInt(as[i]) + bi := parseInt(bs[i]) + if ai < bi { + return -1 + } + if ai > bi { + return 1 + } + } + return 0 +} + +func parseInt(s string) int { + n := 0 + for _, r := range s { + if r < '0' || r > '9' { + break + } + n = n*10 + int(r-'0') + } + return n +} + +func writeStatus(res Result) { + _ = os.MkdirAll(stateDir, 0o755) + b, _ := json.MarshalIndent(res, "", " ") + _ = os.WriteFile(statusFile, b, 0o644) +} + +func ReadStatus() Result { + b, err := os.ReadFile(statusFile) + if err != nil { + return Result{} + } + var res Result + _ = json.Unmarshal(b, &res) + return res +} + +func acquireLock(path string) error { + _ = os.MkdirAll(stateDir, 0o755) + f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644) + if err != nil { + return fmt.Errorf("update already in progress") + } + _ = f.Close() + return nil +} + +func releaseLock(path string) { + _ = os.Remove(path) +} + +func ensureCurrentSymlinks(currentVersion string) error { + if _, err := os.Lstat(currentLink); err == nil { + return nil + } + + if err := os.MkdirAll(releasesDir, 0o755); err != nil { + return err + } + + currentRelease := filepath.Join(releasesDir, currentVersion) + if err := os.MkdirAll(currentRelease, 0o755); err != nil { + return err + } + + currentBinary := filepath.Join(currentRelease, "zlh-agent") + if _, err := os.Stat(currentBinary); errors.Is(err, os.ErrNotExist) { + if err := copyFile(binaryPath, currentBinary); err != nil { + return err + } + if err := os.Chmod(currentBinary, 0o755); err != nil { + return err + } + } + + if err := os.Symlink(filepath.Join("releases", currentVersion), currentLink); err != nil { + return err + } + + if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil { + return err + } + + return nil +} + +func updateSymlinks(target string) error { + if err := os.RemoveAll(previousLink); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + if _, err := os.Lstat(currentLink); err == nil { + if err := os.Symlink("current", previousLink); err != nil { + return err + } + } + + if err := os.RemoveAll(currentLink); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } + return os.Symlink(filepath.Join("releases", target), currentLink) +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + if _, err := io.Copy(out, in); err != nil { + return err + } + return out.Close() +} diff --git a/internal/util/log.go b/internal/util/log.go index 582ed08..f690e94 100644 --- a/internal/util/log.go +++ b/internal/util/log.go @@ -2,6 +2,7 @@ package util import ( "fmt" + "io" "log" "os" "time" @@ -21,6 +22,7 @@ func InitLogFile(path string) error { } logFile = f logReady = true + log.SetOutput(io.MultiWriter(os.Stdout, logFile)) return nil } diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 0000000..5c8fd1a --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,3 @@ +package version + +const AgentVersion = "v1.0.0" diff --git a/main.go b/main.go index ccc920a..3fa4066 100755 --- a/main.go +++ b/main.go @@ -10,14 +10,14 @@ import ( "time" agenthttp "zlh-agent/internal/http" - "zlh-agent/internal/system" // <-- ADD THIS + "zlh-agent/internal/system" + "zlh-agent/internal/update" "zlh-agent/internal/util" + "zlh-agent/internal/version" ) -const AgentVersion = "v1.0.0" // Consolidated agent version tag - func main() { - log.Printf("[agent] starting ZeroLagHub Agent %s", AgentVersion) + log.Printf("[agent] starting ZeroLagHub Agent %s", version.AgentVersion) // ------------------------------------------------------------ // Optional: enable log file (safe if path doesn't exist yet) @@ -46,7 +46,8 @@ func main() { // Enable autostart subsystem // (does nothing unless AutoStartEnabled=true) // ------------------------------------------------------------ - system.InitAutoStart() // <-- ADD THIS + system.InitAutoStart() + update.StartPeriodic(version.AgentVersion) server := &http.Server{ Addr: addr, diff --git a/scripts/devcontainer/lib/common.sh b/scripts/devcontainer/lib/common.sh index ff8a132..ab351fe 100644 --- a/scripts/devcontainer/lib/common.sh +++ b/scripts/devcontainer/lib/common.sh @@ -1,20 +1,46 @@ #!/usr/bin/env bash set -euo pipefail +############################################ +# Required env (installer contract) +############################################ + : "${RUNTIME:?RUNTIME required}" : "${RUNTIME_VERSION:?RUNTIME_VERSION required}" : "${ARCHIVE_EXT:?ARCHIVE_EXT required}" +############################################ +# Optional env +############################################ + ZLH_ARTIFACT_BASE_URL="${ZLH_ARTIFACT_BASE_URL:-http://10.60.0.251:8080}" ZLH_RUNTIME_ROOT="${ZLH_RUNTIME_ROOT:-/opt/zlh/runtime}" ARCHIVE_PREFIX="${ARCHIVE_PREFIX:-${RUNTIME}}" +############################################ +# Derived paths +############################################ + RUNTIME_ROOT="${ZLH_RUNTIME_ROOT}/${RUNTIME}" DEST_DIR="${RUNTIME_ROOT}/${RUNTIME_VERSION}" CURRENT_LINK="${RUNTIME_ROOT}/current" -log() { echo "[zlh-installer:${RUNTIME}] $*"; } -fail() { echo "[zlh-installer:${RUNTIME}] ERROR: $*" >&2; exit 1; } +############################################ +# Logging helpers +############################################ + +log() { + echo "[zlh-installer:${RUNTIME}] $*" +} + +fail() { + echo "[zlh-installer:${RUNTIME}] ERROR: $*" >&2 + exit 1 +} + +############################################ +# Artifact helpers +############################################ artifact_name() { echo "${ARCHIVE_PREFIX}-${RUNTIME_VERSION}.${ARCHIVE_EXT}" @@ -24,6 +50,10 @@ artifact_url() { echo "${ZLH_ARTIFACT_BASE_URL%/}/devcontainer/${RUNTIME}/${RUNTIME_VERSION}/$(artifact_name)" } +############################################ +# Download / extract +############################################ + download_artifact() { local url out url="$(artifact_url)" @@ -56,6 +86,10 @@ extract_artifact() { esac } +############################################ +# Runtime wiring +############################################ + update_symlinks() { ln -sfn "${DEST_DIR}" "${CURRENT_LINK}" ln -sfn "${CURRENT_LINK}/bin" "${RUNTIME_ROOT}/bin" @@ -68,6 +102,35 @@ EOF chmod +x /etc/profile.d/zlh-${RUNTIME}.sh } +############################################ +# SSH host key initialization +############################################ + +ensure_ssh_host_keys() { + # SSH may not be installed in all templates — do not fail + if ! command -v ssh-keygen >/dev/null 2>&1; then + log "ssh-keygen not present, skipping SSH host key setup" + return 0 + fi + + if ls /etc/ssh/ssh_host_*_key >/dev/null 2>&1; then + log "SSH host keys already exist" + return 0 + fi + + log "Generating SSH host keys" + ssh-keygen -A + + # Best-effort service restart (container init systems vary) + systemctl enable ssh >/dev/null 2>&1 || true + systemctl restart ssh >/dev/null 2>&1 || \ + systemctl restart sshd >/dev/null 2>&1 || true +} + +############################################ +# Entry point +############################################ + install_runtime() { log "Installing ${RUNTIME} ${RUNTIME_VERSION}" @@ -84,5 +147,7 @@ install_runtime() { write_profile chmod -R 755 "${DEST_DIR}" + ensure_ssh_host_keys + log "Install complete" } diff --git a/zlh-agent b/zlh-agent index 07e39d3..9dea1f3 100755 Binary files a/zlh-agent and b/zlh-agent differ