package daemon import ( "encoding/json" "fmt" "log" "net" "os" "os/signal" "path/filepath" "sync" "syscall" "time" "git.jrop.me/jonathan/envvault/internal" ) // Message represents the communication protocol between client and daemon type Message struct { Command string `json:"command"` Password string `json:"password,omitempty"` } // Response represents the daemon's response to client requests type Response struct { Success bool `json:"success"` Error string `json:"error,omitempty"` Data string `json:"data,omitempty"` } // PasswordCache manages the cached password with timeout type PasswordCache struct { password []byte freshAccessTime time.Time mutex sync.RWMutex timeout time.Duration } // NewPasswordCache creates a new password cache with the specified timeout func NewPasswordCache(timeout time.Duration) *PasswordCache { return &PasswordCache{ timeout: timeout, } } // Set stores a password in the cache func (pc *PasswordCache) Set(password []byte) { internal.DaemonLog.Println("Setting password in cache") pc.mutex.Lock() defer pc.mutex.Unlock() pc.password = make([]byte, len(password)) copy(pc.password, password) pc.freshAccessTime = time.Now() internal.DaemonLog.Printf("Password set in cache, expires at: %s", pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339)) } // Get retrieves the password if it exists and hasn't expired func (pc *PasswordCache) Get() ([]byte, bool) { internal.DaemonLog.Println("Getting password from cache") pc.mutex.Lock() defer pc.mutex.Unlock() if pc.password == nil { internal.DaemonLog.Println("No password in cache") return nil, false } timeSinceFreshAccess := time.Since(pc.freshAccessTime) internal.DaemonLog.Printf("Time since fresh access: %s (timeout: %s)", timeSinceFreshAccess, pc.timeout) if timeSinceFreshAccess > pc.timeout { internal.DaemonLog.Println("Password has expired, clearing") if pc.password != nil { // Securely clear the password by overwriting with zeros for i := range pc.password { pc.password[i] = 0 } pc.password = nil } return nil, false } result := make([]byte, len(pc.password)) copy(result, pc.password) internal.DaemonLog.Println("Returning valid password, expires at: %s", pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339)) return result, true } // StartDaemon starts the password caching daemon func StartDaemon(socketPath string, timeout time.Duration) error { internal.DaemonLog.Println("Starting daemon with socket path: %s and timeout: %s", socketPath, timeout) // Ensure socket directory exists socketDir := filepath.Dir(socketPath) internal.DaemonLog.Println("Creating socket directory: %s", socketDir) if err := os.MkdirAll(socketDir, 0700); err != nil { internal.DaemonLog.Println("Failed to create socket directory: %v", err) return fmt.Errorf("failed to create socket directory: %w", err) } // Remove existing socket if it exists if _, err := os.Stat(socketPath); err == nil { internal.DaemonLog.Println("Removing existing socket: %s", socketPath) if err := os.RemoveAll(socketPath); err != nil { internal.DaemonLog.Println("Failed to remove existing socket: %v", err) return fmt.Errorf("failed to remove existing socket: %w", err) } } // Create Unix domain socket internal.DaemonLog.Println("Creating Unix domain socket") listener, err := net.Listen("unix", socketPath) if err != nil { internal.DaemonLog.Println("Failed to listen on socket: %v", err) return fmt.Errorf("failed to listen on socket: %w", err) } defer listener.Close() // Set socket permissions to only allow current user internal.DaemonLog.Println("Setting socket permissions to 0600") if err := os.Chmod(socketPath, 0600); err != nil { internal.DaemonLog.Println("Failed to set socket permissions: %v", err) return fmt.Errorf("failed to set socket permissions: %w", err) } // Create password cache internal.DaemonLog.Println("Creating password cache with timeout: %s", timeout) cache := NewPasswordCache(timeout) // Handle signals for graceful shutdown internal.DaemonLog.Println("Setting up signal handlers for graceful shutdown") sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-sigChan internal.DaemonLog.Println("Received signal: %s, shutting down", sig) os.RemoveAll(socketPath) os.Exit(0) }() // Start cleanup goroutine to periodically check for expired passwords internal.DaemonLog.Println("Starting cleanup goroutine with interval: %s", timeout/2) go func() { ticker := time.NewTicker(timeout / 2) defer ticker.Stop() for t := range ticker.C { internal.DaemonLog.Println("Cleanup tick at %s", t.Format(time.RFC3339)) // This will trigger cleanup of expired password _, ok := cache.Get() internal.DaemonLog.Println("Cleanup check result: password exists=%v", ok) } }() fmt.Fprintf(os.Stderr, "Password daemon started (timeout: %s)\n", timeout) internal.DaemonLog.Println("Password daemon started successfully") // Accept connections internal.DaemonLog.Println("Entering accept loop") for { internal.DaemonLog.Println("Waiting for connections") conn, err := listener.Accept() if err != nil { internal.DaemonLog.Println("Error accepting connection: %v", err) log.Printf("Error accepting connection: %v", err) continue } internal.DaemonLog.Println("Accepted connection from: %s", conn.RemoteAddr()) go handleConnection(conn, cache) } } // handleConnection processes client requests func handleConnection(conn net.Conn, cache *PasswordCache) { defer conn.Close() internal.DaemonLog.Println("Handling new connection") decoder := json.NewDecoder(conn) encoder := json.NewEncoder(conn) var msg Message if err := decoder.Decode(&msg); err != nil { internal.DaemonLog.Println("Failed to decode message: %v", err) encoder.Encode(Response{Success: false, Error: "Invalid message format"}) return } internal.DaemonLog.Println("Received command: %s", msg.Command) switch msg.Command { case "store": if msg.Password == "" { internal.DaemonLog.Println("Store command received with empty password") encoder.Encode(Response{Success: false, Error: "No password provided"}) return } internal.DaemonLog.Println("Storing password in cache") cache.Set([]byte(msg.Password)) internal.DaemonLog.Println("Password stored, sending success response") encoder.Encode(Response{Success: true}) case "retrieve": internal.DaemonLog.Println("Retrieve command received") password, ok := cache.Get() if !ok { internal.DaemonLog.Println("No password available or password expired") encoder.Encode(Response{Success: false, Error: "No password available or password expired"}) return } internal.DaemonLog.Println("Password retrieved, sending success response") encoder.Encode(Response{Success: true, Data: string(password)}) default: internal.DaemonLog.Println("Unknown command received: %s", msg.Command) encoder.Encode(Response{Success: false, Error: "Unknown command"}) } internal.DaemonLog.Println("Connection handled successfully") }