227 lines
7.0 KiB
Go
227 lines
7.0 KiB
Go
package daemon
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"git.jrop.me/jonathan/envvault/internal"
|
|
)
|
|
|
|
// Message represents the communication protocol between client and daemon
|
|
type Message struct {
|
|
Command string `json:"command"`
|
|
Password string `json:"password,omitempty"`
|
|
}
|
|
|
|
// Response represents the daemon's response to client requests
|
|
type Response struct {
|
|
Success bool `json:"success"`
|
|
Error string `json:"error,omitempty"`
|
|
Data string `json:"data,omitempty"`
|
|
}
|
|
|
|
// PasswordCache manages the cached password with timeout
|
|
type PasswordCache struct {
|
|
password []byte
|
|
freshAccessTime time.Time
|
|
mutex sync.RWMutex
|
|
timeout time.Duration
|
|
}
|
|
|
|
// NewPasswordCache creates a new password cache with the specified timeout
|
|
func NewPasswordCache(timeout time.Duration) *PasswordCache {
|
|
return &PasswordCache{
|
|
timeout: timeout,
|
|
}
|
|
}
|
|
|
|
// Set stores a password in the cache
|
|
func (pc *PasswordCache) Set(password []byte) {
|
|
internal.DaemonLog.Println("Setting password in cache")
|
|
pc.mutex.Lock()
|
|
defer pc.mutex.Unlock()
|
|
|
|
pc.password = make([]byte, len(password))
|
|
copy(pc.password, password)
|
|
pc.freshAccessTime = time.Now()
|
|
internal.DaemonLog.Printf("Password set in cache, expires at: %s",
|
|
pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339))
|
|
}
|
|
|
|
// Get retrieves the password if it exists and hasn't expired
|
|
func (pc *PasswordCache) Get() ([]byte, bool) {
|
|
internal.DaemonLog.Println("Getting password from cache")
|
|
pc.mutex.Lock()
|
|
defer pc.mutex.Unlock()
|
|
|
|
if pc.password == nil {
|
|
internal.DaemonLog.Println("No password in cache")
|
|
return nil, false
|
|
}
|
|
|
|
timeSinceFreshAccess := time.Since(pc.freshAccessTime)
|
|
internal.DaemonLog.Printf("Time since fresh access: %s (timeout: %s)",
|
|
timeSinceFreshAccess, pc.timeout)
|
|
|
|
if timeSinceFreshAccess > pc.timeout {
|
|
internal.DaemonLog.Println("Password has expired, clearing")
|
|
if pc.password != nil {
|
|
// Securely clear the password by overwriting with zeros
|
|
for i := range pc.password {
|
|
pc.password[i] = 0
|
|
}
|
|
pc.password = nil
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
result := make([]byte, len(pc.password))
|
|
copy(result, pc.password)
|
|
|
|
internal.DaemonLog.Printf("Returning valid password, expires at: %s",
|
|
pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339))
|
|
return result, true
|
|
}
|
|
|
|
// StartDaemon starts the password caching daemon
|
|
func StartDaemon(socketPath string, timeout time.Duration) error {
|
|
internal.DaemonLog.Printf("Starting daemon with socket path: %s and timeout: %s",
|
|
socketPath, timeout)
|
|
|
|
// Ensure socket directory exists
|
|
socketDir := filepath.Dir(socketPath)
|
|
internal.DaemonLog.Printf("Creating socket directory: %s", socketDir)
|
|
if err := os.MkdirAll(socketDir, 0700); err != nil {
|
|
internal.DaemonLog.Printf("Failed to create socket directory: %v", err)
|
|
return fmt.Errorf("failed to create socket directory: %w", err)
|
|
}
|
|
|
|
// Remove existing socket if it exists
|
|
if _, err := os.Stat(socketPath); err == nil {
|
|
internal.DaemonLog.Printf("Removing existing socket: %s", socketPath)
|
|
if err := os.RemoveAll(socketPath); err != nil {
|
|
internal.DaemonLog.Printf("Failed to remove existing socket: %v", err)
|
|
return fmt.Errorf("failed to remove existing socket: %w", err)
|
|
}
|
|
}
|
|
|
|
// Create Unix domain socket
|
|
internal.DaemonLog.Println("Creating Unix domain socket")
|
|
listener, err := net.Listen("unix", socketPath)
|
|
if err != nil {
|
|
internal.DaemonLog.Printf("Failed to listen on socket: %v", err)
|
|
return fmt.Errorf("failed to listen on socket: %w", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
// Set socket permissions to only allow current user
|
|
internal.DaemonLog.Println("Setting socket permissions to 0600")
|
|
if err := os.Chmod(socketPath, 0600); err != nil {
|
|
internal.DaemonLog.Printf("Failed to set socket permissions: %v", err)
|
|
return fmt.Errorf("failed to set socket permissions: %w", err)
|
|
}
|
|
|
|
// Create password cache
|
|
internal.DaemonLog.Printf("Creating password cache with timeout: %s", timeout)
|
|
cache := NewPasswordCache(timeout)
|
|
|
|
// Handle signals for graceful shutdown
|
|
internal.DaemonLog.Println("Setting up signal handlers for graceful shutdown")
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
sig := <-sigChan
|
|
internal.DaemonLog.Printf("Received signal: %s, shutting down", sig)
|
|
os.RemoveAll(socketPath)
|
|
os.Exit(0)
|
|
}()
|
|
|
|
// Start cleanup goroutine to periodically check for expired passwords
|
|
internal.DaemonLog.Printf("Starting cleanup goroutine with interval: %s", timeout/2)
|
|
go func() {
|
|
ticker := time.NewTicker(timeout / 2)
|
|
defer ticker.Stop()
|
|
|
|
for t := range ticker.C {
|
|
internal.DaemonLog.Printf("Cleanup tick at %s", t.Format(time.RFC3339))
|
|
// This will trigger cleanup of expired password
|
|
_, ok := cache.Get()
|
|
internal.DaemonLog.Printf("Cleanup check result: password exists=%v", ok)
|
|
}
|
|
}()
|
|
|
|
fmt.Fprintf(os.Stderr, "Password daemon started (timeout: %s)\n", timeout)
|
|
internal.DaemonLog.Println("Password daemon started successfully")
|
|
|
|
// Accept connections
|
|
internal.DaemonLog.Println("Entering accept loop")
|
|
for {
|
|
internal.DaemonLog.Println("Waiting for connections")
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
internal.DaemonLog.Printf("Error accepting connection: %v", err)
|
|
log.Printf("Error accepting connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
internal.DaemonLog.Printf("Accepted connection from: %s", conn.RemoteAddr())
|
|
go handleConnection(conn, cache)
|
|
}
|
|
}
|
|
|
|
// handleConnection processes client requests
|
|
func handleConnection(conn net.Conn, cache *PasswordCache) {
|
|
defer conn.Close()
|
|
internal.DaemonLog.Println("Handling new connection")
|
|
|
|
decoder := json.NewDecoder(conn)
|
|
encoder := json.NewEncoder(conn)
|
|
|
|
var msg Message
|
|
if err := decoder.Decode(&msg); err != nil {
|
|
internal.DaemonLog.Printf("Failed to decode message: %v", err)
|
|
encoder.Encode(Response{Success: false, Error: "Invalid message format"})
|
|
return
|
|
}
|
|
|
|
internal.DaemonLog.Printf("Received command: %s", msg.Command)
|
|
switch msg.Command {
|
|
case "store":
|
|
if msg.Password == "" {
|
|
internal.DaemonLog.Println("Store command received with empty password")
|
|
encoder.Encode(Response{Success: false, Error: "No password provided"})
|
|
return
|
|
}
|
|
internal.DaemonLog.Println("Storing password in cache")
|
|
cache.Set([]byte(msg.Password))
|
|
internal.DaemonLog.Println("Password stored, sending success response")
|
|
encoder.Encode(Response{Success: true})
|
|
|
|
case "retrieve":
|
|
internal.DaemonLog.Println("Retrieve command received")
|
|
password, ok := cache.Get()
|
|
if !ok {
|
|
internal.DaemonLog.Println("No password available or password expired")
|
|
encoder.Encode(Response{Success: false, Error: "No password available or password expired"})
|
|
return
|
|
}
|
|
internal.DaemonLog.Println("Password retrieved, sending success response")
|
|
encoder.Encode(Response{Success: true, Data: string(password)})
|
|
|
|
default:
|
|
internal.DaemonLog.Printf("Unknown command received: %s", msg.Command)
|
|
encoder.Encode(Response{Success: false, Error: "Unknown command"})
|
|
}
|
|
|
|
internal.DaemonLog.Println("Connection handled successfully")
|
|
}
|