236 lines
4.6 KiB
Go
236 lines
4.6 KiB
Go
package agenthttp
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
"zlh-agent/internal/runtime"
|
|
"zlh-agent/internal/state"
|
|
"zlh-agent/internal/system"
|
|
)
|
|
|
|
const sessionTTL = 60 * time.Second
|
|
|
|
type consoleConn struct {
|
|
send chan []byte
|
|
}
|
|
|
|
type consoleSession struct {
|
|
key string
|
|
cfg *state.Config
|
|
ptyFile *os.File
|
|
createdAt time.Time
|
|
lastActive time.Time
|
|
|
|
mu sync.Mutex
|
|
conns map[*websocket.Conn]*consoleConn
|
|
readerOnce sync.Once
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
var (
|
|
sessionMu sync.Mutex
|
|
sessions = make(map[string]*consoleSession)
|
|
)
|
|
|
|
func sessionKey(cfg *state.Config) string {
|
|
return fmt.Sprintf("%d:%s", cfg.VMID, cfg.ContainerType)
|
|
}
|
|
|
|
func getConsoleSession(cfg *state.Config) (*consoleSession, bool, error) {
|
|
key := sessionKey(cfg)
|
|
|
|
sessionMu.Lock()
|
|
if sess, ok := sessions[key]; ok {
|
|
sessionMu.Unlock()
|
|
|
|
currentPTY, err := system.GetConsolePTY(cfg)
|
|
if err != nil {
|
|
sess.destroy()
|
|
return nil, false, err
|
|
}
|
|
|
|
if sess.ptyFile != currentPTY {
|
|
log.Printf("[ws] pty changed, destroying stale session: vmid=%d type=%s", cfg.VMID, cfg.ContainerType)
|
|
sess.destroy()
|
|
} else {
|
|
sess.touch()
|
|
log.Printf("[ws] session reuse: vmid=%d type=%s", cfg.VMID, cfg.ContainerType)
|
|
return sess, true, nil
|
|
}
|
|
}
|
|
sessionMu.Unlock()
|
|
|
|
ptyFile, err := system.GetConsolePTY(cfg)
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
|
|
sess := &consoleSession{
|
|
key: key,
|
|
cfg: cfg,
|
|
ptyFile: ptyFile,
|
|
createdAt: time.Now(),
|
|
lastActive: time.Now(),
|
|
conns: make(map[*websocket.Conn]*consoleConn),
|
|
}
|
|
|
|
sess.startReader()
|
|
|
|
sessionMu.Lock()
|
|
sessions[key] = sess
|
|
sessionMu.Unlock()
|
|
|
|
log.Printf("[ws] session created: vmid=%d type=%s", cfg.VMID, cfg.ContainerType)
|
|
return sess, false, nil
|
|
}
|
|
|
|
func (s *consoleSession) touch() {
|
|
s.mu.Lock()
|
|
s.lastActive = time.Now()
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
func (s *consoleSession) addConn(conn *websocket.Conn, cc *consoleConn) *consoleConn {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if cc == nil {
|
|
cc = &consoleConn{send: make(chan []byte, 128)}
|
|
}
|
|
s.conns[conn] = cc
|
|
s.lastActive = time.Now()
|
|
log.Printf("[ws] conn add: vmid=%d type=%s conns=%d", s.cfg.VMID, s.cfg.ContainerType, len(s.conns))
|
|
return cc
|
|
}
|
|
|
|
func (s *consoleSession) removeConn(conn *websocket.Conn) int {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
cc, ok := s.conns[conn]
|
|
if ok {
|
|
delete(s.conns, conn)
|
|
safeCloseChan(cc.send)
|
|
}
|
|
s.lastActive = time.Now()
|
|
log.Printf("[ws] conn remove: vmid=%d type=%s conns=%d", s.cfg.VMID, s.cfg.ContainerType, len(s.conns))
|
|
return len(s.conns)
|
|
}
|
|
|
|
func (s *consoleSession) startReader() {
|
|
s.readerOnce.Do(func() {
|
|
go func() {
|
|
buf := make([]byte, 4096)
|
|
for {
|
|
n, err := s.ptyFile.Read(buf)
|
|
if n > 0 {
|
|
out := make([]byte, n)
|
|
copy(out, buf[:n])
|
|
log.Printf("[ws] pty read: vmid=%d bytes=%d", s.cfg.VMID, n)
|
|
s.broadcast(out)
|
|
}
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
log.Printf("[ws] pty read loop exit: vmid=%d err=EOF", s.cfg.VMID)
|
|
} else {
|
|
log.Printf("[ws] pty read loop exit: vmid=%d err=%v", s.cfg.VMID, err)
|
|
}
|
|
s.destroy()
|
|
return
|
|
}
|
|
if n == 0 && err == nil {
|
|
continue
|
|
}
|
|
}
|
|
}()
|
|
})
|
|
}
|
|
|
|
func (s *consoleSession) broadcast(data []byte) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.lastActive = time.Now()
|
|
for _, cc := range s.conns {
|
|
select {
|
|
case cc.send <- data:
|
|
default:
|
|
select {
|
|
case <-cc.send:
|
|
default:
|
|
}
|
|
select {
|
|
case cc.send <- data:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *consoleSession) writeInput(data []byte) error {
|
|
s.touch()
|
|
if s.ptyFile == nil {
|
|
return fmt.Errorf("pty unavailable")
|
|
}
|
|
return runtime.Write(s.ptyFile, data)
|
|
}
|
|
|
|
func (s *consoleSession) scheduleCleanupIfIdle() {
|
|
s.mu.Lock()
|
|
last := s.lastActive
|
|
s.mu.Unlock()
|
|
|
|
go func(ts time.Time) {
|
|
time.Sleep(sessionTTL)
|
|
s.mu.Lock()
|
|
conns := len(s.conns)
|
|
lastActive := s.lastActive
|
|
s.mu.Unlock()
|
|
|
|
if conns == 0 && lastActive.Equal(ts) {
|
|
log.Printf("[ws] session cleanup: vmid=%d type=%s", s.cfg.VMID, s.cfg.ContainerType)
|
|
if s.cfg.ContainerType == "dev" {
|
|
_ = system.StopDevShell()
|
|
}
|
|
s.destroy()
|
|
}
|
|
}(last)
|
|
}
|
|
|
|
func (s *consoleSession) destroy() {
|
|
s.closeOnce.Do(func() {
|
|
s.mu.Lock()
|
|
for conn, cc := range s.conns {
|
|
safeCloseChan(cc.send)
|
|
_ = conn.Close()
|
|
delete(s.conns, conn)
|
|
}
|
|
pty := s.ptyFile
|
|
s.ptyFile = nil
|
|
s.mu.Unlock()
|
|
|
|
if pty != nil {
|
|
_ = pty.Close()
|
|
}
|
|
|
|
sessionMu.Lock()
|
|
delete(sessions, s.key)
|
|
sessionMu.Unlock()
|
|
|
|
log.Printf("[ws] session destroyed: vmid=%d type=%s", s.cfg.VMID, s.cfg.ContainerType)
|
|
})
|
|
}
|
|
|
|
func safeCloseChan(ch chan []byte) {
|
|
defer func() {
|
|
_ = recover()
|
|
}()
|
|
close(ch)
|
|
}
|