91 lines
2.1 KiB
Go
91 lines
2.1 KiB
Go
package auth
|
|
|
|
import (
|
|
"crypto/subtle"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
envToken = "ZLH_AGENT_TOKEN"
|
|
HeaderToken = "X-ZLH-Agent-Token"
|
|
QueryTokenParam = "agent_token"
|
|
)
|
|
|
|
type Policy struct {
|
|
Public map[string]map[string]struct{}
|
|
}
|
|
|
|
func Public(methods ...string) map[string]struct{} {
|
|
out := make(map[string]struct{}, len(methods))
|
|
for _, method := range methods {
|
|
out[strings.ToUpper(strings.TrimSpace(method))] = struct{}{}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func Wrap(next http.Handler, policy Policy) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if policy.IsPublic(r.Method, r.URL.Path) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
if strings.TrimSpace(os.Getenv(envToken)) == "" {
|
|
log.Printf("[auth] warning: %s not set; rejecting protected request path=%s", envToken, r.URL.Path)
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
if !Authorized(r) {
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (p Policy) IsPublic(method, path string) bool {
|
|
methods, ok := p.Public[path]
|
|
if !ok {
|
|
return false
|
|
}
|
|
_, ok = methods[strings.ToUpper(method)]
|
|
return ok
|
|
}
|
|
|
|
func Authorized(r *http.Request) bool {
|
|
expected := strings.TrimSpace(os.Getenv(envToken))
|
|
if expected == "" {
|
|
return true
|
|
}
|
|
|
|
return constantTimeEqual(bearerToken(r.Header.Get("Authorization")), expected) ||
|
|
constantTimeEqual(r.Header.Get(HeaderToken), expected) ||
|
|
constantTimeEqual(r.URL.Query().Get(QueryTokenParam), expected)
|
|
}
|
|
|
|
func bearerToken(header string) string {
|
|
header = strings.TrimSpace(header)
|
|
if header == "" {
|
|
return ""
|
|
}
|
|
before, after, ok := strings.Cut(header, " ")
|
|
if !ok || !strings.EqualFold(before, "Bearer") {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(after)
|
|
}
|
|
|
|
func constantTimeEqual(got, expected string) bool {
|
|
got = strings.TrimSpace(got)
|
|
expected = strings.TrimSpace(expected)
|
|
if got == "" || expected == "" {
|
|
return false
|
|
}
|
|
if len(got) != len(expected) {
|
|
return false
|
|
}
|
|
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
|
|
}
|