#!/usr/bin/env python3
"""
Secure Password Vault for Python Applications
Uses industry-standard encryption (AES-256-GCM) with PBKDF2 key derivation.

Supports encrypted master password storage:
  - VAULT_KEY  : Fernet key stored in project .env file
  - ID_ENC     : Master password encrypted with VAULT_KEY, stored in OS env var
  - VPATH      : Path to vault.enc, stored in OS env var
  - VAULT_DIR  : Path to this file, stored in OS env var

Setup (run once per machine):
  python vault.py setup              # generates VAULT_KEY, sets OS env vars
  python vault.py create             # initializes vault.enc

Credential management:
  python vault.py store <SERVICE> <USERNAME>
  python vault.py get   <SERVICE> [USERNAME]
  python vault.py list
  python vault.py delete <SERVICE> <USERNAME>
  python vault.py change-master

Master password utilities:
  python vault.py generate-key       # generate a new VAULT_KEY
  python vault.py encrypt-master     # encrypt master password with VAULT_KEY
  python vault.py decrypt-master     # verify/test decryption
"""

import os
import json
import base64
import getpass
import hashlib
from pathlib import Path
from typing import Dict, Optional, Any, Tuple
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.fernet import Fernet
from cryptography.exceptions import InvalidSignature, InvalidTag
import secrets
import argparse
import sys
from datetime import datetime


# =============================================================================
# Master password encryption utilities
# =============================================================================

def generate_vault_key() -> str:
    """Generate a new Fernet key for encrypting the master password."""
    return Fernet.generate_key().decode()


def encrypt_master_password(master_password: str, vault_key: str) -> str:
    """
    Encrypt the master password using the VAULT_KEY.
    Returns the encrypted token as a string (stored as ID_ENC OS env var).
    """
    f = Fernet(vault_key.encode())
    return f.encrypt(master_password.encode()).decode()


def decrypt_master_password(id_enc: str, vault_key: str) -> str:
    """
    Decrypt ID_ENC using VAULT_KEY to recover the master password.
    Raises InvalidToken if either value is wrong or tampered.
    """
    f = Fernet(vault_key.encode())
    return f.decrypt(id_enc.encode()).decode()


def resolve_master_password(vault_key: str = None) -> str:
    """
    Resolve the master password using the two-part encrypted scheme:
      1. VAULT_KEY from argument or VAULT_KEY env var (set in project .env)
      2. ID_ENC from OS env var (set by setup script)
    Falls back to ID (plaintext) for backward compatibility.
    Falls back to interactive prompt if nothing is set.
    """
    # Try encrypted scheme first
    id_enc = os.getenv("ID_ENC")
    key = vault_key or os.getenv("VAULT_KEY")
    if id_enc and key:
        try:
            return decrypt_master_password(id_enc, key)
        except Exception:
            raise ValueError(
                "Failed to decrypt master password. "
                "Check that VAULT_KEY in .env matches the key used when ID_ENC was set."
            )

    # Backward compatibility: plaintext ID env var
    plain_id = os.getenv("ID")
    if plain_id:
        return plain_id

    # Interactive fallback
    return getpass.getpass("Enter vault master password: ")


# =============================================================================
# SecureVault class
# =============================================================================

