zlh-agent/internal/auth/auth.go

91 lines
2.0 KiB
Go

package auth
import (
"crypto/subtle"
"log"
"net/http"
"os"
"strings"
)
const (
envToken = "ZLH_AGENT_TOKEN"
HeaderToken = "X-ZLH-Agent-Token"
QueryTokenParam = "agent_token"
)
type Policy struct {
Public map[string]map[string]struct{}
}
func Public(methods ...string) map[string]struct{} {
out := make(map[string]struct{}, len(methods))
for _, method := range methods {
out[strings.ToUpper(strings.TrimSpace(method))] = struct{}{}
}
return out
}
func Wrap(next http.Handler, policy Policy) http.Handler {
if strings.TrimSpace(os.Getenv(envToken)) == "" {
log.Printf("[auth] warning: %s not set; agent auth enforcement disabled", envToken)
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if policy.IsPublic(r.Method, r.URL.Path) {
next.ServeHTTP(w, r)
return
}
if !Authorized(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func (p Policy) IsPublic(method, path string) bool {
methods, ok := p.Public[path]
if !ok {
return false
}
_, ok = methods[strings.ToUpper(method)]
return ok
}
func Authorized(r *http.Request) bool {
expected := strings.TrimSpace(os.Getenv(envToken))
if expected == "" {
return true
}
return constantTimeEqual(bearerToken(r.Header.Get("Authorization")), expected) ||
constantTimeEqual(r.Header.Get(HeaderToken), expected) ||
constantTimeEqual(r.URL.Query().Get(QueryTokenParam), expected)
}
func bearerToken(header string) string {
header = strings.TrimSpace(header)
if header == "" {
return ""
}
before, after, ok := strings.Cut(header, " ")
if !ok || !strings.EqualFold(before, "Bearer") {
return ""
}
return strings.TrimSpace(after)
}
func constantTimeEqual(got, expected string) bool {
got = strings.TrimSpace(got)
expected = strings.TrimSpace(expected)
if got == "" || expected == "" {
return false
}
if len(got) != len(expected) {
return false
}
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
}