package update import ( "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "log" "net/http" "os" "os/exec" "path/filepath" "runtime" "sort" "strconv" "strings" "time" ) const ( artifactBaseURL = "http://10.60.0.251:8080/agents/" releasesDir = "/opt/zlh-agent/releases" currentLink = "/opt/zlh-agent/current" previousLink = "/opt/zlh-agent/previous" binaryPath = "/opt/zlh-agent/zlh-agent" stateDir = "/opt/zlh-agent/state" statusFile = "/opt/zlh-agent/state/update.json" defaultUnit = "zlh-agent" defaultMode = "notify" defaultKeepReleases = 3 // current + 2 previous ) type Manifest struct { SchemaVersion int `json:"schema_version"` Latest string `json:"latest"` MinSupported string `json:"min_supported"` Channels map[string]string `json:"channels"` Artifacts map[string]struct { LinuxAMD64 struct { Binary string `json:"binary"` SHA256 string `json:"sha256"` } `json:"linux_amd64"` } `json:"artifacts"` ReleasedAt string `json:"released_at"` } type Result struct { Status string `json:"status"` Current string `json:"current,omitempty"` Target string `json:"target,omitempty"` Error string `json:"error,omitempty"` CheckedAtUTC string `json:"checked_at_utc,omitempty"` } func CheckAvailable(currentVersion string) Result { currentVersion = normalizeVersion(currentVersion) result := Result{ Status: "error", Current: currentVersion, CheckedAtUTC: time.Now().UTC().Format(time.RFC3339), } manifest, err := fetchManifest() if err != nil { result.Error = err.Error() writeStatus(result) return result } if manifest.SchemaVersion != 1 { result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion) writeStatus(result) return result } target := normalizeVersion(manifest.Channels["stable"]) if target == "" { result.Error = "missing stable channel version" writeStatus(result) return result } result.Target = target if compareVersions(currentVersion, target) >= 0 { result.Status = "noop" result.Error = "" writeStatus(result) return result } result.Status = "available" result.Error = "" writeStatus(result) return result } func StartPeriodic(currentVersion string) { mode := strings.ToLower(strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_MODE"))) if mode == "" { mode = defaultMode } switch mode { case "off", "disabled": log.Printf("[update] periodic checks disabled (mode=%s)", mode) return case "notify", "auto": default: log.Printf("[update] invalid mode %q, using %q", mode, defaultMode) mode = defaultMode } interval := 30 * time.Minute if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_INTERVAL")); v != "" { d, err := time.ParseDuration(v) if err != nil { log.Printf("[update] invalid ZLH_AGENT_UPDATE_INTERVAL=%q: %v (using %s)", v, err, interval) } else if d < time.Minute { log.Printf("[update] ZLH_AGENT_UPDATE_INTERVAL too small (%s), using %s", d, interval) } else { interval = d } } log.Printf("[update] periodic checks enabled (mode=%s interval=%s)", mode, interval) run := func() { switch mode { case "auto": res := CheckAndUpdate(currentVersion) switch res.Status { case "updated": log.Printf("[update] applied update current=%s target=%s", res.Current, res.Target) case "noop": log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target) default: log.Printf("[update] auto check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error) } case "notify": res := CheckAvailable(currentVersion) switch res.Status { case "available": log.Printf("[update] update available current=%s target=%s", res.Current, res.Target) case "noop": log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target) default: log.Printf("[update] notify check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error) } } } go func() { time.Sleep(10 * time.Second) run() ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { run() } }() } func CheckAndUpdate(currentVersion string) Result { currentVersion = normalizeVersion(currentVersion) result := Result{ Status: "error", Current: currentVersion, CheckedAtUTC: time.Now().UTC().Format(time.RFC3339), } if runtime.GOOS != "linux" || runtime.GOARCH != "amd64" { result.Error = fmt.Sprintf("unsupported platform: %s_%s", runtime.GOOS, runtime.GOARCH) writeStatus(result) return result } lockPath := filepath.Join(stateDir, "update.lock") if err := acquireLock(lockPath); err != nil { result.Error = err.Error() writeStatus(result) return result } defer releaseLock(lockPath) manifest, err := fetchManifest() if err != nil { result.Error = err.Error() writeStatus(result) return result } if manifest.SchemaVersion != 1 { result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion) writeStatus(result) return result } target := normalizeVersion(manifest.Channels["stable"]) if target == "" { result.Error = "missing stable channel version" writeStatus(result) return result } result.Target = target if compareVersions(currentVersion, target) == 0 { result.Status = "noop" result.Error = "" writeStatus(result) return result } if compareVersions(target, normalizeVersion(manifest.MinSupported)) < 0 { result.Error = "target below min_supported" writeStatus(result) return result } artifact, ok := manifest.Artifacts[target] if !ok { result.Error = "missing artifacts for target version" writeStatus(result) return result } binPath := artifact.LinuxAMD64.Binary shaPath := artifact.LinuxAMD64.SHA256 if binPath == "" || shaPath == "" { result.Error = "artifact paths missing" writeStatus(result) return result } if err := ensureCurrentSymlinks(currentVersion); err != nil { result.Error = fmt.Sprintf("prepare current symlink: %v", err) writeStatus(result) return result } if err := os.MkdirAll(filepath.Join(releasesDir, target), 0o755); err != nil { result.Error = err.Error() writeStatus(result) return result } tmpPath := filepath.Join(releasesDir, target, "zlh-agent.new") finalPath := filepath.Join(releasesDir, target, "zlh-agent") binURL := artifactBaseURL + binPath shaURL := artifactBaseURL + shaPath if err := downloadFile(binURL, tmpPath); err != nil { result.Error = fmt.Sprintf("download binary: %v", err) writeStatus(result) return result } expected, err := downloadSHA256(shaURL) if err != nil { result.Error = fmt.Sprintf("download sha256: %v", err) writeStatus(result) return result } if err := verifySHA256(tmpPath, expected); err != nil { result.Error = fmt.Sprintf("sha256 verify failed: %v", err) writeStatus(result) return result } if err := os.Chmod(tmpPath, 0o755); err != nil { result.Error = fmt.Sprintf("chmod: %v", err) writeStatus(result) return result } if err := os.Rename(tmpPath, finalPath); err != nil { result.Error = fmt.Sprintf("install: %v", err) writeStatus(result) return result } if err := updateSymlinks(target); err != nil { result.Error = fmt.Sprintf("update symlinks: %v", err) writeStatus(result) return result } if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) { result.Error = fmt.Sprintf("remove binary path: %v", err) writeStatus(result) return result } if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil { result.Error = fmt.Sprintf("update current symlink: %v", err) writeStatus(result) return result } keep := defaultKeepReleases if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_KEEP_RELEASES")); v != "" { if n, err := strconv.Atoi(v); err == nil && n >= 2 { keep = n } } if err := pruneOldReleases(keep); err != nil { log.Printf("[update] prune warning: %v", err) } result.Status = "updated" result.Error = "" writeStatus(result) if err := scheduleRollbackGuard(target); err != nil { log.Printf("[update] rollback guard warning: %v", err) } go func() { if err := restartService(); err != nil { log.Printf("[update] restart failed: %v", err) } }() return result } func fetchManifest() (Manifest, error) { url := artifactBaseURL + "manifest.json" client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Get(url) if err != nil { return Manifest{}, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return Manifest{}, fmt.Errorf("manifest status %d", resp.StatusCode) } var m Manifest if err := json.NewDecoder(resp.Body).Decode(&m); err != nil { return Manifest{}, err } return m, nil } func downloadFile(url, dest string) error { client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Get(url) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("status %d", resp.StatusCode) } f, err := os.Create(dest) if err != nil { return err } defer f.Close() _, err = io.Copy(f, resp.Body) return err } func downloadSHA256(url string) (string, error) { client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Get(url) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("status %d", resp.StatusCode) } b, err := io.ReadAll(resp.Body) if err != nil { return "", err } fields := strings.Fields(string(b)) if len(fields) == 0 { return "", errors.New("empty sha256 file") } return strings.TrimSpace(fields[0]), nil } func verifySHA256(path, expected string) error { f, err := os.Open(path) if err != nil { return err } defer f.Close() h := sha256.New() if _, err := io.Copy(h, f); err != nil { return err } sum := hex.EncodeToString(h.Sum(nil)) if !strings.EqualFold(sum, expected) { return fmt.Errorf("checksum mismatch") } return nil } func restartService() error { unit := os.Getenv("ZLH_AGENT_UNIT") if unit == "" { unit = defaultUnit } cmd := exec.Command("systemctl", "restart", "--no-block", unit) return cmd.Run() } func scheduleRollbackGuard(target string) error { unit := os.Getenv("ZLH_AGENT_UNIT") if unit == "" { unit = defaultUnit } port := strings.TrimSpace(os.Getenv("ZLH_AGENT_PORT")) if port == "" { port = "18888" } target = normalizeVersion(target) if target == "" { return nil } script := fmt.Sprintf( "sleep 25; "+ "if ! curl -fsS http://127.0.0.1:%s/health >/dev/null 2>&1 || "+ "! curl -fsS http://127.0.0.1:%s/version 2>/dev/null | grep -q '\"version\":\"v%s\"'; then "+ "prev=$(readlink -f %s || true); "+ "if [ -n \"$prev\" ] && [ -d \"$prev\" ]; then "+ "b=$(basename \"$prev\"); "+ "ln -sfn \"releases/$b\" %s; "+ "ln -sfn %s/zlh-agent %s; "+ "systemctl restart --no-block %s; "+ "fi; "+ "fi", port, port, target, previousLink, currentLink, currentLink, binaryPath, unit, ) transientUnit := fmt.Sprintf("zlh-agent-update-verify-%d", time.Now().UnixNano()) cmd := exec.Command("systemd-run", "--unit", transientUnit, "--collect", "/bin/sh", "-c", script) if err := cmd.Run(); err != nil { return fmt.Errorf("schedule rollback guard: %w", err) } return nil } func normalizeVersion(v string) string { return strings.TrimPrefix(strings.TrimSpace(v), "v") } func compareVersions(a, b string) int { as := strings.Split(a, ".") bs := strings.Split(b, ".") for len(as) < 3 { as = append(as, "0") } for len(bs) < 3 { bs = append(bs, "0") } for i := 0; i < 3; i++ { ai := parseInt(as[i]) bi := parseInt(bs[i]) if ai < bi { return -1 } if ai > bi { return 1 } } return 0 } func parseInt(s string) int { n := 0 for _, r := range s { if r < '0' || r > '9' { break } n = n*10 + int(r-'0') } return n } func writeStatus(res Result) { _ = os.MkdirAll(stateDir, 0o755) b, _ := json.MarshalIndent(res, "", " ") _ = os.WriteFile(statusFile, b, 0o644) } func ReadStatus() Result { b, err := os.ReadFile(statusFile) if err != nil { return Result{} } var res Result _ = json.Unmarshal(b, &res) return res } func acquireLock(path string) error { _ = os.MkdirAll(stateDir, 0o755) f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644) if err != nil { return fmt.Errorf("update already in progress") } _ = f.Close() return nil } func releaseLock(path string) { _ = os.Remove(path) } func ensureCurrentSymlinks(currentVersion string) error { if _, err := os.Lstat(currentLink); err == nil { return nil } if err := os.MkdirAll(releasesDir, 0o755); err != nil { return err } currentRelease := filepath.Join(releasesDir, currentVersion) if err := os.MkdirAll(currentRelease, 0o755); err != nil { return err } currentBinary := filepath.Join(currentRelease, "zlh-agent") if _, err := os.Stat(currentBinary); errors.Is(err, os.ErrNotExist) { if err := copyFile(binaryPath, currentBinary); err != nil { return err } if err := os.Chmod(currentBinary, 0o755); err != nil { return err } } if err := os.Symlink(filepath.Join("releases", currentVersion), currentLink); err != nil { return err } if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) { return err } if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil { return err } return nil } func updateSymlinks(target string) error { if err := os.RemoveAll(previousLink); err != nil && !errors.Is(err, os.ErrNotExist) { return err } if _, err := os.Lstat(currentLink); err == nil { if err := os.Symlink("current", previousLink); err != nil { return err } } if err := os.RemoveAll(currentLink); err != nil && !errors.Is(err, os.ErrNotExist) { return err } return os.Symlink(filepath.Join("releases", target), currentLink) } func copyFile(src, dst string) error { in, err := os.Open(src) if err != nil { return err } defer in.Close() out, err := os.Create(dst) if err != nil { return err } defer out.Close() if _, err := io.Copy(out, in); err != nil { return err } return out.Close() } func pruneOldReleases(keep int) error { entries, err := os.ReadDir(releasesDir) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil } return err } currentResolved, _ := filepath.EvalSymlinks(currentLink) previousResolved, _ := filepath.EvalSymlinks(previousLink) protected := map[string]struct{}{} if currentResolved != "" { protected[currentResolved] = struct{}{} } if previousResolved != "" { protected[previousResolved] = struct{}{} } type rel struct { name string path string } rels := make([]rel, 0, len(entries)) for _, e := range entries { if !e.IsDir() { continue } name := e.Name() if !isSemverLike(name) { continue } rels = append(rels, rel{name: name, path: filepath.Join(releasesDir, name)}) } sort.Slice(rels, func(i, j int) bool { return compareVersions(rels[i].name, rels[j].name) > 0 }) for idx, r := range rels { if idx < keep { continue } if _, ok := protected[r.path]; ok { continue } if err := os.RemoveAll(r.path); err != nil { log.Printf("[update] prune failed for %s: %v", r.path, err) } } return nil } func isSemverLike(v string) bool { parts := strings.Split(v, ".") if len(parts) != 3 { return false } for _, p := range parts { if p == "" { return false } for _, r := range p { if r < '0' || r > '9' { return false } } } return true }