class SecureVault:
    """
    Secure password vault using AES-256-GCM encryption with PBKDF2 key derivation.

    Security Features:
    - AES-256-GCM encryption (authenticated encryption)
    - PBKDF2-SHA256 key derivation (600,000 iterations)
    - Cryptographically secure random salts and nonces
    - Master password verification without storing password
    - Encrypted master password support (ID_ENC + VAULT_KEY)
    """

    def __init__(self, vault_path: str = None):
        resolved = vault_path or os.getenv('VPATH')
        self.vault_path = Path(resolved) if resolved else Path("C:/code/vault/vault.enc")
        self.vault_path.parent.mkdir(parents=True, exist_ok=True)
        if os.name != 'nt':
            self.vault_path.parent.chmod(0o700)
        self._master_key = None

    def _derive_key(self, password: str, salt: bytes) -> bytes:
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=600000,
        )
        return kdf.derive(password.encode('utf-8'))

    def _create_password_hash(self, password: str, salt: bytes) -> str:
        return hashlib.pbkdf2_hmac('sha256', password.encode('utf-8'), salt, 600000).hex()

    def _encrypt_data(self, data: Dict[str, Any], master_password: str) -> bytes:
        salt = secrets.token_bytes(32)
        nonce = secrets.token_bytes(12)
        key = self._derive_key(master_password, salt)
        password_hash = self._create_password_hash(master_password, salt)
        vault_data = {
            'version': '1.0',
            'created': datetime.now().isoformat(),
            'password_hash': password_hash,
            'entries': data
        }
        aesgcm = AESGCM(key)
        ciphertext = aesgcm.encrypt(nonce, json.dumps(vault_data).encode('utf-8'), None)
        encrypted_package = {
            'salt': base64.b64encode(salt).decode('utf-8'),
            'nonce': base64.b64encode(nonce).decode('utf-8'),
            'ciphertext': base64.b64encode(ciphertext).decode('utf-8')
        }
        return json.dumps(encrypted_package).encode('utf-8')

    def _decrypt_data(self, encrypted_data: bytes, master_password: str) -> Dict[str, Any]:
        try:
            package = json.loads(encrypted_data.decode('utf-8'))
            salt = base64.b64decode(package['salt'])
            nonce = base64.b64decode(package['nonce'])
            ciphertext = base64.b64decode(package['ciphertext'])
            key = self._derive_key(master_password, salt)
            aesgcm = AESGCM(key)
            decrypted_data = aesgcm.decrypt(nonce, ciphertext, None)
            vault_data = json.loads(decrypted_data.decode('utf-8'))
            expected_hash = self._create_password_hash(master_password, salt)
            if vault_data['password_hash'] != expected_hash:
                raise ValueError("Invalid master password")
            return vault_data['entries']
        except InvalidTag:
            raise ValueError("Incorrect vault master password.")
        except (json.JSONDecodeError, KeyError) as e:
            raise ValueError("Vault file is corrupt or unreadable.") from e
        except InvalidSignature:
            raise ValueError("Vault data signature invalid -- file may have been tampered with.")

    def _migrate_flat_vault(self, entries: Dict[str, Any]) -> tuple:
        migrated = False
        for service, value in list(entries.items()):
            if isinstance(value, dict) and 'username' in value and isinstance(value['username'], str):
                username = value['username']
                entries[service] = {username: value}
                migrated = True
        if migrated:
            print("INFO: Vault migrated to multi-user format. Saving...")
        return entries, migrated

    def create_vault(self, master_password: str = None) -> bool:
        if self.vault_path.exists():
            response = input(f"Vault already exists at {self.vault_path}. Overwrite? (y/N): ")
            if response.lower() != 'y':
                print("Vault creation cancelled.")
                return False
        if not master_password:
            master_password = resolve_master_password()
        if len(master_password) < 12:
            raise ValueError("Master password must be at least 12 characters long.")
        encrypted_data = self._encrypt_data({}, master_password)
        self.vault_path.write_bytes(encrypted_data)
        if os.name != 'nt':
            self.vault_path.chmod(0o600)
        print(f"Vault created successfully at {self.vault_path}")
        return True

    def unlock_vault(self, master_password: str = None, vault_key: str = None) -> bool:
        if not self.vault_path.exists():
            print(f"Vault not found at {self.vault_path}")
            return False
        if not master_password:
            master_password = resolve_master_password(vault_key)
        try:
            encrypted_data = self.vault_path.read_bytes()
            self._vault_data = self._decrypt_data(encrypted_data, master_password)
            self._vault_data, migrated = self._migrate_flat_vault(self._vault_data)
            self._master_password = master_password
            if migrated:
                self._save_vault()
            print("Vault unlocked successfully.")
            return True
        except ValueError as e:
            print(f"Failed to unlock vault: {e}")
            return False

    def _save_vault(self):
        if not hasattr(self, '_vault_data') or not self._master_password:
            raise RuntimeError("Vault is not unlocked")
        encrypted_data = self._encrypt_data(self._vault_data, self._master_password)
        self.vault_path.write_bytes(encrypted_data)

    def store_password(self, service: str, username: str, password: str, notes: str = "") -> bool:
        if not hasattr(self, '_vault_data'):
            print("Vault is not unlocked. Please unlock first.")
            return False
        now = datetime.now().isoformat()
        existing_entry = self._vault_data.get(service, {}).get(username, {})
        entry = {
            'username': username,
            'password': password,
            'notes': notes,
            'created': existing_entry.get('created', now),
            'modified': now,
        }
        if service not in self._vault_data:
            self._vault_data[service] = {}
        self._vault_data[service][username] = entry
        self._save_vault()
        print(f"Password for '{service}' / '{username}' stored successfully.")
        return True

    def get_password(self, service: str, username: str = None) -> Optional[Dict[str, Any]]:
        if not hasattr(self, '_vault_data'):
            print("Vault is not unlocked. Please unlock first.")
            return None
        service_entries = self._vault_data.get(service)
        if service_entries is None:
            return None
        if username is not None:
            return service_entries.get(username)
        return service_entries

    def list_services(self) -> list:
        if not hasattr(self, '_vault_data'):
            print("Vault is not unlocked. Please unlock first.")
            return []
        self._vault_data, migrated = self._migrate_flat_vault(self._vault_data)
        if migrated:
            self._save_vault()
        pairs = []
        for service, users in self._vault_data.items():
            if not isinstance(users, dict):
                continue
            for username, entry in users.items():
                if isinstance(entry, dict):
                    pairs.append((service, username))
        return sorted(pairs)

    def delete_password(self, service: str, username: str) -> bool:
        if not hasattr(self, '_vault_data'):
            print("Vault is not unlocked. Please unlock first.")
            return False
        if service not in self._vault_data:
            print(f"Service '{service}' not found in vault.")
            return False
        if username not in self._vault_data[service]:
            print(f"Username '{username}' not found under service '{service}'.")
            return False
        del self._vault_data[service][username]
        if not self._vault_data[service]:
            del self._vault_data[service]
        self._save_vault()
        print(f"Password for '{service}' / '{username}' deleted successfully.")
        return True

    def change_master_password(self, new_password: str = None) -> bool:
        if not hasattr(self, '_vault_data'):
            print("Vault is not unlocked. Please unlock first.")
            return False
        if not new_password:
            while True:
                new_password = getpass.getpass("Enter NEW vault master password: ")
                confirm_password = getpass.getpass("Confirm new vault master password: ")
                if new_password != confirm_password:
                    print("Passwords don't match. Please try again.")
                    continue
                if len(new_password) < 12:
                    print("Master password must be at least 12 characters long.")
                    continue
                break
        self._master_password = new_password
        self._save_vault()
        print("Master password changed successfully.")
        return True

    def lock_vault(self):
        if hasattr(self, '_vault_data'):
            self._vault_data.clear()
            del self._vault_data
        if hasattr(self, '_master_password'):
            self._master_password = 'x' * len(self._master_password)
            del self._master_password
        print("Vault locked.")


