zlh-agent/internal/update/update.go
2026-03-07 20:59:27 +00:00

658 lines
15 KiB
Go

package update
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"time"
)
const (
artifactBaseURL = "http://10.60.0.251:8080/agents/"
releasesDir = "/opt/zlh-agent/releases"
currentLink = "/opt/zlh-agent/current"
previousLink = "/opt/zlh-agent/previous"
binaryPath = "/opt/zlh-agent/zlh-agent"
stateDir = "/opt/zlh-agent/state"
statusFile = "/opt/zlh-agent/state/update.json"
defaultUnit = "zlh-agent"
defaultMode = "notify"
defaultKeepReleases = 3 // current + 2 previous
)
type Manifest struct {
SchemaVersion int `json:"schema_version"`
Latest string `json:"latest"`
MinSupported string `json:"min_supported"`
Channels map[string]string `json:"channels"`
Artifacts map[string]struct {
LinuxAMD64 struct {
Binary string `json:"binary"`
SHA256 string `json:"sha256"`
} `json:"linux_amd64"`
} `json:"artifacts"`
ReleasedAt string `json:"released_at"`
}
type Result struct {
Status string `json:"status"`
Current string `json:"current,omitempty"`
Target string `json:"target,omitempty"`
Error string `json:"error,omitempty"`
CheckedAtUTC string `json:"checked_at_utc,omitempty"`
}
func CheckAvailable(currentVersion string) Result {
currentVersion = normalizeVersion(currentVersion)
result := Result{
Status: "error",
Current: currentVersion,
CheckedAtUTC: time.Now().UTC().Format(time.RFC3339),
}
manifest, err := fetchManifest()
if err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
if manifest.SchemaVersion != 1 {
result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion)
writeStatus(result)
return result
}
target := normalizeVersion(manifest.Channels["stable"])
if target == "" {
result.Error = "missing stable channel version"
writeStatus(result)
return result
}
result.Target = target
if compareVersions(currentVersion, target) >= 0 {
result.Status = "noop"
result.Error = ""
writeStatus(result)
return result
}
result.Status = "available"
result.Error = ""
writeStatus(result)
return result
}
func StartPeriodic(currentVersion string) {
mode := strings.ToLower(strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_MODE")))
if mode == "" {
mode = defaultMode
}
switch mode {
case "off", "disabled":
log.Printf("[update] periodic checks disabled (mode=%s)", mode)
return
case "notify", "auto":
default:
log.Printf("[update] invalid mode %q, using %q", mode, defaultMode)
mode = defaultMode
}
interval := 30 * time.Minute
if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_INTERVAL")); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
log.Printf("[update] invalid ZLH_AGENT_UPDATE_INTERVAL=%q: %v (using %s)", v, err, interval)
} else if d < time.Minute {
log.Printf("[update] ZLH_AGENT_UPDATE_INTERVAL too small (%s), using %s", d, interval)
} else {
interval = d
}
}
log.Printf("[update] periodic checks enabled (mode=%s interval=%s)", mode, interval)
run := func() {
switch mode {
case "auto":
res := CheckAndUpdate(currentVersion)
switch res.Status {
case "updated":
log.Printf("[update] applied update current=%s target=%s", res.Current, res.Target)
case "noop":
log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target)
default:
log.Printf("[update] auto check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error)
}
case "notify":
res := CheckAvailable(currentVersion)
switch res.Status {
case "available":
log.Printf("[update] update available current=%s target=%s", res.Current, res.Target)
case "noop":
log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target)
default:
log.Printf("[update] notify check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error)
}
}
}
go func() {
time.Sleep(10 * time.Second)
run()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
run()
}
}()
}
func CheckAndUpdate(currentVersion string) Result {
currentVersion = normalizeVersion(currentVersion)
result := Result{
Status: "error",
Current: currentVersion,
CheckedAtUTC: time.Now().UTC().Format(time.RFC3339),
}
if runtime.GOOS != "linux" || runtime.GOARCH != "amd64" {
result.Error = fmt.Sprintf("unsupported platform: %s_%s", runtime.GOOS, runtime.GOARCH)
writeStatus(result)
return result
}
lockPath := filepath.Join(stateDir, "update.lock")
if err := acquireLock(lockPath); err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
defer releaseLock(lockPath)
manifest, err := fetchManifest()
if err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
if manifest.SchemaVersion != 1 {
result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion)
writeStatus(result)
return result
}
target := normalizeVersion(manifest.Channels["stable"])
if target == "" {
result.Error = "missing stable channel version"
writeStatus(result)
return result
}
result.Target = target
if compareVersions(currentVersion, target) == 0 {
result.Status = "noop"
result.Error = ""
writeStatus(result)
return result
}
if compareVersions(target, normalizeVersion(manifest.MinSupported)) < 0 {
result.Error = "target below min_supported"
writeStatus(result)
return result
}
artifact, ok := manifest.Artifacts[target]
if !ok {
result.Error = "missing artifacts for target version"
writeStatus(result)
return result
}
binPath := artifact.LinuxAMD64.Binary
shaPath := artifact.LinuxAMD64.SHA256
if binPath == "" || shaPath == "" {
result.Error = "artifact paths missing"
writeStatus(result)
return result
}
if err := ensureCurrentSymlinks(currentVersion); err != nil {
result.Error = fmt.Sprintf("prepare current symlink: %v", err)
writeStatus(result)
return result
}
if err := os.MkdirAll(filepath.Join(releasesDir, target), 0o755); err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
tmpPath := filepath.Join(releasesDir, target, "zlh-agent.new")
finalPath := filepath.Join(releasesDir, target, "zlh-agent")
binURL := artifactBaseURL + binPath
shaURL := artifactBaseURL + shaPath
if err := downloadFile(binURL, tmpPath); err != nil {
result.Error = fmt.Sprintf("download binary: %v", err)
writeStatus(result)
return result
}
expected, err := downloadSHA256(shaURL)
if err != nil {
result.Error = fmt.Sprintf("download sha256: %v", err)
writeStatus(result)
return result
}
if err := verifySHA256(tmpPath, expected); err != nil {
result.Error = fmt.Sprintf("sha256 verify failed: %v", err)
writeStatus(result)
return result
}
if err := os.Chmod(tmpPath, 0o755); err != nil {
result.Error = fmt.Sprintf("chmod: %v", err)
writeStatus(result)
return result
}
if err := os.Rename(tmpPath, finalPath); err != nil {
result.Error = fmt.Sprintf("install: %v", err)
writeStatus(result)
return result
}
if err := updateSymlinks(target); err != nil {
result.Error = fmt.Sprintf("update symlinks: %v", err)
writeStatus(result)
return result
}
if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) {
result.Error = fmt.Sprintf("remove binary path: %v", err)
writeStatus(result)
return result
}
if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil {
result.Error = fmt.Sprintf("update current symlink: %v", err)
writeStatus(result)
return result
}
keep := defaultKeepReleases
if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_KEEP_RELEASES")); v != "" {
if n, err := strconv.Atoi(v); err == nil && n >= 2 {
keep = n
}
}
if err := pruneOldReleases(keep); err != nil {
log.Printf("[update] prune warning: %v", err)
}
result.Status = "updated"
result.Error = ""
writeStatus(result)
if err := scheduleRollbackGuard(target); err != nil {
log.Printf("[update] rollback guard warning: %v", err)
}
go func() {
if err := restartService(); err != nil {
log.Printf("[update] restart failed: %v", err)
}
}()
return result
}
func fetchManifest() (Manifest, error) {
url := artifactBaseURL + "manifest.json"
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return Manifest{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return Manifest{}, fmt.Errorf("manifest status %d", resp.StatusCode)
}
var m Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return Manifest{}, err
}
return m, nil
}
func downloadFile(url, dest string) error {
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status %d", resp.StatusCode)
}
f, err := os.Create(dest)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func downloadSHA256(url string) (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("status %d", resp.StatusCode)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
fields := strings.Fields(string(b))
if len(fields) == 0 {
return "", errors.New("empty sha256 file")
}
return strings.TrimSpace(fields[0]), nil
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
sum := hex.EncodeToString(h.Sum(nil))
if !strings.EqualFold(sum, expected) {
return fmt.Errorf("checksum mismatch")
}
return nil
}
func restartService() error {
unit := os.Getenv("ZLH_AGENT_UNIT")
if unit == "" {
unit = defaultUnit
}
cmd := exec.Command("systemctl", "restart", "--no-block", unit)
return cmd.Run()
}
func scheduleRollbackGuard(target string) error {
unit := os.Getenv("ZLH_AGENT_UNIT")
if unit == "" {
unit = defaultUnit
}
port := strings.TrimSpace(os.Getenv("ZLH_AGENT_PORT"))
if port == "" {
port = "18888"
}
target = normalizeVersion(target)
if target == "" {
return nil
}
script := fmt.Sprintf(
"sleep 25; "+
"if ! curl -fsS http://127.0.0.1:%s/health >/dev/null 2>&1 || "+
"! curl -fsS http://127.0.0.1:%s/version 2>/dev/null | grep -q '\"version\":\"v%s\"'; then "+
"prev=$(readlink -f %s || true); "+
"if [ -n \"$prev\" ] && [ -d \"$prev\" ]; then "+
"b=$(basename \"$prev\"); "+
"ln -sfn \"releases/$b\" %s; "+
"ln -sfn %s/zlh-agent %s; "+
"systemctl restart --no-block %s; "+
"fi; "+
"fi",
port, port, target, previousLink, currentLink, currentLink, binaryPath, unit,
)
transientUnit := fmt.Sprintf("zlh-agent-update-verify-%d", time.Now().UnixNano())
cmd := exec.Command("systemd-run", "--unit", transientUnit, "--collect", "/bin/sh", "-c", script)
if err := cmd.Run(); err != nil {
return fmt.Errorf("schedule rollback guard: %w", err)
}
return nil
}
func normalizeVersion(v string) string {
return strings.TrimPrefix(strings.TrimSpace(v), "v")
}
func compareVersions(a, b string) int {
as := strings.Split(a, ".")
bs := strings.Split(b, ".")
for len(as) < 3 {
as = append(as, "0")
}
for len(bs) < 3 {
bs = append(bs, "0")
}
for i := 0; i < 3; i++ {
ai := parseInt(as[i])
bi := parseInt(bs[i])
if ai < bi {
return -1
}
if ai > bi {
return 1
}
}
return 0
}
func parseInt(s string) int {
n := 0
for _, r := range s {
if r < '0' || r > '9' {
break
}
n = n*10 + int(r-'0')
}
return n
}
func writeStatus(res Result) {
_ = os.MkdirAll(stateDir, 0o755)
b, _ := json.MarshalIndent(res, "", " ")
_ = os.WriteFile(statusFile, b, 0o644)
}
func ReadStatus() Result {
b, err := os.ReadFile(statusFile)
if err != nil {
return Result{}
}
var res Result
_ = json.Unmarshal(b, &res)
return res
}
func acquireLock(path string) error {
_ = os.MkdirAll(stateDir, 0o755)
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("update already in progress")
}
_ = f.Close()
return nil
}
func releaseLock(path string) {
_ = os.Remove(path)
}
func ensureCurrentSymlinks(currentVersion string) error {
if _, err := os.Lstat(currentLink); err == nil {
return nil
}
if err := os.MkdirAll(releasesDir, 0o755); err != nil {
return err
}
currentRelease := filepath.Join(releasesDir, currentVersion)
if err := os.MkdirAll(currentRelease, 0o755); err != nil {
return err
}
currentBinary := filepath.Join(currentRelease, "zlh-agent")
if _, err := os.Stat(currentBinary); errors.Is(err, os.ErrNotExist) {
if err := copyFile(binaryPath, currentBinary); err != nil {
return err
}
if err := os.Chmod(currentBinary, 0o755); err != nil {
return err
}
}
if err := os.Symlink(filepath.Join("releases", currentVersion), currentLink); err != nil {
return err
}
if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil {
return err
}
return nil
}
func updateSymlinks(target string) error {
if err := os.RemoveAll(previousLink); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if _, err := os.Lstat(currentLink); err == nil {
if err := os.Symlink("current", previousLink); err != nil {
return err
}
}
if err := os.RemoveAll(currentLink); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
return os.Symlink(filepath.Join("releases", target), currentLink)
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
if _, err := io.Copy(out, in); err != nil {
return err
}
return out.Close()
}
func pruneOldReleases(keep int) error {
entries, err := os.ReadDir(releasesDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
}
return err
}
currentResolved, _ := filepath.EvalSymlinks(currentLink)
previousResolved, _ := filepath.EvalSymlinks(previousLink)
protected := map[string]struct{}{}
if currentResolved != "" {
protected[currentResolved] = struct{}{}
}
if previousResolved != "" {
protected[previousResolved] = struct{}{}
}
type rel struct {
name string
path string
}
rels := make([]rel, 0, len(entries))
for _, e := range entries {
if !e.IsDir() {
continue
}
name := e.Name()
if !isSemverLike(name) {
continue
}
rels = append(rels, rel{name: name, path: filepath.Join(releasesDir, name)})
}
sort.Slice(rels, func(i, j int) bool {
return compareVersions(rels[i].name, rels[j].name) > 0
})
for idx, r := range rels {
if idx < keep {
continue
}
if _, ok := protected[r.path]; ok {
continue
}
if err := os.RemoveAll(r.path); err != nil {
log.Printf("[update] prune failed for %s: %v", r.path, err)
}
}
return nil
}
func isSemverLike(v string) bool {
parts := strings.Split(v, ".")
if len(parts) != 3 {
return false
}
for _, p := range parts {
if p == "" {
return false
}
for _, r := range p {
if r < '0' || r > '9' {
return false
}
}
}
return true
}