From e278c40de19e8cc066ff26905eb1e803c3f5c467 Mon Sep 17 00:00:00 2001 From: Jonathan Apodaca Date: Thu, 8 May 2025 21:09:32 -0600 Subject: [PATCH] initial commit --- .github/workflows/ci.yml | 23 ++ .gitignore | 3 + Makefile | 19 ++ cmd/envvault/main.go | 508 ++++++++++++++++++++++++++++++++++++++ go.mod | 11 + go.sum | 14 ++ internal/daemon/client.go | 125 ++++++++++ internal/daemon/daemon.go | 226 +++++++++++++++++ internal/log.go | 58 +++++ 9 files changed, 987 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 cmd/envvault/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/daemon/client.go create mode 100644 internal/daemon/daemon.go create mode 100644 internal/log.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..2061fc9 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,23 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Build and Test + run: make diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1ef4283 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/build +/envvault +.aider* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..45d902a --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +envvault: .PHONY + go build -o envvault ./cmd/envvault/main.go + +.PHONY: + +fmt: + go fmt ./... + +local-install: envvault + ln -s `pwd`/envvault ~/.local/bin/envvault + +cross-compile: + mkdir -p build + GOOS=linux GOARCH=amd64 go build -o build/envvault-linux-amd64 ./... + GOOS=linux GOARCH=arm64 go build -o build/envvault-linux-arm64 ./... + GOOS=darwin GOARCH=amd64 go build -o build/envvault-darwin-amd64 ./... + GOOS=darwin GOARCH=arm64 go build -o build/envvault-darwin-arm64 ./... + GOOS=windows GOARCH=amd64 go build -o build/envvault-windows-amd64.exe ./... + GOOS=windows GOARCH=arm64 go build -o build/envvault-windows-arm64.exe ./... diff --git a/cmd/envvault/main.go b/cmd/envvault/main.go new file mode 100644 index 0000000..f92d1b9 --- /dev/null +++ b/cmd/envvault/main.go @@ -0,0 +1,508 @@ +package main + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + "time" + + "github.com/alecthomas/kong" + "golang.org/x/crypto/nacl/secretbox" + "golang.org/x/crypto/scrypt" + "golang.org/x/term" + "slices" + + "git.jrop.me/jonathan/envvault/internal" + "git.jrop.me/jonathan/envvault/internal/daemon" +) + +var cachedPassword []byte + +var ( + keyFilePath string + dbFilePath string +) + +func init() { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Fatalf("Failed to get user home directory: %v", err) + } + keyFilePath = filepath.Join(homeDir, ".local/cache/envvault/key") + dbFilePath = filepath.Join(homeDir, ".local/cache/envvault/db.json.enc") +} + +type EnvStore struct { + Vars map[string]string `json:"vars"` +} + +type CLI struct { + Init InitCmd `cmd:"" help:"Initialize the vault."` + List ListCmd `cmd:"" help:"List all environment variables."` + Add AddCmd `cmd:"" help:"Add an environment variable."` + Rm RmCmd `cmd:"" help:"Remove an environment variable."` + Exec ExecCmd `cmd:"" help:"Execute a command with environment variables." alias:"x"` + Rekey RekeyCmd `cmd:"" help:"Change the master password."` + Daemon DaemonCmd `cmd:"" help:"Start the password caching daemon."` +} + +type InitCmd struct{} +type ListCmd struct { + Values bool `short:"v" help:"Show values of environment variables." default:"false"` +} +type AddCmd struct { + Name string `arg:"" help:"Name of the environment variable."` + Value string `arg:"" optional:"" help:"Value of the environment variable."` +} +type RmCmd struct { + Name string `arg:"" help:"Name of the environment variable to remove."` +} +type ExecCmd struct { + Env []string `short:"e" help:"Environment variables to set."` + Cmd string `arg:"" help:"Command to execute." passthrough:"all"` + Args []string `arg:"" optional:"" help:"Arguments for the command."` +} +type RekeyCmd struct{} +type DaemonCmd struct { + Timeout int `help:"Password cache timeout in minutes." default:"5"` +} + +func main() { + internal.Log.Println("Starting envvault") + + cli := CLI{} + ctx := kong.Parse(&cli) + + command := ctx.Command() + internal.Log.Printf("Executing command: %s", command) + + switch command { + case "init": + subcommandInitVault() + case "add ": + subcommandAddEnvVar(cli.Add.Name, cli.Add.Value) + case "add ": + subcommandAddEnvVar(cli.Add.Name, cli.Add.Value) + case "rm ": + subcommandRmEnvVar(cli.Rm.Name) + case "list": + subcommandListEnvVars(cli.List) + case "exec ": + subcommandExecCommand(cli.Exec) + case "exec ": + subcommandExecCommand(cli.Exec) + case "rekey": + subcommandRekeyVault() + case "daemon": + subcommandStartDaemon(cli.Daemon.Timeout) + default: + internal.Log.Printf("Unknown command: %s", command) + log.Fatal("Unknown command") + } + + internal.Log.Printf("Command completed: %s", command) +} + +func subcommandInitVault() { + internal.Log.Println("Initializing vault") + + if _, err := os.Stat(keyFilePath); os.IsNotExist(err) { + internal.Log.Println("Key file does not exist, creating new vault") + + fmt.Fprint(os.Stderr, "Enter a new master password: ") + password, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + internal.Log.Printf("Failed to read password: %v", err) + log.Fatal(err) + } + + fmt.Fprint(os.Stderr, "Enter a new master password (again): ") + password2, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + internal.Log.Printf("Failed to read confirmation password: %v", err) + log.Fatal(err) + } + + if !bytes.Equal(password, password2) { + internal.Log.Println("Passwords do not match") + log.Fatal("Passwords do not match") + } + internal.Log.Println("Passwords match, caching password") + cachedPassword = password + + // Generate a random key + internal.Log.Println("Generating random key") + var key [32]byte + if _, err := rand.Read(key[:]); err != nil { + internal.Log.Printf("Failed to generate random key: %v", err) + log.Fatal(err) + } + + // Encrypt the key with the password + internal.Log.Println("Encrypting key with password") + encryptedKey := encryptKeyWithPassword(key[:], password) + + internal.Log.Printf("Creating key directory: %s", filepath.Dir(keyFilePath)) + if err := os.MkdirAll(filepath.Dir(keyFilePath), 0700); err != nil { + internal.Log.Printf("Failed to create key directory: %v", err) + log.Fatal(err) + } + + internal.Log.Printf("Writing encrypted key to: %s", keyFilePath) + if err := os.WriteFile(keyFilePath, encryptedKey, 0600); err != nil { + internal.Log.Printf("Failed to write key file: %v", err) + log.Fatal(err) + } + + // Initialize the db file + internal.Log.Println("Initializing empty environment store") + emptyStore := &EnvStore{Vars: make(map[string]string)} + saveEnvStore(emptyStore) + + internal.Log.Println("Vault initialized successfully") + fmt.Fprintln(os.Stderr, "Vault initialized successfully.") + } else { + internal.Log.Println("Vault already initialized") + fmt.Fprintln(os.Stderr, "Vault already initialized.") + } +} + +func subcommandAddEnvVar(name, value string) { + store := loadEnvStore() + if value == "" { + fmt.Fprintf(os.Stderr, "Enter value for %s: ", name) + inputValue, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + log.Fatal(err) + } + value = string(inputValue) + } + store.Vars[name] = value + saveEnvStore(store) +} + +func subcommandRmEnvVar(name string) { + store := loadEnvStore() + if _, exists := store.Vars[name]; exists { + delete(store.Vars, name) + saveEnvStore(store) + fmt.Fprintf(os.Stderr, "Environment variable '%s' removed.\n", name) + } else { + fmt.Fprintf(os.Stderr, "Environment variable '%s' not found.\n", name) + } +} + +func subcommandListEnvVars(cmdArgs ListCmd) { + store := loadEnvStore() + for k, v := range store.Vars { + if cmdArgs.Values { + fmt.Printf("%s=%s\n", k, v) + } else { + fmt.Printf("%s\n", k) + } + } +} + +func subcommandExecCommand(cmdArgs ExecCmd) { + store := loadEnvStore() + envVars := os.Environ() + for candidateEnvName, candidateEnvValue := range store.Vars { + if len(cmdArgs.Env) == 0 { + // If no env vars are specified, add all env vars + envVars = append(envVars, fmt.Sprintf("%s=%s", candidateEnvName, candidateEnvValue)) + continue + } + if slices.Contains(cmdArgs.Env, candidateEnvName) { + envVars = append(envVars, fmt.Sprintf("%s=%s", candidateEnvName, candidateEnvValue)) + } + } + + cmd := exec.Command(cmdArgs.Cmd, cmdArgs.Args...) + cmd.Env = envVars + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + + if err := cmd.Run(); err != nil { + log.Fatal(err) + } +} + +func loadEnvStore() *EnvStore { + key := loadKey() + if err := os.MkdirAll(filepath.Dir(dbFilePath), 0700); err != nil { + log.Fatal(err) + } + data, err := os.ReadFile(dbFilePath) + if err != nil { + if os.IsNotExist(err) { + return &EnvStore{Vars: make(map[string]string)} + } + log.Fatal(err) + } + + var store EnvStore + if len(data) > 0 { + var nonce [24]byte + copy(nonce[:], data[:24]) + decrypted, ok := secretbox.Open(nil, data[24:], &nonce, &key) + if !ok { + log.Fatal("Decryption failed") + } + if err := json.Unmarshal(decrypted, &store); err != nil { + log.Fatal(err) + } + } else { + return &EnvStore{Vars: make(map[string]string)} + } + return &store +} + +func saveEnvStore(store *EnvStore) { + key := loadKey() + data, err := json.Marshal(store) + if err != nil { + log.Fatal(err) + } + + var nonce [24]byte + if _, err := rand.Read(nonce[:]); err != nil { + log.Fatal(err) + } + encrypted := secretbox.Seal(nonce[:], data, &nonce, &key) + + if err := os.WriteFile(dbFilePath, encrypted, 0600); err != nil { + log.Fatal(err) + } +} + +func loadPassword() []byte { + internal.Log.Println("Loading password") + + // 1. Check for globally cached password + if cachedPassword != nil { + internal.Log.Println("Using cached password from memory") + return cachedPassword + } + + internal.Log.Println("No cached password in memory, trying daemon") + + // 2. Try to get password from daemon + socketPath := daemon.GetSocketPath() + internal.Log.Printf("Using socket path: %s", socketPath) + client := daemon.NewClient(socketPath) + + // Check if daemon is running + if !client.IsRunning() { + internal.Log.Println("Daemon not running, attempting to start it") + // Spawn daemon in background + cmd := exec.Command(os.Args[0], "daemon") + cmd.Stdout = nil + cmd.Stderr = nil + if err := cmd.Start(); err != nil { + internal.Log.Printf("Failed to start daemon: %v", err) + log.Printf("Failed to start daemon: %v", err) + } else { + internal.Log.Printf("Started daemon process with PID: %d", cmd.Process.Pid) + // Detach the process + cmd.Process.Release() + + // Give daemon time to start + internal.Log.Println("Waiting for daemon to initialize") + time.Sleep(100 * time.Millisecond) + } + } else { + internal.Log.Println("Daemon is already running") + } + + // Try to retrieve password from daemon if it's running + if client.IsRunning() { + internal.Log.Println("Attempting to retrieve password from daemon") + password, err := client.RetrievePassword() + if err == nil { + internal.Log.Println("Successfully retrieved password from daemon") + // Cache the password + cachedPassword = password + return password + } + internal.Log.Printf("Failed to retrieve password from daemon: %v", err) + } else { + internal.Log.Println("Daemon still not running after attempt to start it") + } + + internal.Log.Println("Falling back to terminal input for password") + // 3. Fall back to terminal input + fmt.Fprint(os.Stderr, "Enter master password: ") + password, err := term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + internal.Log.Printf("Failed to read password from terminal: %v", err) + log.Fatal(err) + } + fmt.Fprintln(os.Stderr) // Ensure newline after password input + + internal.Log.Println("Password read from terminal, caching in memory") + // Cache the password + cachedPassword = password + + // Store in daemon if it's running + if client.IsRunning() { + internal.Log.Println("Storing password in daemon") + if err := client.StorePassword(password); err != nil { + internal.Log.Printf("Failed to store password in daemon: %v", err) + log.Printf("Failed to store password in daemon: %v", err) + } else { + internal.Log.Println("Successfully stored password in daemon") + } + } + + return password +} + +func loadKey() [32]byte { + password := loadPassword() + + encryptedKey, err := os.ReadFile(keyFilePath) + if err != nil { + log.Fatal(err) + } + + decryptedKey, err := decryptKeyWithPassword(encryptedKey, password) + if err != nil { + // Clear cached password on error + cachedPassword = nil + log.Fatal("Invalid password or corrupted key file") + } + + var key [32]byte + copy(key[:], decryptedKey) + return key +} + +func encryptKeyWithPassword(key, password []byte) []byte { + var nonce [24]byte + if _, err := rand.Read(nonce[:]); err != nil { + log.Fatal(err) + } + + passwordDerivedKey, err := scrypt.Key(password, nonce[:], 32768, 8, 1, 32) + if err != nil { + log.Fatal(err) + } + + encrypted := secretbox.Seal(nonce[:], key, &nonce, (*[32]byte)(passwordDerivedKey)) + return encrypted +} + +func decryptKeyWithPassword(encryptedKey, password []byte) ([]byte, error) { + var nonce [24]byte + copy(nonce[:], encryptedKey[:24]) + encrypted := encryptedKey[24:] + + passwordDerivedKey, err := scrypt.Key(password, nonce[:], 32768, 8, 1, 32) + if err != nil { + return nil, err + } + + decrypted, ok := secretbox.Open(nil, encrypted, &nonce, (*[32]byte)(passwordDerivedKey)) + if !ok { + return nil, fmt.Errorf("decryption failed") + } + + return decrypted, nil +} + +func subcommandStartDaemon(timeoutMinutes int) { + internal.Log.Printf("Starting daemon with timeout: %d minutes", timeoutMinutes) + + timeout := time.Duration(timeoutMinutes) * time.Minute + socketPath := daemon.GetSocketPath() + internal.Log.Printf("Using socket path: %s", socketPath) + + // Ensure socket directory exists + socketDir := filepath.Dir(socketPath) + internal.Log.Printf("Creating socket directory: %s", socketDir) + if err := os.MkdirAll(socketDir, 0700); err != nil { + internal.Log.Printf("Failed to create socket directory: %v", err) + log.Fatalf("Failed to create socket directory: %v", err) + } + + fmt.Printf("Starting password daemon (timeout: %s)\n", timeout) + fmt.Printf("Socket path: %s\n", socketPath) + + internal.Log.Println("Calling daemon.StartDaemon") + // Start the daemon + if err := daemon.StartDaemon(socketPath, timeout); err != nil { + internal.Log.Printf("Failed to start daemon: %v", err) + log.Fatalf("Failed to start daemon: %v", err) + } +} + +func subcommandRekeyVault() { + // Ask for current master password + fmt.Fprint(os.Stderr, "Enter current master password: ") + currentPassword, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + log.Fatal(err) + } + + // Load and decrypt the master key with current password + encryptedKey, err := os.ReadFile(keyFilePath) + if err != nil { + log.Fatal(err) + } + + masterKey, err := decryptKeyWithPassword(encryptedKey, currentPassword) + if err != nil { + log.Fatal("Invalid password or corrupted key file") + } + + // Ask for new master password + fmt.Fprint(os.Stderr, "Enter new master password: ") + newPassword, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + log.Fatal(err) + } + + // Ask for new master password confirmation + fmt.Fprint(os.Stderr, "Enter new master password (again): ") + newPassword2, err := term.ReadPassword(int(os.Stdin.Fd())) + fmt.Fprintln(os.Stderr) + if err != nil { + log.Fatal(err) + } + + // Verify passwords match + if !bytes.Equal(newPassword, newPassword2) { + log.Fatal("Passwords do not match") + } + + // Encrypt the master key with the new password + newEncryptedKey := encryptKeyWithPassword(masterKey, newPassword) + + // Backup the current key file + backupPath := keyFilePath + ".bak" + if err := os.Rename(keyFilePath, backupPath); err != nil { + log.Fatal(err) + } + fmt.Fprintf(os.Stderr, "Backed up key file to %s\n", backupPath) + + // Write the new encrypted key + if err := os.WriteFile(keyFilePath, newEncryptedKey, 0600); err != nil { + log.Fatal(err) + } + + // Update cached password + cachedPassword = newPassword + + fmt.Fprintln(os.Stderr, "Master password changed successfully.") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6b8c53e --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module git.jrop.me/jonathan/envvault + +go 1.24.3 + +require ( + github.com/alecthomas/kong v1.10.0 + golang.org/x/crypto v0.38.0 + golang.org/x/term v0.32.0 +) + +require golang.org/x/sys v0.33.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6f0caa4 --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/kong v1.10.0 h1:8K4rGDpT7Iu+jEXCIJUeKqvpwZHbsFRoebLbnzlmrpw= +github.com/alecthomas/kong v1.10.0/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= +golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= diff --git a/internal/daemon/client.go b/internal/daemon/client.go new file mode 100644 index 0000000..10762bb --- /dev/null +++ b/internal/daemon/client.go @@ -0,0 +1,125 @@ +package daemon + +import ( + "encoding/json" + "fmt" + "net" + "os" + "path/filepath" + + "git.jrop.me/jonathan/envvault/internal" +) + +// Client provides methods to interact with the password daemon +type Client struct { + socketPath string +} + +// NewClient creates a new client for the password daemon +func NewClient(socketPath string) *Client { + return &Client{ + socketPath: socketPath, + } +} + +// IsRunning checks if the daemon is running +func (c *Client) IsRunning() bool { + internal.DaemonLog.Printf("Checking if daemon is running at: %s", c.socketPath) + conn, err := net.Dial("unix", c.socketPath) + if err != nil { + internal.DaemonLog.Printf("Daemon is not running: %v", err) + return false + } + conn.Close() + internal.DaemonLog.Printf("Daemon is running") + return true +} + +// StorePassword sends the password to the daemon for caching +func (c *Client) StorePassword(password []byte) error { + internal.DaemonLog.Printf("Storing password in daemon") + msg := Message{ + Command: "store", + Password: string(password), + } + + internal.DaemonLog.Printf("Sending store message to daemon") + resp, err := c.sendMessage(msg) + if err != nil { + internal.DaemonLog.Printf("Failed to send store message: %v", err) + return err + } + + if !resp.Success { + internal.DaemonLog.Printf("Daemon reported error: %s", resp.Error) + return fmt.Errorf("failed to store password: %s", resp.Error) + } + + internal.DaemonLog.Printf("Password stored successfully") + return nil +} + +// RetrievePassword gets the cached password from the daemon +func (c *Client) RetrievePassword() ([]byte, error) { + internal.DaemonLog.Printf("Retrieving password from daemon") + msg := Message{ + Command: "retrieve", + } + + internal.DaemonLog.Printf("Sending retrieve message to daemon") + resp, err := c.sendMessage(msg) + if err != nil { + internal.DaemonLog.Printf("Failed to send retrieve message: %v", err) + return nil, err + } + + if !resp.Success { + internal.DaemonLog.Printf("Daemon reported error: %s", resp.Error) + return nil, fmt.Errorf("failed to retrieve password: %s", resp.Error) + } + + internal.DaemonLog.Printf("Password retrieved successfully") + return []byte(resp.Data), nil +} + +// sendMessage sends a message to the daemon and returns the response +func (c *Client) sendMessage(msg Message) (*Response, error) { + internal.DaemonLog.Printf("Connecting to daemon socket: %s", c.socketPath) + conn, err := net.Dial("unix", c.socketPath) + if err != nil { + internal.DaemonLog.Printf("Failed to connect to daemon: %v", err) + return nil, fmt.Errorf("failed to connect to daemon: %w", err) + } + defer conn.Close() + + internal.DaemonLog.Printf("Connected to daemon, encoding message") + encoder := json.NewEncoder(conn) + decoder := json.NewDecoder(conn) + + if err := encoder.Encode(msg); err != nil { + internal.DaemonLog.Printf("Failed to encode message: %v", err) + return nil, fmt.Errorf("failed to encode message: %w", err) + } + + internal.DaemonLog.Printf("Message sent, waiting for response") + var resp Response + if err := decoder.Decode(&resp); err != nil { + internal.DaemonLog.Printf("Failed to decode response: %v", err) + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + internal.DaemonLog.Printf("Response received: success=%v", resp.Success) + return &resp, nil +} + +// GetSocketPath returns the default socket path for the current user +func GetSocketPath() string { + homeDir, err := os.UserHomeDir() + if err != nil { + // Fall back to current directory if home dir can't be determined + internal.DaemonLog.Printf("Failed to get user home directory: %v", err) + return filepath.Join(".local/cache/envvault", "daemon.sock") + } + runtimeDir := filepath.Join(homeDir, ".local/cache/envvault") + return filepath.Join(runtimeDir, "daemon.sock") +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go new file mode 100644 index 0000000..83363eb --- /dev/null +++ b/internal/daemon/daemon.go @@ -0,0 +1,226 @@ +package daemon + +import ( + "encoding/json" + "fmt" + "log" + "net" + "os" + "os/signal" + "path/filepath" + "sync" + "syscall" + "time" + + "git.jrop.me/jonathan/envvault/internal" +) + +// Message represents the communication protocol between client and daemon +type Message struct { + Command string `json:"command"` + Password string `json:"password,omitempty"` +} + +// Response represents the daemon's response to client requests +type Response struct { + Success bool `json:"success"` + Error string `json:"error,omitempty"` + Data string `json:"data,omitempty"` +} + +// PasswordCache manages the cached password with timeout +type PasswordCache struct { + password []byte + freshAccessTime time.Time + mutex sync.RWMutex + timeout time.Duration +} + +// NewPasswordCache creates a new password cache with the specified timeout +func NewPasswordCache(timeout time.Duration) *PasswordCache { + return &PasswordCache{ + timeout: timeout, + } +} + +// Set stores a password in the cache +func (pc *PasswordCache) Set(password []byte) { + internal.DaemonLog.Println("Setting password in cache") + pc.mutex.Lock() + defer pc.mutex.Unlock() + + pc.password = make([]byte, len(password)) + copy(pc.password, password) + pc.freshAccessTime = time.Now() + internal.DaemonLog.Printf("Password set in cache, expires at: %s", + pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339)) +} + +// Get retrieves the password if it exists and hasn't expired +func (pc *PasswordCache) Get() ([]byte, bool) { + internal.DaemonLog.Println("Getting password from cache") + pc.mutex.Lock() + defer pc.mutex.Unlock() + + if pc.password == nil { + internal.DaemonLog.Println("No password in cache") + return nil, false + } + + timeSinceFreshAccess := time.Since(pc.freshAccessTime) + internal.DaemonLog.Printf("Time since fresh access: %s (timeout: %s)", + timeSinceFreshAccess, pc.timeout) + + if timeSinceFreshAccess > pc.timeout { + internal.DaemonLog.Println("Password has expired, clearing") + if pc.password != nil { + // Securely clear the password by overwriting with zeros + for i := range pc.password { + pc.password[i] = 0 + } + pc.password = nil + } + return nil, false + } + + result := make([]byte, len(pc.password)) + copy(result, pc.password) + + internal.DaemonLog.Println("Returning valid password, expires at: %s", + pc.freshAccessTime.Add(pc.timeout).Format(time.RFC3339)) + return result, true +} + +// StartDaemon starts the password caching daemon +func StartDaemon(socketPath string, timeout time.Duration) error { + internal.DaemonLog.Println("Starting daemon with socket path: %s and timeout: %s", + socketPath, timeout) + + // Ensure socket directory exists + socketDir := filepath.Dir(socketPath) + internal.DaemonLog.Println("Creating socket directory: %s", socketDir) + if err := os.MkdirAll(socketDir, 0700); err != nil { + internal.DaemonLog.Println("Failed to create socket directory: %v", err) + return fmt.Errorf("failed to create socket directory: %w", err) + } + + // Remove existing socket if it exists + if _, err := os.Stat(socketPath); err == nil { + internal.DaemonLog.Println("Removing existing socket: %s", socketPath) + if err := os.RemoveAll(socketPath); err != nil { + internal.DaemonLog.Println("Failed to remove existing socket: %v", err) + return fmt.Errorf("failed to remove existing socket: %w", err) + } + } + + // Create Unix domain socket + internal.DaemonLog.Println("Creating Unix domain socket") + listener, err := net.Listen("unix", socketPath) + if err != nil { + internal.DaemonLog.Println("Failed to listen on socket: %v", err) + return fmt.Errorf("failed to listen on socket: %w", err) + } + defer listener.Close() + + // Set socket permissions to only allow current user + internal.DaemonLog.Println("Setting socket permissions to 0600") + if err := os.Chmod(socketPath, 0600); err != nil { + internal.DaemonLog.Println("Failed to set socket permissions: %v", err) + return fmt.Errorf("failed to set socket permissions: %w", err) + } + + // Create password cache + internal.DaemonLog.Println("Creating password cache with timeout: %s", timeout) + cache := NewPasswordCache(timeout) + + // Handle signals for graceful shutdown + internal.DaemonLog.Println("Setting up signal handlers for graceful shutdown") + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigChan + internal.DaemonLog.Println("Received signal: %s, shutting down", sig) + os.RemoveAll(socketPath) + os.Exit(0) + }() + + // Start cleanup goroutine to periodically check for expired passwords + internal.DaemonLog.Println("Starting cleanup goroutine with interval: %s", timeout/2) + go func() { + ticker := time.NewTicker(timeout / 2) + defer ticker.Stop() + + for t := range ticker.C { + internal.DaemonLog.Println("Cleanup tick at %s", t.Format(time.RFC3339)) + // This will trigger cleanup of expired password + _, ok := cache.Get() + internal.DaemonLog.Println("Cleanup check result: password exists=%v", ok) + } + }() + + fmt.Fprintf(os.Stderr, "Password daemon started (timeout: %s)\n", timeout) + internal.DaemonLog.Println("Password daemon started successfully") + + // Accept connections + internal.DaemonLog.Println("Entering accept loop") + for { + internal.DaemonLog.Println("Waiting for connections") + conn, err := listener.Accept() + if err != nil { + internal.DaemonLog.Println("Error accepting connection: %v", err) + log.Printf("Error accepting connection: %v", err) + continue + } + + internal.DaemonLog.Println("Accepted connection from: %s", conn.RemoteAddr()) + go handleConnection(conn, cache) + } +} + +// handleConnection processes client requests +func handleConnection(conn net.Conn, cache *PasswordCache) { + defer conn.Close() + internal.DaemonLog.Println("Handling new connection") + + decoder := json.NewDecoder(conn) + encoder := json.NewEncoder(conn) + + var msg Message + if err := decoder.Decode(&msg); err != nil { + internal.DaemonLog.Println("Failed to decode message: %v", err) + encoder.Encode(Response{Success: false, Error: "Invalid message format"}) + return + } + + internal.DaemonLog.Println("Received command: %s", msg.Command) + switch msg.Command { + case "store": + if msg.Password == "" { + internal.DaemonLog.Println("Store command received with empty password") + encoder.Encode(Response{Success: false, Error: "No password provided"}) + return + } + internal.DaemonLog.Println("Storing password in cache") + cache.Set([]byte(msg.Password)) + internal.DaemonLog.Println("Password stored, sending success response") + encoder.Encode(Response{Success: true}) + + case "retrieve": + internal.DaemonLog.Println("Retrieve command received") + password, ok := cache.Get() + if !ok { + internal.DaemonLog.Println("No password available or password expired") + encoder.Encode(Response{Success: false, Error: "No password available or password expired"}) + return + } + internal.DaemonLog.Println("Password retrieved, sending success response") + encoder.Encode(Response{Success: true, Data: string(password)}) + + default: + internal.DaemonLog.Println("Unknown command received: %s", msg.Command) + encoder.Encode(Response{Success: false, Error: "Unknown command"}) + } + + internal.DaemonLog.Println("Connection handled successfully") +} diff --git a/internal/log.go b/internal/log.go new file mode 100644 index 0000000..23c17e2 --- /dev/null +++ b/internal/log.go @@ -0,0 +1,58 @@ +package internal + +import ( + "io" + "log" + "os" + "path/filepath" +) + +// Logger constants +var ( + Log *log.Logger + DaemonLog *log.Logger +) + +func init() { + // Check debug mode once at startup + debugMode := os.Getenv("DEBUG") == "1" + + // Initialize loggers + Log = createLogger("envvault.log", debugMode) + DaemonLog = createLogger("envvault.daemon.log", debugMode) +} + +// createLogger creates a logger with the specified log file +func createLogger(logFileName string, debugMode bool) *log.Logger { + // Ensure log directory exists + homeDir, err := os.UserHomeDir() + if err != nil { + log.Printf("Failed to get user home directory: %v", err) + return log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile) + } + + logDir := filepath.Join(homeDir, ".local/cache/envvault") + if err := os.MkdirAll(logDir, 0700); err != nil { + log.Printf("Failed to create log directory: %v", err) + return log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile) + } + + // Open log file + logFilePath := filepath.Join(logDir, logFileName) + f, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + log.Printf("Failed to open log file: %v", err) + return log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile) + } + + // If in debug mode, log to both file and stderr + var writer io.Writer + if debugMode { + writer = io.MultiWriter(f, os.Stderr) + } else { + writer = f + } + + // Create logger with timestamp, file name, and line number + return log.New(writer, "", log.LstdFlags|log.Lshortfile) +}