# =============================================================================
# CLI
# =============================================================================

def main():
    parser = argparse.ArgumentParser(description="Secure Password Vault")
    parser.add_argument("--vault-path", help="Path to vault file (overrides VPATH env var)")
    parser.add_argument("--vault-key",  help="Fernet key for master password decryption (overrides VAULT_KEY env var)")

    subparsers = parser.add_subparsers(dest="command", help="Available commands")

    subparsers.add_parser("generate-key",   help="Generate a new VAULT_KEY")
    subparsers.add_parser("encrypt-master", help="Encrypt master password with VAULT_KEY -> ID_ENC")
    subparsers.add_parser("decrypt-master", help="Decrypt ID_ENC to verify master password")
    subparsers.add_parser("setup",          help="Full first-time setup: generate key, encrypt password, show env vars")
    subparsers.add_parser("create",         help="Create a new vault")

    store_p = subparsers.add_parser("store", help="Store a password")
    store_p.add_argument("service",  help="Service name (e.g. MULBUILDDB1)")
    store_p.add_argument("username", help="Username")
    store_p.add_argument("--notes",  default="", help="Optional notes")

    get_p = subparsers.add_parser("get", help="Retrieve a password")
    get_p.add_argument("service",  help="Service name")
    get_p.add_argument("username", nargs="?", default=None, help="Username (optional)")

    subparsers.add_parser("list", help="List all services")

    del_p = subparsers.add_parser("delete", help="Delete a password")
    del_p.add_argument("service",  help="Service name")
    del_p.add_argument("username", help="Username")

    subparsers.add_parser("change-master", help="Change master password")

    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        return

    vault_key = getattr(args, 'vault_key', None) or os.getenv("VAULT_KEY")

    # ── Key/encryption utilities ──────────────────────────────────────────────

    if args.command == "generate-key":
        key = generate_vault_key()
        print(f"\nVAULT_KEY={key}")
        print("\nAdd VAULT_KEY to your project .env file.")
        print("Keep it out of source control.\n")
        return

    if args.command == "encrypt-master":
        key = vault_key or input("Enter VAULT_KEY: ").strip()
        pw  = getpass.getpass("Enter vault master password: ")
        enc = encrypt_master_password(pw, key)
        print(f"\nID_ENC={enc}")
        print("\nSet ID_ENC as a persistent OS environment variable (see setup script).\n")
        return

    if args.command == "decrypt-master":
        key     = vault_key or input("Enter VAULT_KEY: ").strip()
        id_enc  = os.getenv("ID_ENC") or input("Enter ID_ENC: ").strip()
        try:
            pw = decrypt_master_password(id_enc, key)
            print("Decryption successful. Master password verified.")
        except Exception as e:
            print(f"Decryption failed: {e}")
        return

    if args.command == "setup":
        print("\n=== Vault First-Time Setup ===\n")
        key = generate_vault_key()
        print(f"Generated VAULT_KEY:\n  {key}")
        print("\nAdd this to every project .env file:")
        print(f"  VAULT_KEY={key}\n")

        pw = getpass.getpass("Enter vault master password (min 12 chars): ")
        if len(pw) < 12:
            print("ERROR: Password must be at least 12 characters.")
            sys.exit(1)
        enc = encrypt_master_password(pw, key)

        vpath     = args.vault_path or "C:\\code\\vault\\vault.enc"
        vault_dir = str(Path(__file__).resolve().parent)

        print("\nSet these OS user environment variables (setup_vault.ps1 does this automatically):")
        print(f"  ID_ENC    = {enc}")
        print(f"  VPATH     = {vpath}")
        print(f"  VAULT_DIR = {vault_dir}")
        print("\nThen run:  python vault.py create\n")
        return

    # ── Vault operations ──────────────────────────────────────────────────────

    vault = SecureVault(args.vault_path)

    if args.command == "create":
        vault.create_vault()

    elif args.command == "store":
        if vault.unlock_vault(vault_key=vault_key):
            password = getpass.getpass(f"Enter password for '{args.service}' / '{args.username}': ")
            notes    = args.notes or input("Notes (optional, press Enter to skip): ").strip()
            vault.store_password(args.service, args.username, password, notes)
            vault.lock_vault()

    elif args.command == "get":
        if vault.unlock_vault(vault_key=vault_key):
            result = vault.get_password(args.service, args.username)
            if result is None:
                target = f"'{args.service}' / '{args.username}'" if args.username else f"'{args.service}'"
                print(f"No password found for {target}")
            elif args.username:
                entry = result
                print(f"Service:  {args.service}")
                print(f"Username: {entry['username']}")
                print(f"Password: {entry['password']}")
                if entry['notes']:
                    print(f"Notes:    {entry['notes']}")
                print(f"Created:  {entry['created']}")
                print(f"Modified: {entry['modified']}")
            else:
                for uname, entry in sorted(result.items()):
                    print(f"Service:  {args.service}")
                    print(f"Username: {entry['username']}")
                    print(f"Password: {entry['password']}")
                    if entry['notes']:
                        print(f"Notes:    {entry['notes']}")
                    print(f"Created:  {entry['created']}")
                    print(f"Modified: {entry['modified']}")
                    print()
            vault.lock_vault()

    elif args.command == "list":
        if vault.unlock_vault(vault_key=vault_key):
            pairs = vault.list_services()
            if pairs:
                print("Stored services:")
                for service, username in pairs:
                    print(f"  - {service} ({username})")
            else:
                print("No passwords stored in vault.")
            vault.lock_vault()

    elif args.command == "delete":
        if vault.unlock_vault(vault_key=vault_key):
            confirm = input(f"Delete password for '{args.service}' / '{args.username}'? (y/N): ")
            if confirm.lower() == 'y':
                vault.delete_password(args.service, args.username)
            vault.lock_vault()

    elif args.command == "change-master":
        if vault.unlock_vault(vault_key=vault_key):
            vault.change_master_password()
            vault.lock_vault()


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nAborted.")
        sys.exit(1)
    except PermissionError as e:
        print(f"ERROR: Permission denied -- {e}")
        sys.exit(1)
    except FileNotFoundError as e:
        print(f"ERROR: File not found -- {e}")
        sys.exit(1)
    except RuntimeError as e:
        print(f"ERROR: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"ERROR: Unexpected error -- {e}")
        sys.exit(1)
