Jonathan Apodaca 1922138133
All checks were successful
CI / build (push) Successful in 23s
add lock subcommand; add --local flag to rm and exec
2025-05-12 22:47:09 -06:00

244 lines
7.5 KiB
Go

package daemon
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"git.jrop.me/jonathan/envvault/internal"
)
// Message represents the communication protocol between client and daemon
type Message struct {
Command string `json:"command"`
Password string `json:"password,omitempty"`
}
// Response represents the daemon's response to client requests
type Response struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
Data string `json:"data,omitempty"`
}
// PasswordCache manages the cached password with timeout
type PasswordCache struct {
password []byte
freshAccessTime time.Time
mutex sync.RWMutex
timeout time.Duration
}
// NewPasswordCache creates a new password cache with the specified timeout
func NewPasswordCache(timeout time.Duration) *PasswordCache {
return &PasswordCache{
timeout: timeout,
}
}
// Set stores a password in the cache
func (pc *PasswordCache) Set(password []byte) {
internal.DaemonLog.Println("Setting password in cache")
pc.mutex.Lock()
defer pc.mutex.Unlock()
pc.password = make([]byte, len(password))
copy(pc.password, password)
pc.freshAccessTime = time.Now()
internal.DaemonLog.Printf("Password set in cache, expires at: %s",
pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339))
}
// Get retrieves the password if it exists and hasn't expired
func (pc *PasswordCache) Get() ([]byte, bool) {
internal.DaemonLog.Println("Getting password from cache")
pc.mutex.Lock()
defer pc.mutex.Unlock()
if pc.password == nil {
internal.DaemonLog.Println("No password in cache")
return nil, false
}
timeSinceFreshAccess := time.Since(pc.freshAccessTime)
internal.DaemonLog.Printf("Time since fresh access: %s (timeout: %s)",
timeSinceFreshAccess, pc.timeout)
if timeSinceFreshAccess > pc.timeout {
internal.DaemonLog.Println("Password has expired, clearing")
if pc.password != nil {
// Securely clear the password by overwriting with zeros
for i := range pc.password {
pc.password[i] = 0
}
pc.password = nil
}
return nil, false
}
result := make([]byte, len(pc.password))
copy(result, pc.password)
internal.DaemonLog.Printf("Returning valid password, expires at: %s",
pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339))
return result, true
}
// StartDaemon starts the password caching daemon
func StartDaemon(socketPath string, timeout time.Duration) error {
internal.DaemonLog.Printf("Starting daemon with socket path: %s and timeout: %s",
socketPath, timeout)
// Ensure socket directory exists
socketDir := filepath.Dir(socketPath)
internal.DaemonLog.Printf("Creating socket directory: %s", socketDir)
if err := os.MkdirAll(socketDir, 0700); err != nil {
internal.DaemonLog.Printf("Failed to create socket directory: %v", err)
return fmt.Errorf("failed to create socket directory: %w", err)
}
// Remove existing socket if it exists
if _, err := os.Stat(socketPath); err == nil {
internal.DaemonLog.Printf("Removing existing socket: %s", socketPath)
if err := os.RemoveAll(socketPath); err != nil {
internal.DaemonLog.Printf("Failed to remove existing socket: %v", err)
return fmt.Errorf("failed to remove existing socket: %w", err)
}
}
// Create Unix domain socket
internal.DaemonLog.Println("Creating Unix domain socket")
listener, err := net.Listen("unix", socketPath)
if err != nil {
internal.DaemonLog.Printf("Failed to listen on socket: %v", err)
return fmt.Errorf("failed to listen on socket: %w", err)
}
defer listener.Close()
// Set socket permissions to only allow current user
internal.DaemonLog.Println("Setting socket permissions to 0600")
if err := os.Chmod(socketPath, 0600); err != nil {
internal.DaemonLog.Printf("Failed to set socket permissions: %v", err)
return fmt.Errorf("failed to set socket permissions: %w", err)
}
// Create password cache
internal.DaemonLog.Printf("Creating password cache with timeout: %s", timeout)
cache := NewPasswordCache(timeout)
// Handle signals for graceful shutdown
internal.DaemonLog.Println("Setting up signal handlers for graceful shutdown")
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
go func() {
sig := <-sigChan
internal.DaemonLog.Printf("Received signal: %s, shutting down", sig)
os.RemoveAll(socketPath)
os.Exit(0)
}()
// Start cleanup goroutine to periodically check for expired passwords
internal.DaemonLog.Printf("Starting cleanup goroutine with interval: %s", timeout/2)
go func() {
ticker := time.NewTicker(timeout / 2)
defer ticker.Stop()
for t := range ticker.C {
internal.DaemonLog.Printf("Cleanup tick at %s", t.Format(time.RFC3339))
// This will trigger cleanup of expired password
_, ok := cache.Get()
internal.DaemonLog.Printf("Cleanup check result: password exists=%v", ok)
}
}()
fmt.Fprintf(os.Stderr, "Password daemon started (timeout: %s)\n", timeout)
internal.DaemonLog.Println("Password daemon started successfully")
// Accept connections
internal.DaemonLog.Println("Entering accept loop")
for {
internal.DaemonLog.Println("Waiting for connections")
conn, err := listener.Accept()
if err != nil {
internal.DaemonLog.Printf("Error accepting connection: %v", err)
log.Printf("Error accepting connection: %v", err)
continue
}
internal.DaemonLog.Printf("Accepted connection from: %s", conn.RemoteAddr())
go handleConnection(conn, cache)
}
}
// handleConnection processes client requests
func handleConnection(conn net.Conn, cache *PasswordCache) {
defer conn.Close()
internal.DaemonLog.Println("Handling new connection")
decoder := json.NewDecoder(conn)
encoder := json.NewEncoder(conn)
var msg Message
if err := decoder.Decode(&msg); err != nil {
internal.DaemonLog.Printf("Failed to decode message: %v", err)
encoder.Encode(Response{Success: false, Error: "Invalid message format"})
return
}
internal.DaemonLog.Printf("Received command: %s", msg.Command)
switch msg.Command {
case "store":
if msg.Password == "" {
internal.DaemonLog.Println("Store command received with empty password")
encoder.Encode(Response{Success: false, Error: "No password provided"})
return
}
internal.DaemonLog.Println("Storing password in cache")
cache.Set([]byte(msg.Password))
internal.DaemonLog.Println("Password stored, sending success response")
encoder.Encode(Response{Success: true})
case "retrieve":
internal.DaemonLog.Println("Retrieve command received")
password, ok := cache.Get()
if !ok {
internal.DaemonLog.Println("No password available or password expired")
encoder.Encode(Response{Success: false, Error: "No password available or password expired"})
return
}
internal.DaemonLog.Println("Password retrieved, sending success response")
encoder.Encode(Response{Success: true, Data: string(password)})
case "kill":
internal.DaemonLog.Println("Kill command received, shutting down daemon")
encoder.Encode(Response{Success: true})
// Close the connection before exiting
conn.Close()
// Clean up the socket file
os.Remove(conn.LocalAddr().String())
internal.DaemonLog.Printf("Removed socket file: %s", conn.LocalAddr().String())
// Exit with success status
internal.DaemonLog.Println("Daemon shutting down gracefully")
// Use goroutine to allow response to be sent before exit
go func() {
time.Sleep(100 * time.Millisecond)
os.Exit(0)
}()
return
default:
internal.DaemonLog.Printf("Unknown command received: %s", msg.Command)
encoder.Encode(Response{Success: false, Error: "Unknown command"})
}
internal.DaemonLog.Println("Connection handled successfully")
}