zlh-agent/internal/http/console_sessions.go
2026-03-15 11:06:08 +00:00

236 lines
4.6 KiB
Go

package agenthttp
import (
"fmt"
"io"
"log"
"os"
"sync"
"time"
"github.com/gorilla/websocket"
"zlh-agent/internal/runtime"
"zlh-agent/internal/state"
"zlh-agent/internal/system"
)
const sessionTTL = 60 * time.Second
type consoleConn struct {
send chan []byte
}
type consoleSession struct {
key string
cfg *state.Config
ptyFile *os.File
createdAt time.Time
lastActive time.Time
mu sync.Mutex
conns map[*websocket.Conn]*consoleConn
readerOnce sync.Once
closeOnce sync.Once
}
var (
sessionMu sync.Mutex
sessions = make(map[string]*consoleSession)
)
func sessionKey(cfg *state.Config) string {
return fmt.Sprintf("%d:%s", cfg.VMID, cfg.ContainerType)
}
func getConsoleSession(cfg *state.Config) (*consoleSession, bool, error) {
key := sessionKey(cfg)
sessionMu.Lock()
if sess, ok := sessions[key]; ok {
sessionMu.Unlock()
currentPTY, err := system.GetConsolePTY(cfg)
if err != nil {
sess.destroy()
return nil, false, err
}
if sess.ptyFile != currentPTY {
log.Printf("[console] vmid=%d type=%s pty changed, destroying stale session", cfg.VMID, cfg.ContainerType)
sess.destroy()
} else {
sess.touch()
log.Printf("[console] vmid=%d type=%s session reuse", cfg.VMID, cfg.ContainerType)
return sess, true, nil
}
}
sessionMu.Unlock()
ptyFile, err := system.GetConsolePTY(cfg)
if err != nil {
return nil, false, err
}
sess := &consoleSession{
key: key,
cfg: cfg,
ptyFile: ptyFile,
createdAt: time.Now(),
lastActive: time.Now(),
conns: make(map[*websocket.Conn]*consoleConn),
}
sess.startReader()
sessionMu.Lock()
sessions[key] = sess
sessionMu.Unlock()
log.Printf("[console] vmid=%d type=%s session created", cfg.VMID, cfg.ContainerType)
return sess, false, nil
}
func (s *consoleSession) touch() {
s.mu.Lock()
s.lastActive = time.Now()
s.mu.Unlock()
}
func (s *consoleSession) addConn(conn *websocket.Conn, cc *consoleConn) *consoleConn {
s.mu.Lock()
defer s.mu.Unlock()
if cc == nil {
cc = &consoleConn{send: make(chan []byte, 128)}
}
s.conns[conn] = cc
s.lastActive = time.Now()
log.Printf("[console] vmid=%d type=%s conn add conns=%d", s.cfg.VMID, s.cfg.ContainerType, len(s.conns))
return cc
}
func (s *consoleSession) removeConn(conn *websocket.Conn) int {
s.mu.Lock()
defer s.mu.Unlock()
cc, ok := s.conns[conn]
if ok {
delete(s.conns, conn)
safeCloseChan(cc.send)
}
s.lastActive = time.Now()
log.Printf("[console] vmid=%d type=%s conn remove conns=%d", s.cfg.VMID, s.cfg.ContainerType, len(s.conns))
return len(s.conns)
}
func (s *consoleSession) startReader() {
s.readerOnce.Do(func() {
go func() {
buf := make([]byte, 4096)
for {
n, err := s.ptyFile.Read(buf)
if n > 0 {
out := make([]byte, n)
copy(out, buf[:n])
log.Printf("[console] vmid=%d pty read bytes=%d", s.cfg.VMID, n)
s.broadcast(out)
}
if err != nil {
if err == io.EOF {
log.Printf("[console] vmid=%d pty read loop exit err=EOF", s.cfg.VMID)
} else {
log.Printf("[console] vmid=%d pty read loop exit err=%v", s.cfg.VMID, err)
}
s.destroy()
return
}
if n == 0 && err == nil {
continue
}
}
}()
})
}
func (s *consoleSession) broadcast(data []byte) {
s.mu.Lock()
defer s.mu.Unlock()
s.lastActive = time.Now()
for _, cc := range s.conns {
select {
case cc.send <- data:
default:
select {
case <-cc.send:
default:
}
select {
case cc.send <- data:
default:
}
}
}
}
func (s *consoleSession) writeInput(data []byte) error {
s.touch()
if s.ptyFile == nil {
return fmt.Errorf("pty unavailable")
}
return runtime.Write(s.ptyFile, data)
}
func (s *consoleSession) scheduleCleanupIfIdle() {
s.mu.Lock()
last := s.lastActive
s.mu.Unlock()
go func(ts time.Time) {
time.Sleep(sessionTTL)
s.mu.Lock()
conns := len(s.conns)
lastActive := s.lastActive
s.mu.Unlock()
if conns == 0 && lastActive.Equal(ts) {
log.Printf("[console] vmid=%d type=%s session cleanup", s.cfg.VMID, s.cfg.ContainerType)
if s.cfg.ContainerType == "dev" {
_ = system.StopDevShell()
}
s.destroy()
}
}(last)
}
func (s *consoleSession) destroy() {
s.closeOnce.Do(func() {
s.mu.Lock()
for conn, cc := range s.conns {
safeCloseChan(cc.send)
_ = conn.Close()
delete(s.conns, conn)
}
pty := s.ptyFile
s.ptyFile = nil
s.mu.Unlock()
if pty != nil {
_ = pty.Close()
}
sessionMu.Lock()
delete(sessions, s.key)
sessionMu.Unlock()
log.Printf("[console] vmid=%d type=%s session destroyed", s.cfg.VMID, s.cfg.ContainerType)
})
}
func safeCloseChan(ch chan []byte) {
defer func() {
_ = recover()
}()
close(ch)
}