# -*- coding: utf-8 -*-
"""
TTL vs Database View Validator
================================
Validates that LEAFLET_IDs in a TTL file match those in a database view.

Reports:
  - IDs in DB view but missing from TTL
  - IDs in TTL but missing from DB view
  - Summary counts

Output: console + <ttl_stem>_validation_log.txt alongside the TTL file.

Usage:
    python validate_ttl_vs_db.py <path_to_ttl_file> [DB_VIEW_NAME]

    DB_VIEW_NAME defaults to V_GRF_LEAFLET.

Author: M Dombaugh
"""

import sys
import os
import re
from pathlib import Path
from datetime import datetime

from dotenv import load_dotenv
import oracledb
import getpass

# ── Oracle Instant Client (Windows) ──────────────────────────────────────────

# ── Environment variable loading ──────────────────────────────────────────────
# Scripts live in skg/validation/ - load skg/.env explicitly
_project_root = Path(__file__).resolve().parent.parent  # skg/
load_dotenv(_project_root / ".env")

# Add VAULT_DIR to sys.path so vault_client can be imported
_vault_dir = os.getenv("VAULT_DIR")
if _vault_dir and _vault_dir not in sys.path:
    sys.path.insert(0, _vault_dir)

try:
    from vault_client import get_vault_credentials, get_wallet_password
    VAULT_AVAILABLE = True
except ImportError:
    VAULT_AVAILABLE = False

wallet_path = Path(os.getenv("ORACLE_WALLET_PATH"))
dsn         = os.getenv("ORACLE_DSN")
oracle_user = os.getenv("ORACLE_USER")

# Credential resolution: vault -> ORACLE_PASSWORD env var -> interactive prompt
_wallet_folder = Path(os.getenv("ORACLE_WALLET_PATH", "")).name
_service       = _wallet_folder.replace("Wallet_", "").replace("wallet_", "").strip()

oracle_password = os.getenv("ORACLE_PASSWORD")
if not oracle_password and VAULT_AVAILABLE and _service:
    _vault_user, oracle_password = get_vault_credentials(_service)
    if _vault_user and not oracle_user:
        oracle_user = _vault_user
if not oracle_password:
    oracle_password = getpass.getpass("Enter Oracle password: ")

wallet_password = None
if VAULT_AVAILABLE and _service:
    wallet_password = get_wallet_password(str(wallet_path))
if not wallet_password:
    wallet_password = getpass.getpass("Enter wallet password: ")


os.environ["NLS_LANG"] = "AMERICAN_AMERICA.AL32UTF8"

# ── Config ────────────────────────────────────────────────────────────────────
DEFAULT_VIEW       = "V_GRF_LEAFLET"
KEY_COLUMN         = "LEAFLET_ID"
TTL_PREFIX         = "leaflet"        # prefix used in the TTL file (leaflet:9)
OWNER              = "GLOBAL_DISTRIBUTE"


# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

os.environ["NLS_LANG"] = "AMERICAN_AMERICA.AL32UTF8"

def connect():
    conn = oracledb.connect(user=oracle_user, password=oracle_password, dsn=dsn,
                           config_dir=str(wallet_path), wallet_location=str(wallet_path),
                           wallet_password=wallet_password)
    conn.autocommit = True
    return conn


def fetch_db_ids(cursor, view_name: str) -> set:
    """Return set of all LEAFLET_ID values from the database view (as strings)."""
    sql = f"SELECT {KEY_COLUMN} FROM {OWNER}.{view_name}"
    print(f"[DB] Executing: {sql}")
    cursor.execute(sql)
    rows = cursor.fetchall()
    ids = {str(row[0]).strip() for row in rows if row[0] is not None}
    print(f"[DB] Fetched {len(ids):,} IDs from {view_name}")
    return ids


def extract_ttl_ids(ttl_path: Path) -> set:
    """
    Extract all unique local IDs used with the leaflet: prefix in the TTL file.

    Matches patterns like:
        leaflet:9
        leaflet:100
    Anywhere on a line (subject position, object position, etc.).
    """
    # Match   leaflet:<digits or alphanumeric local name>
    # Using \S+ to handle any local-name characters, same approach as the main pipeline.
    pattern = re.compile(rf'\b{re.escape(TTL_PREFIX)}:(\S+?)(?=[\s,;.>\]]|$)')

    ids = set()
    with open(ttl_path, "r", encoding="utf-8") as f:
        for line in f:
            for match in pattern.finditer(line):
                local = match.group(1).strip()
                # Strip any trailing Turtle punctuation that may have been captured
                local = local.rstrip(";,.")
                if local:
                    ids.add(local)

    print(f"[TTL] Found {len(ids):,} unique IDs in {ttl_path.name}")
    return ids


