658 lines
15 KiB
Go
658 lines
15 KiB
Go
package update
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"runtime"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
artifactBaseURL = "http://10.60.0.251:8080/agents/"
|
|
releasesDir = "/opt/zlh-agent/releases"
|
|
currentLink = "/opt/zlh-agent/current"
|
|
previousLink = "/opt/zlh-agent/previous"
|
|
binaryPath = "/opt/zlh-agent/zlh-agent"
|
|
stateDir = "/opt/zlh-agent/state"
|
|
statusFile = "/opt/zlh-agent/state/update.json"
|
|
defaultUnit = "zlh-agent"
|
|
defaultMode = "notify"
|
|
defaultKeepReleases = 3 // current + 2 previous
|
|
)
|
|
|
|
type Manifest struct {
|
|
SchemaVersion int `json:"schema_version"`
|
|
Latest string `json:"latest"`
|
|
MinSupported string `json:"min_supported"`
|
|
Channels map[string]string `json:"channels"`
|
|
Artifacts map[string]struct {
|
|
LinuxAMD64 struct {
|
|
Binary string `json:"binary"`
|
|
SHA256 string `json:"sha256"`
|
|
} `json:"linux_amd64"`
|
|
} `json:"artifacts"`
|
|
ReleasedAt string `json:"released_at"`
|
|
}
|
|
|
|
type Result struct {
|
|
Status string `json:"status"`
|
|
Current string `json:"current,omitempty"`
|
|
Target string `json:"target,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
CheckedAtUTC string `json:"checked_at_utc,omitempty"`
|
|
}
|
|
|
|
func CheckAvailable(currentVersion string) Result {
|
|
currentVersion = normalizeVersion(currentVersion)
|
|
result := Result{
|
|
Status: "error",
|
|
Current: currentVersion,
|
|
CheckedAtUTC: time.Now().UTC().Format(time.RFC3339),
|
|
}
|
|
|
|
manifest, err := fetchManifest()
|
|
if err != nil {
|
|
result.Error = err.Error()
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
if manifest.SchemaVersion != 1 {
|
|
result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
target := normalizeVersion(manifest.Channels["stable"])
|
|
if target == "" {
|
|
result.Error = "missing stable channel version"
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
result.Target = target
|
|
|
|
if compareVersions(currentVersion, target) >= 0 {
|
|
result.Status = "noop"
|
|
result.Error = ""
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
result.Status = "available"
|
|
result.Error = ""
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
func StartPeriodic(currentVersion string) {
|
|
mode := strings.ToLower(strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_MODE")))
|
|
if mode == "" {
|
|
mode = defaultMode
|
|
}
|
|
|
|
switch mode {
|
|
case "off", "disabled":
|
|
log.Printf("[update] periodic checks disabled (mode=%s)", mode)
|
|
return
|
|
case "notify", "auto":
|
|
default:
|
|
log.Printf("[update] invalid mode %q, using %q", mode, defaultMode)
|
|
mode = defaultMode
|
|
}
|
|
|
|
interval := 30 * time.Minute
|
|
if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_INTERVAL")); v != "" {
|
|
d, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
log.Printf("[update] invalid ZLH_AGENT_UPDATE_INTERVAL=%q: %v (using %s)", v, err, interval)
|
|
} else if d < time.Minute {
|
|
log.Printf("[update] ZLH_AGENT_UPDATE_INTERVAL too small (%s), using %s", d, interval)
|
|
} else {
|
|
interval = d
|
|
}
|
|
}
|
|
|
|
log.Printf("[update] periodic checks enabled (mode=%s interval=%s)", mode, interval)
|
|
|
|
run := func() {
|
|
switch mode {
|
|
case "auto":
|
|
res := CheckAndUpdate(currentVersion)
|
|
switch res.Status {
|
|
case "updated":
|
|
log.Printf("[update] applied update current=%s target=%s", res.Current, res.Target)
|
|
case "noop":
|
|
log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target)
|
|
default:
|
|
log.Printf("[update] auto check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error)
|
|
}
|
|
case "notify":
|
|
res := CheckAvailable(currentVersion)
|
|
switch res.Status {
|
|
case "available":
|
|
log.Printf("[update] update available current=%s target=%s", res.Current, res.Target)
|
|
case "noop":
|
|
log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target)
|
|
default:
|
|
log.Printf("[update] notify check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error)
|
|
}
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
time.Sleep(10 * time.Second)
|
|
run()
|
|
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
run()
|
|
}
|
|
}()
|
|
}
|
|
|
|
func CheckAndUpdate(currentVersion string) Result {
|
|
currentVersion = normalizeVersion(currentVersion)
|
|
result := Result{
|
|
Status: "error",
|
|
Current: currentVersion,
|
|
CheckedAtUTC: time.Now().UTC().Format(time.RFC3339),
|
|
}
|
|
|
|
if runtime.GOOS != "linux" || runtime.GOARCH != "amd64" {
|
|
result.Error = fmt.Sprintf("unsupported platform: %s_%s", runtime.GOOS, runtime.GOARCH)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
lockPath := filepath.Join(stateDir, "update.lock")
|
|
if err := acquireLock(lockPath); err != nil {
|
|
result.Error = err.Error()
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
defer releaseLock(lockPath)
|
|
|
|
manifest, err := fetchManifest()
|
|
if err != nil {
|
|
result.Error = err.Error()
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
if manifest.SchemaVersion != 1 {
|
|
result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
target := normalizeVersion(manifest.Channels["stable"])
|
|
if target == "" {
|
|
result.Error = "missing stable channel version"
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
result.Target = target
|
|
|
|
if compareVersions(currentVersion, target) == 0 {
|
|
result.Status = "noop"
|
|
result.Error = ""
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if compareVersions(target, normalizeVersion(manifest.MinSupported)) < 0 {
|
|
result.Error = "target below min_supported"
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
artifact, ok := manifest.Artifacts[target]
|
|
if !ok {
|
|
result.Error = "missing artifacts for target version"
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
binPath := artifact.LinuxAMD64.Binary
|
|
shaPath := artifact.LinuxAMD64.SHA256
|
|
if binPath == "" || shaPath == "" {
|
|
result.Error = "artifact paths missing"
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := ensureCurrentSymlinks(currentVersion); err != nil {
|
|
result.Error = fmt.Sprintf("prepare current symlink: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Join(releasesDir, target), 0o755); err != nil {
|
|
result.Error = err.Error()
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
tmpPath := filepath.Join(releasesDir, target, "zlh-agent.new")
|
|
finalPath := filepath.Join(releasesDir, target, "zlh-agent")
|
|
|
|
binURL := artifactBaseURL + binPath
|
|
shaURL := artifactBaseURL + shaPath
|
|
|
|
if err := downloadFile(binURL, tmpPath); err != nil {
|
|
result.Error = fmt.Sprintf("download binary: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
expected, err := downloadSHA256(shaURL)
|
|
if err != nil {
|
|
result.Error = fmt.Sprintf("download sha256: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := verifySHA256(tmpPath, expected); err != nil {
|
|
result.Error = fmt.Sprintf("sha256 verify failed: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := os.Chmod(tmpPath, 0o755); err != nil {
|
|
result.Error = fmt.Sprintf("chmod: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := os.Rename(tmpPath, finalPath); err != nil {
|
|
result.Error = fmt.Sprintf("install: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := updateSymlinks(target); err != nil {
|
|
result.Error = fmt.Sprintf("update symlinks: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
result.Error = fmt.Sprintf("remove binary path: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil {
|
|
result.Error = fmt.Sprintf("update current symlink: %v", err)
|
|
writeStatus(result)
|
|
return result
|
|
}
|
|
|
|
keep := defaultKeepReleases
|
|
if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_KEEP_RELEASES")); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil && n >= 2 {
|
|
keep = n
|
|
}
|
|
}
|
|
if err := pruneOldReleases(keep); err != nil {
|
|
log.Printf("[update] prune warning: %v", err)
|
|
}
|
|
|
|
result.Status = "updated"
|
|
result.Error = ""
|
|
writeStatus(result)
|
|
|
|
if err := scheduleRollbackGuard(target); err != nil {
|
|
log.Printf("[update] rollback guard warning: %v", err)
|
|
}
|
|
|
|
go func() {
|
|
if err := restartService(); err != nil {
|
|
log.Printf("[update] restart failed: %v", err)
|
|
}
|
|
}()
|
|
return result
|
|
}
|
|
|
|
func fetchManifest() (Manifest, error) {
|
|
url := artifactBaseURL + "manifest.json"
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
return Manifest{}, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return Manifest{}, fmt.Errorf("manifest status %d", resp.StatusCode)
|
|
}
|
|
var m Manifest
|
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
|
return Manifest{}, err
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
func downloadFile(url, dest string) error {
|
|
client := &http.Client{Timeout: 30 * time.Second}
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("status %d", resp.StatusCode)
|
|
}
|
|
f, err := os.Create(dest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
_, err = io.Copy(f, resp.Body)
|
|
return err
|
|
}
|
|
|
|
func downloadSHA256(url string) (string, error) {
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
resp, err := client.Get(url)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("status %d", resp.StatusCode)
|
|
}
|
|
b, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
fields := strings.Fields(string(b))
|
|
if len(fields) == 0 {
|
|
return "", errors.New("empty sha256 file")
|
|
}
|
|
return strings.TrimSpace(fields[0]), nil
|
|
}
|
|
|
|
func verifySHA256(path, expected string) error {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
h := sha256.New()
|
|
if _, err := io.Copy(h, f); err != nil {
|
|
return err
|
|
}
|
|
sum := hex.EncodeToString(h.Sum(nil))
|
|
if !strings.EqualFold(sum, expected) {
|
|
return fmt.Errorf("checksum mismatch")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func restartService() error {
|
|
unit := os.Getenv("ZLH_AGENT_UNIT")
|
|
if unit == "" {
|
|
unit = defaultUnit
|
|
}
|
|
cmd := exec.Command("systemctl", "restart", "--no-block", unit)
|
|
return cmd.Run()
|
|
}
|
|
|
|
func scheduleRollbackGuard(target string) error {
|
|
unit := os.Getenv("ZLH_AGENT_UNIT")
|
|
if unit == "" {
|
|
unit = defaultUnit
|
|
}
|
|
|
|
port := strings.TrimSpace(os.Getenv("ZLH_AGENT_PORT"))
|
|
if port == "" {
|
|
port = "18888"
|
|
}
|
|
target = normalizeVersion(target)
|
|
if target == "" {
|
|
return nil
|
|
}
|
|
|
|
script := fmt.Sprintf(
|
|
"sleep 25; "+
|
|
"if ! curl -fsS http://127.0.0.1:%s/health >/dev/null 2>&1 || "+
|
|
"! curl -fsS http://127.0.0.1:%s/version 2>/dev/null | grep -q '\"version\":\"v%s\"'; then "+
|
|
"prev=$(readlink -f %s || true); "+
|
|
"if [ -n \"$prev\" ] && [ -d \"$prev\" ]; then "+
|
|
"b=$(basename \"$prev\"); "+
|
|
"ln -sfn \"releases/$b\" %s; "+
|
|
"ln -sfn %s/zlh-agent %s; "+
|
|
"systemctl restart --no-block %s; "+
|
|
"fi; "+
|
|
"fi",
|
|
port, port, target, previousLink, currentLink, currentLink, binaryPath, unit,
|
|
)
|
|
|
|
transientUnit := fmt.Sprintf("zlh-agent-update-verify-%d", time.Now().UnixNano())
|
|
cmd := exec.Command("systemd-run", "--unit", transientUnit, "--collect", "/bin/sh", "-c", script)
|
|
if err := cmd.Run(); err != nil {
|
|
return fmt.Errorf("schedule rollback guard: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeVersion(v string) string {
|
|
return strings.TrimPrefix(strings.TrimSpace(v), "v")
|
|
}
|
|
|
|
func compareVersions(a, b string) int {
|
|
as := strings.Split(a, ".")
|
|
bs := strings.Split(b, ".")
|
|
for len(as) < 3 {
|
|
as = append(as, "0")
|
|
}
|
|
for len(bs) < 3 {
|
|
bs = append(bs, "0")
|
|
}
|
|
for i := 0; i < 3; i++ {
|
|
ai := parseInt(as[i])
|
|
bi := parseInt(bs[i])
|
|
if ai < bi {
|
|
return -1
|
|
}
|
|
if ai > bi {
|
|
return 1
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func parseInt(s string) int {
|
|
n := 0
|
|
for _, r := range s {
|
|
if r < '0' || r > '9' {
|
|
break
|
|
}
|
|
n = n*10 + int(r-'0')
|
|
}
|
|
return n
|
|
}
|
|
|
|
func writeStatus(res Result) {
|
|
_ = os.MkdirAll(stateDir, 0o755)
|
|
b, _ := json.MarshalIndent(res, "", " ")
|
|
_ = os.WriteFile(statusFile, b, 0o644)
|
|
}
|
|
|
|
func ReadStatus() Result {
|
|
b, err := os.ReadFile(statusFile)
|
|
if err != nil {
|
|
return Result{}
|
|
}
|
|
var res Result
|
|
_ = json.Unmarshal(b, &res)
|
|
return res
|
|
}
|
|
|
|
func acquireLock(path string) error {
|
|
_ = os.MkdirAll(stateDir, 0o755)
|
|
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644)
|
|
if err != nil {
|
|
return fmt.Errorf("update already in progress")
|
|
}
|
|
_ = f.Close()
|
|
return nil
|
|
}
|
|
|
|
func releaseLock(path string) {
|
|
_ = os.Remove(path)
|
|
}
|
|
|
|
func ensureCurrentSymlinks(currentVersion string) error {
|
|
if _, err := os.Lstat(currentLink); err == nil {
|
|
return nil
|
|
}
|
|
|
|
if err := os.MkdirAll(releasesDir, 0o755); err != nil {
|
|
return err
|
|
}
|
|
|
|
currentRelease := filepath.Join(releasesDir, currentVersion)
|
|
if err := os.MkdirAll(currentRelease, 0o755); err != nil {
|
|
return err
|
|
}
|
|
|
|
currentBinary := filepath.Join(currentRelease, "zlh-agent")
|
|
if _, err := os.Stat(currentBinary); errors.Is(err, os.ErrNotExist) {
|
|
if err := copyFile(binaryPath, currentBinary); err != nil {
|
|
return err
|
|
}
|
|
if err := os.Chmod(currentBinary, 0o755); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := os.Symlink(filepath.Join("releases", currentVersion), currentLink); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return err
|
|
}
|
|
if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func updateSymlinks(target string) error {
|
|
if err := os.RemoveAll(previousLink); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return err
|
|
}
|
|
if _, err := os.Lstat(currentLink); err == nil {
|
|
if err := os.Symlink("current", previousLink); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := os.RemoveAll(currentLink); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return err
|
|
}
|
|
return os.Symlink(filepath.Join("releases", target), currentLink)
|
|
}
|
|
|
|
func copyFile(src, dst string) error {
|
|
in, err := os.Open(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer in.Close()
|
|
out, err := os.Create(dst)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
if _, err := io.Copy(out, in); err != nil {
|
|
return err
|
|
}
|
|
return out.Close()
|
|
}
|
|
|
|
func pruneOldReleases(keep int) error {
|
|
entries, err := os.ReadDir(releasesDir)
|
|
if err != nil {
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
currentResolved, _ := filepath.EvalSymlinks(currentLink)
|
|
previousResolved, _ := filepath.EvalSymlinks(previousLink)
|
|
protected := map[string]struct{}{}
|
|
if currentResolved != "" {
|
|
protected[currentResolved] = struct{}{}
|
|
}
|
|
if previousResolved != "" {
|
|
protected[previousResolved] = struct{}{}
|
|
}
|
|
|
|
type rel struct {
|
|
name string
|
|
path string
|
|
}
|
|
rels := make([]rel, 0, len(entries))
|
|
for _, e := range entries {
|
|
if !e.IsDir() {
|
|
continue
|
|
}
|
|
name := e.Name()
|
|
if !isSemverLike(name) {
|
|
continue
|
|
}
|
|
rels = append(rels, rel{name: name, path: filepath.Join(releasesDir, name)})
|
|
}
|
|
|
|
sort.Slice(rels, func(i, j int) bool {
|
|
return compareVersions(rels[i].name, rels[j].name) > 0
|
|
})
|
|
|
|
for idx, r := range rels {
|
|
if idx < keep {
|
|
continue
|
|
}
|
|
if _, ok := protected[r.path]; ok {
|
|
continue
|
|
}
|
|
if err := os.RemoveAll(r.path); err != nil {
|
|
log.Printf("[update] prune failed for %s: %v", r.path, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isSemverLike(v string) bool {
|
|
parts := strings.Split(v, ".")
|
|
if len(parts) != 3 {
|
|
return false
|
|
}
|
|
for _, p := range parts {
|
|
if p == "" {
|
|
return false
|
|
}
|
|
for _, r := range p {
|
|
if r < '0' || r > '9' {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|