zlh-agent/internal/mods/installer.go
2026-03-07 20:59:27 +00:00

436 lines
12 KiB
Go

package mods
import (
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
const (
maxDownloadSize = 200 * 1024 * 1024
defaultTimeout = 120 * time.Second
maxRedirects = 3
tempModsDir = "/tmp/zlh-agent/mods"
)
var allowedHosts = []string{
"artifacts.zerolaghub.com",
"cdn.modrinth.com",
}
func InstallCurated(serverRoot string, req InstallRequest) (ActionResponse, error) {
source := strings.TrimSpace(strings.ToLower(req.Source))
if source == "" {
source = "curated"
}
if source != "curated" && source != "modrinth" {
return ActionResponse{}, errors.New("unsupported source")
}
var downloadURL string
var filename string
verifyFunc := func(string) error { return nil }
if source == "modrinth" {
downloadURL = strings.TrimSpace(req.DownloadURL)
filename = strings.TrimSpace(req.Filename)
if downloadURL == "" {
return ActionResponse{}, errors.New("download_url required for modrinth source")
}
if filename == "" || !isSafeFilename(filename) || !strings.HasSuffix(strings.ToLower(filename), ".jar") {
return ActionResponse{}, errors.New("invalid filename")
}
if err := validateArtifactURL(downloadURL); err != nil {
return ActionResponse{}, err
}
if strings.TrimSpace(req.SHA512) != "" {
normalized, err := normalizeExpectedHash(req.SHA512, 128, "sha512")
if err != nil {
return ActionResponse{}, err
}
verifyFunc = func(path string) error { return VerifyHashWithAlgorithm(path, normalized, "sha512") }
} else if strings.TrimSpace(req.SHA1) != "" {
normalized, err := normalizeExpectedHash(req.SHA1, 40, "sha1")
if err != nil {
return ActionResponse{}, err
}
verifyFunc = func(path string) error { return VerifyHashWithAlgorithm(path, normalized, "sha1") }
} else {
return ActionResponse{}, errors.New("sha512 or sha1 required for modrinth source")
}
} else {
if !isValidModID(req.ModID) {
return ActionResponse{}, errors.New("invalid mod_id")
}
if err := validateArtifactURL(req.ArtifactURL); err != nil {
return ActionResponse{}, err
}
if err := validateExpectedHash(req.ArtifactHash); err != nil {
return ActionResponse{}, err
}
downloadURL = req.ArtifactURL
filename = safeInstallFilename(req)
if !isSafeFilename(filename) || !strings.HasSuffix(strings.ToLower(filename), ".jar") {
return ActionResponse{}, errors.New("invalid artifact filename")
}
verifyFunc = func(path string) error { return VerifyHash(path, req.ArtifactHash) }
}
modsDir := filepath.Join(serverRoot, "mods")
if err := os.MkdirAll(modsDir, 0o755); err != nil {
return ActionResponse{}, fmt.Errorf("create mods dir: %w", err)
}
finalPath := filepath.Join(modsDir, filename)
if _, err := os.Stat(finalPath); err == nil {
return ActionResponse{}, fmt.Errorf("mod already exists: %s", filename)
}
if err := os.MkdirAll(tempModsDir, 0o755); err != nil {
return ActionResponse{}, fmt.Errorf("create temp dir: %w", err)
}
tmpFile, err := os.CreateTemp(tempModsDir, "zlh-mod-*.jar")
if err != nil {
return ActionResponse{}, fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer func() {
_ = os.Remove(tmpPath)
}()
if err := tmpFile.Close(); err != nil {
return ActionResponse{}, fmt.Errorf("close temp file: %w", err)
}
if err := downloadArtifact(downloadURL, tmpPath); err != nil {
return ActionResponse{}, err
}
if err := verifyFunc(tmpPath); err != nil {
return ActionResponse{}, err
}
if err := os.Chmod(tmpPath, 0o644); err != nil {
return ActionResponse{}, fmt.Errorf("chmod temp file: %w", err)
}
if err := os.Rename(tmpPath, finalPath); err != nil {
return ActionResponse{}, fmt.Errorf("install mod: %w", err)
}
if err := os.Chmod(finalPath, 0o644); err != nil {
return ActionResponse{}, fmt.Errorf("set permissions: %w", err)
}
InvalidateCache(serverRoot)
return ActionResponse{Success: true, Action: "installed", RestartRequired: true}, nil
}
func SetEnabled(serverRoot, modID string, enabled bool) (ActionResponse, error) {
modsDir := filepath.Join(serverRoot, "mods")
enabledName, disabledName, err := ResolveByModID(serverRoot, modID)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return ActionResponse{}, os.ErrNotExist
}
return ActionResponse{}, err
}
if enabled {
if enabledName != "" {
return ActionResponse{Success: true, Action: "enabled", RestartRequired: true}, nil
}
src := filepath.Join(modsDir, disabledName)
dstName := strings.TrimSuffix(disabledName, ".disabled")
dst := filepath.Join(modsDir, dstName)
if _, err := os.Stat(dst); err == nil {
return ActionResponse{}, errors.New("cannot enable: target file already exists")
}
if err := os.Rename(src, dst); err != nil {
return ActionResponse{}, fmt.Errorf("enable mod: %w", err)
}
if err := os.Chmod(dst, 0o644); err != nil {
return ActionResponse{}, err
}
InvalidateCache(serverRoot)
return ActionResponse{Success: true, Action: "enabled", RestartRequired: true}, nil
}
if disabledName != "" {
return ActionResponse{Success: true, Action: "disabled", RestartRequired: true}, nil
}
src := filepath.Join(modsDir, enabledName)
dstName := enabledName + ".disabled"
dst := filepath.Join(modsDir, dstName)
if _, err := os.Stat(dst); err == nil {
return ActionResponse{}, errors.New("cannot disable: target file already exists")
}
if err := os.Rename(src, dst); err != nil {
return ActionResponse{}, fmt.Errorf("disable mod: %w", err)
}
if err := os.Chmod(dst, 0o644); err != nil {
return ActionResponse{}, err
}
InvalidateCache(serverRoot)
return ActionResponse{Success: true, Action: "disabled", RestartRequired: true}, nil
}
func DeleteMod(serverRoot, modID string) (ActionResponse, error) {
modsDir := filepath.Join(serverRoot, "mods")
enabledName, disabledName, err := ResolveByModID(serverRoot, modID)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return ActionResponse{}, os.ErrNotExist
}
return ActionResponse{}, err
}
sourceName := enabledName
if sourceName == "" {
sourceName = disabledName
}
src := filepath.Join(modsDir, sourceName)
removedDir := filepath.Join(serverRoot, "mods-removed")
if err := os.MkdirAll(removedDir, 0o755); err != nil {
return ActionResponse{}, fmt.Errorf("create removed dir: %w", err)
}
targetPath := uniqueRemovedPath(removedDir, sourceName)
if err := os.Rename(src, targetPath); err != nil {
return ActionResponse{}, fmt.Errorf("remove mod: %w", err)
}
if err := os.Chmod(targetPath, 0o644); err != nil {
return ActionResponse{}, err
}
InvalidateCache(serverRoot)
return ActionResponse{Success: true, Action: "deleted", RestartRequired: true}, nil
}
func VerifyHash(path string, expected string) error {
if err := validateExpectedHash(expected); err != nil {
return err
}
want := strings.TrimPrefix(expected, "sha256:")
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
have := hex.EncodeToString(h.Sum(nil))
if !strings.EqualFold(have, want) {
return errors.New("sha256 mismatch")
}
return nil
}
func VerifyHashWithAlgorithm(path, expectedHex, algorithm string) error {
expectedHex = strings.ToLower(strings.TrimSpace(expectedHex))
if expectedHex == "" {
return errors.New("missing expected hash")
}
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
switch algorithm {
case "sha512":
h := sha512.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
if hex.EncodeToString(h.Sum(nil)) != expectedHex {
return errors.New("sha512 mismatch")
}
return nil
case "sha1":
h := sha1.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
if hex.EncodeToString(h.Sum(nil)) != expectedHex {
return errors.New("sha1 mismatch")
}
return nil
default:
return errors.New("unsupported hash algorithm")
}
}
func validateArtifactURL(raw string) error {
u, err := url.Parse(raw)
if err != nil {
return errors.New("invalid artifact_url")
}
if u.Scheme != "https" {
return errors.New("artifact_url must use https")
}
if !isAllowedHost(u.Hostname()) {
return errors.New("artifact_url host not allowed")
}
return nil
}
func validateExpectedHash(v string) error {
if !strings.HasPrefix(v, "sha256:") {
return errors.New("artifact_hash must start with sha256:")
}
hexPart := strings.TrimPrefix(v, "sha256:")
if len(hexPart) != 64 {
return errors.New("artifact_hash must include 64 hex characters")
}
for _, r := range hexPart {
isHex := (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')
if !isHex {
return errors.New("artifact_hash contains non-hex characters")
}
}
return nil
}
func normalizeExpectedHash(raw string, expectedLen int, prefix string) (string, error) {
v := strings.TrimSpace(strings.ToLower(raw))
if strings.HasPrefix(v, prefix+":") {
v = strings.TrimPrefix(v, prefix+":")
}
if len(v) != expectedLen {
return "", fmt.Errorf("%s must include %d hex characters", prefix, expectedLen)
}
for _, r := range v {
isHex := (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f')
if !isHex {
return "", fmt.Errorf("%s contains non-hex characters", prefix)
}
}
return v, nil
}
func safeInstallFilename(req InstallRequest) string {
if u, err := url.Parse(req.ArtifactURL); err == nil {
base := filepath.Base(u.Path)
if base != "." && base != "/" && isSafeFilename(base) {
return base
}
}
name := sanitizeID(req.ModID)
if strings.TrimSpace(req.Version) != "" {
name = sanitizeID(req.ModID + "-" + req.Version)
}
return name + ".jar"
}
func downloadArtifact(rawURL, dest string) error {
timeout := defaultTimeout
if v := strings.TrimSpace(os.Getenv("ZLH_MOD_DOWNLOAD_TIMEOUT")); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
timeout = d
}
}
client := &http.Client{
Timeout: timeout,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// Keep downloads pinned to the curated HTTPS host, even across redirects.
if len(via) >= maxRedirects {
return errors.New("too many redirects")
}
if req.URL.Scheme != "https" {
return errors.New("redirected to non-https url")
}
if !isAllowedHost(req.URL.Hostname()) {
return errors.New("redirected to disallowed host")
}
return nil
},
}
resp, err := client.Get(rawURL)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.Request == nil || resp.Request.URL == nil {
return errors.New("invalid download response")
}
if resp.Request.URL.Scheme != "https" || !isAllowedHost(resp.Request.URL.Hostname()) {
return errors.New("final download url not allowed")
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed: status %d", resp.StatusCode)
}
if cl := resp.Header.Get("Content-Length"); cl != "" {
n, err := strconv.ParseInt(cl, 10, 64)
if err == nil && n > maxDownloadSize {
return errors.New("artifact too large")
}
}
out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()
limited := io.LimitReader(resp.Body, maxDownloadSize+1)
written, err := io.Copy(out, limited)
if err != nil {
return err
}
if written > maxDownloadSize {
return errors.New("artifact exceeds 200MB limit")
}
return nil
}
func uniqueRemovedPath(dir, filename string) string {
candidate := filepath.Join(dir, filename)
if _, err := os.Stat(candidate); errors.Is(err, os.ErrNotExist) {
return candidate
}
base := filename
ext := ""
if strings.HasSuffix(filename, ".disabled") {
base = strings.TrimSuffix(filename, ".disabled")
ext = ".disabled"
}
if strings.HasSuffix(base, ".jar") {
base = strings.TrimSuffix(base, ".jar")
ext = ".jar" + ext
}
ts := time.Now().UTC().Format("20060102T150405")
for i := 1; ; i++ {
name := fmt.Sprintf("%s-%s-%d%s", base, ts, i, ext)
candidate = filepath.Join(dir, name)
if _, err := os.Stat(candidate); errors.Is(err, os.ErrNotExist) {
return candidate
}
}
}
func isAllowedHost(host string) bool {
for _, allowed := range allowedHosts {
if strings.EqualFold(host, allowed) {
return true
}
}
return false
}