def build_report(db_ids: set, ttl_ids: set, view_name: str, ttl_path: Path) -> list[str]:
    """Build list of report lines (returned for both console and file output)."""
    in_db_not_ttl  = sorted(db_ids  - ttl_ids,  key=lambda x: (len(x), x))
    in_ttl_not_db  = sorted(ttl_ids - db_ids,   key=lambda x: (len(x), x))
    common         = db_ids & ttl_ids

    lines = []
    sep = "=" * 70

    lines.append(sep)
    lines.append("TTL vs DATABASE VALIDATION REPORT")
    lines.append(sep)
    lines.append(f"Timestamp   : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    lines.append(f"TTL file    : {ttl_path.resolve()}")
    lines.append(f"DB view     : {OWNER}.{view_name}")
    lines.append(f"Key column  : {KEY_COLUMN}")
    lines.append(f"TTL prefix  : {TTL_PREFIX}:")
    lines.append("")
    lines.append("SUMMARY")
    lines.append("-" * 40)
    lines.append(f"  IDs in DB view              : {len(db_ids):>8,}")
    lines.append(f"  IDs in TTL file             : {len(ttl_ids):>8,}")
    lines.append(f"  IDs in both (matched)       : {len(common):>8,}")
    lines.append(f"  In DB but MISSING from TTL  : {len(in_db_not_ttl):>8,}")
    lines.append(f"  In TTL but MISSING from DB  : {len(in_ttl_not_db):>8,}")
    lines.append("")

    # ── In DB, not in TTL ───────────────────────────────────────────────────
    lines.append(sep)
    lines.append(f"IN DB ({view_name}) BUT NOT IN TTL  [{len(in_db_not_ttl):,} items]")
    lines.append(sep)
    if in_db_not_ttl:
        for id_val in in_db_not_ttl:
            lines.append(f"  {TTL_PREFIX}:{id_val}")
    else:
        lines.append("  (none - all DB IDs are present in the TTL)")
    lines.append("")

    # ── In TTL, not in DB ───────────────────────────────────────────────────
    lines.append(sep)
    lines.append(f"IN TTL BUT NOT IN DB ({view_name})  [{len(in_ttl_not_db):,} items]")
    lines.append(sep)
    if in_ttl_not_db:
        for id_val in in_ttl_not_db:
            lines.append(f"  {TTL_PREFIX}:{id_val}")
    else:
        lines.append("  (none - all TTL IDs exist in the DB view)")
    lines.append("")

    # ── Verdict ─────────────────────────────────────────────────────────────
    lines.append(sep)
    if not in_db_not_ttl and not in_ttl_not_db:
        lines.append("RESULT: PASS - TTL and DB view are fully synchronized.")
    else:
        lines.append("RESULT: DIFFERENCES FOUND - see sections above.")
    lines.append(sep)

    return lines


# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────

def main():
    # ── Args ─────────────────────────────────────────────────────────────────
    if len(sys.argv) >= 2:
        ttl_path = Path(sys.argv[1])
    else:
        # Default: newest pel_supplemental_v*.ttl in skg/input/
        _input_dir = _project_root / "input"
        _matches   = list(_input_dir.glob("pel_supplemental_v*.ttl"))
        if not _matches:
            print(f"[ERROR] No pel_supplemental_v*.ttl files found in {_input_dir}")
            print("Usage: python rdf_pel_sup_vs_db.py <ttl_file> [VIEW_NAME]")
            sys.exit(1)
        ttl_path = max(_matches, key=lambda p: p.stat().st_mtime)
        print(f"[INFO] No TTL file specified - defaulting to: {ttl_path}")

    view_name = sys.argv[2].upper() if len(sys.argv) >= 3 else DEFAULT_VIEW

    if not ttl_path.exists():
        print(f"[ERROR] TTL file not found: {ttl_path}")
        sys.exit(1)

    # Log file lives next to the TTL file
    log_path = ttl_path.parent / f"{ttl_path.stem}_validation_log.txt"

    # ── Data gathering ────────────────────────────────────────────────────────
    print(f"\n[INFO] TTL file : {ttl_path}")
    print(f"[INFO] DB view  : {OWNER}.{view_name}")
    print()

    ttl_ids = extract_ttl_ids(ttl_path)

    conn   = connect()
    cursor = conn.cursor()
    cursor.arraysize   = 5000
    cursor.prefetchrows = 5000

    db_ids = fetch_db_ids(cursor, view_name)

    cursor.close()
    conn.close()

    # ── Build report ──────────────────────────────────────────────────────────
    report_lines = build_report(db_ids, ttl_ids, view_name, ttl_path)

    # Console output
    print()
    for line in report_lines:
        print(line)

    # File output
    with open(log_path, "w", encoding="utf-8") as lf:
        lf.write("\n".join(report_lines) + "\n")

    print(f"\n[INFO] Validation log written to: {log_path}")


if __name__ == "__main__":
    main()
