zlh-agent/internal/update/update.go
2026-02-21 21:54:48 +00:00

531 lines
12 KiB
Go

package update
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"time"
)
const (
artifactBaseURL = "http://10.60.0.251:8080/agents/"
releasesDir = "/opt/zlh-agent/releases"
currentLink = "/opt/zlh-agent/current"
previousLink = "/opt/zlh-agent/previous"
binaryPath = "/opt/zlh-agent/zlh-agent"
stateDir = "/opt/zlh-agent/state"
statusFile = "/opt/zlh-agent/state/update.json"
defaultUnit = "zlh-agent"
defaultMode = "notify"
)
type Manifest struct {
SchemaVersion int `json:"schema_version"`
Latest string `json:"latest"`
MinSupported string `json:"min_supported"`
Channels map[string]string `json:"channels"`
Artifacts map[string]struct {
LinuxAMD64 struct {
Binary string `json:"binary"`
SHA256 string `json:"sha256"`
} `json:"linux_amd64"`
} `json:"artifacts"`
ReleasedAt string `json:"released_at"`
}
type Result struct {
Status string `json:"status"`
Current string `json:"current,omitempty"`
Target string `json:"target,omitempty"`
Error string `json:"error,omitempty"`
CheckedAtUTC string `json:"checked_at_utc,omitempty"`
}
func CheckAvailable(currentVersion string) Result {
currentVersion = normalizeVersion(currentVersion)
result := Result{
Status: "error",
Current: currentVersion,
CheckedAtUTC: time.Now().UTC().Format(time.RFC3339),
}
manifest, err := fetchManifest()
if err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
if manifest.SchemaVersion != 1 {
result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion)
writeStatus(result)
return result
}
target := normalizeVersion(manifest.Channels["stable"])
if target == "" {
result.Error = "missing stable channel version"
writeStatus(result)
return result
}
result.Target = target
if compareVersions(currentVersion, target) >= 0 {
result.Status = "noop"
result.Error = ""
writeStatus(result)
return result
}
result.Status = "available"
result.Error = ""
writeStatus(result)
return result
}
func StartPeriodic(currentVersion string) {
mode := strings.ToLower(strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_MODE")))
if mode == "" {
mode = defaultMode
}
switch mode {
case "off", "disabled":
log.Printf("[update] periodic checks disabled (mode=%s)", mode)
return
case "notify", "auto":
default:
log.Printf("[update] invalid mode %q, using %q", mode, defaultMode)
mode = defaultMode
}
interval := 30 * time.Minute
if v := strings.TrimSpace(os.Getenv("ZLH_AGENT_UPDATE_INTERVAL")); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
log.Printf("[update] invalid ZLH_AGENT_UPDATE_INTERVAL=%q: %v (using %s)", v, err, interval)
} else if d < time.Minute {
log.Printf("[update] ZLH_AGENT_UPDATE_INTERVAL too small (%s), using %s", d, interval)
} else {
interval = d
}
}
log.Printf("[update] periodic checks enabled (mode=%s interval=%s)", mode, interval)
run := func() {
switch mode {
case "auto":
res := CheckAndUpdate(currentVersion)
switch res.Status {
case "updated":
log.Printf("[update] applied update current=%s target=%s", res.Current, res.Target)
case "noop":
log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target)
default:
log.Printf("[update] auto check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error)
}
case "notify":
res := CheckAvailable(currentVersion)
switch res.Status {
case "available":
log.Printf("[update] update available current=%s target=%s", res.Current, res.Target)
case "noop":
log.Printf("[update] no update available current=%s target=%s", res.Current, res.Target)
default:
log.Printf("[update] notify check failed status=%s current=%s target=%s err=%s", res.Status, res.Current, res.Target, res.Error)
}
}
}
go func() {
time.Sleep(10 * time.Second)
run()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
run()
}
}()
}
func CheckAndUpdate(currentVersion string) Result {
currentVersion = normalizeVersion(currentVersion)
result := Result{
Status: "error",
Current: currentVersion,
CheckedAtUTC: time.Now().UTC().Format(time.RFC3339),
}
if runtime.GOOS != "linux" || runtime.GOARCH != "amd64" {
result.Error = fmt.Sprintf("unsupported platform: %s_%s", runtime.GOOS, runtime.GOARCH)
writeStatus(result)
return result
}
lockPath := filepath.Join(stateDir, "update.lock")
if err := acquireLock(lockPath); err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
defer releaseLock(lockPath)
manifest, err := fetchManifest()
if err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
if manifest.SchemaVersion != 1 {
result.Error = fmt.Sprintf("unsupported manifest schema: %d", manifest.SchemaVersion)
writeStatus(result)
return result
}
target := normalizeVersion(manifest.Channels["stable"])
if target == "" {
result.Error = "missing stable channel version"
writeStatus(result)
return result
}
result.Target = target
if compareVersions(currentVersion, target) == 0 {
result.Status = "noop"
result.Error = ""
writeStatus(result)
return result
}
if compareVersions(target, normalizeVersion(manifest.MinSupported)) < 0 {
result.Error = "target below min_supported"
writeStatus(result)
return result
}
artifact, ok := manifest.Artifacts[target]
if !ok {
result.Error = "missing artifacts for target version"
writeStatus(result)
return result
}
binPath := artifact.LinuxAMD64.Binary
shaPath := artifact.LinuxAMD64.SHA256
if binPath == "" || shaPath == "" {
result.Error = "artifact paths missing"
writeStatus(result)
return result
}
if err := ensureCurrentSymlinks(currentVersion); err != nil {
result.Error = fmt.Sprintf("prepare current symlink: %v", err)
writeStatus(result)
return result
}
if err := os.MkdirAll(filepath.Join(releasesDir, target), 0o755); err != nil {
result.Error = err.Error()
writeStatus(result)
return result
}
tmpPath := filepath.Join(releasesDir, target, "zlh-agent.new")
finalPath := filepath.Join(releasesDir, target, "zlh-agent")
binURL := artifactBaseURL + binPath
shaURL := artifactBaseURL + shaPath
if err := downloadFile(binURL, tmpPath); err != nil {
result.Error = fmt.Sprintf("download binary: %v", err)
writeStatus(result)
return result
}
expected, err := downloadSHA256(shaURL)
if err != nil {
result.Error = fmt.Sprintf("download sha256: %v", err)
writeStatus(result)
return result
}
if err := verifySHA256(tmpPath, expected); err != nil {
result.Error = fmt.Sprintf("sha256 verify failed: %v", err)
writeStatus(result)
return result
}
if err := os.Chmod(tmpPath, 0o755); err != nil {
result.Error = fmt.Sprintf("chmod: %v", err)
writeStatus(result)
return result
}
if err := os.Rename(tmpPath, finalPath); err != nil {
result.Error = fmt.Sprintf("install: %v", err)
writeStatus(result)
return result
}
if err := updateSymlinks(target); err != nil {
result.Error = fmt.Sprintf("update symlinks: %v", err)
writeStatus(result)
return result
}
if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) {
result.Error = fmt.Sprintf("remove binary path: %v", err)
writeStatus(result)
return result
}
if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil {
result.Error = fmt.Sprintf("update current symlink: %v", err)
writeStatus(result)
return result
}
result.Status = "updated"
result.Error = ""
writeStatus(result)
go func() {
if err := restartService(); err != nil {
log.Printf("[update] restart failed: %v", err)
}
}()
return result
}
func fetchManifest() (Manifest, error) {
url := artifactBaseURL + "manifest.json"
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return Manifest{}, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return Manifest{}, fmt.Errorf("manifest status %d", resp.StatusCode)
}
var m Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return Manifest{}, err
}
return m, nil
}
func downloadFile(url, dest string) error {
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status %d", resp.StatusCode)
}
f, err := os.Create(dest)
if err != nil {
return err
}
defer f.Close()
_, err = io.Copy(f, resp.Body)
return err
}
func downloadSHA256(url string) (string, error) {
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("status %d", resp.StatusCode)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
fields := strings.Fields(string(b))
if len(fields) == 0 {
return "", errors.New("empty sha256 file")
}
return strings.TrimSpace(fields[0]), nil
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
sum := hex.EncodeToString(h.Sum(nil))
if !strings.EqualFold(sum, expected) {
return fmt.Errorf("checksum mismatch")
}
return nil
}
func restartService() error {
unit := os.Getenv("ZLH_AGENT_UNIT")
if unit == "" {
unit = defaultUnit
}
cmd := exec.Command("systemctl", "restart", unit)
return cmd.Run()
}
func normalizeVersion(v string) string {
return strings.TrimPrefix(strings.TrimSpace(v), "v")
}
func compareVersions(a, b string) int {
as := strings.Split(a, ".")
bs := strings.Split(b, ".")
for len(as) < 3 {
as = append(as, "0")
}
for len(bs) < 3 {
bs = append(bs, "0")
}
for i := 0; i < 3; i++ {
ai := parseInt(as[i])
bi := parseInt(bs[i])
if ai < bi {
return -1
}
if ai > bi {
return 1
}
}
return 0
}
func parseInt(s string) int {
n := 0
for _, r := range s {
if r < '0' || r > '9' {
break
}
n = n*10 + int(r-'0')
}
return n
}
func writeStatus(res Result) {
_ = os.MkdirAll(stateDir, 0o755)
b, _ := json.MarshalIndent(res, "", " ")
_ = os.WriteFile(statusFile, b, 0o644)
}
func ReadStatus() Result {
b, err := os.ReadFile(statusFile)
if err != nil {
return Result{}
}
var res Result
_ = json.Unmarshal(b, &res)
return res
}
func acquireLock(path string) error {
_ = os.MkdirAll(stateDir, 0o755)
f, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("update already in progress")
}
_ = f.Close()
return nil
}
func releaseLock(path string) {
_ = os.Remove(path)
}
func ensureCurrentSymlinks(currentVersion string) error {
if _, err := os.Lstat(currentLink); err == nil {
return nil
}
if err := os.MkdirAll(releasesDir, 0o755); err != nil {
return err
}
currentRelease := filepath.Join(releasesDir, currentVersion)
if err := os.MkdirAll(currentRelease, 0o755); err != nil {
return err
}
currentBinary := filepath.Join(currentRelease, "zlh-agent")
if _, err := os.Stat(currentBinary); errors.Is(err, os.ErrNotExist) {
if err := copyFile(binaryPath, currentBinary); err != nil {
return err
}
if err := os.Chmod(currentBinary, 0o755); err != nil {
return err
}
}
if err := os.Symlink(filepath.Join("releases", currentVersion), currentLink); err != nil {
return err
}
if err := os.Remove(binaryPath); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if err := os.Symlink(filepath.Join(currentLink, "zlh-agent"), binaryPath); err != nil {
return err
}
return nil
}
func updateSymlinks(target string) error {
if err := os.RemoveAll(previousLink); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
if _, err := os.Lstat(currentLink); err == nil {
if err := os.Symlink("current", previousLink); err != nil {
return err
}
}
if err := os.RemoveAll(currentLink); err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
return os.Symlink(filepath.Join("releases", target), currentLink)
}
func copyFile(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
if _, err := io.Copy(out, in); err != nil {
return err
}
return out.Close()
}