#!/usr/bin/env python3
"""
High-Performance TTL Semantic Validator
For large files (1GB+) - uses streaming parsing and hash-based comparison.

Approaches (in order of speed):
1. Hash comparison - fastest, just tells you equal/not equal
2. Sorted diff - medium, finds actual differences
3. Chunked processing - for files too large for memory

Author: Michael Dombaugh
"""

import sys
import hashlib
import mmap
import re
from pathlib import Path
from typing import Iterator, Set, Tuple, Optional, Dict
from dataclasses import dataclass, field
from collections import defaultdict
import argparse
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp


@dataclass
class Triple:
    """Normalized triple for comparison"""
    subject: str
    predicate: str
    obj: str
    
    def __hash__(self):
        return hash((self.subject, self.predicate, self.obj))
    
    def __eq__(self, other):
        return (self.subject, self.predicate, self.obj) == (other.subject, other.predicate, other.obj)
    
    def __lt__(self, other):
        return (self.subject, self.predicate, self.obj) < (other.subject, other.predicate, other.obj)
    
    def to_line(self) -> str:
        return f"{self.subject}\t{self.predicate}\t{self.obj}"
    
    @classmethod
    def from_line(cls, line: str) -> 'Triple':
        s, p, o = line.strip().split('\t')
        return cls(s, p, o)


class StreamingTTLParser:
    """
    Fast streaming TTL parser - doesn't build full graph.
    Yields normalized triples one at a time.
    """
    
    # Regex patterns compiled once
    PREFIX_PATTERN = re.compile(r'@prefix\s+(\w*):\s*<([^>]+)>\s*\.')
    BASE_PATTERN = re.compile(r'@base\s+<([^>]+)>\s*\.')
    URI_PATTERN = re.compile(r'<([^>]+)>')
    PREFIXED_PATTERN = re.compile(r'(\w*):(\S+)')
    LITERAL_PATTERN = re.compile(r'"([^"\\]*(?:\\.[^"\\]*)*)"(?:\^\^([^\s,;]+)|@(\w+))?')
    
    def __init__(self, filepath: str):
        self.filepath = filepath
        self.prefixes: Dict[str, str] = {}
        self.base: str = ""
        
    def expand_uri(self, uri: str) -> str:
        """Expand prefixed URI to full URI"""
        if uri.startswith('<') and uri.endswith('>'):
            return uri[1:-1]
        
        match = self.PREFIXED_PATTERN.match(uri)
        if match:
            prefix, local = match.groups()
            if prefix in self.prefixes:
                return self.prefixes[prefix] + local
        
        return uri
    
    def normalize_value(self, value: str) -> str:
        """Normalize a value (URI or literal) for comparison"""
        value = value.strip()
        
        # URI in angle brackets
        if value.startswith('<'):
            return self.expand_uri(value)
        
        # Prefixed URI
        if ':' in value and not value.startswith('"'):
            match = self.PREFIXED_PATTERN.match(value)
            if match:
                return self.expand_uri(value)
        
        # Literal - keep as-is but normalize quotes
        if value.startswith('"'):
            return value
        
        # Bare value (number, boolean, etc)
        return value
    
    def parse_triples(self) -> Iterator[Triple]:
        """Stream parse TTL file yielding normalized triples"""
        
        current_subject = None
        current_predicate = None
        
        with open(self.filepath, 'r', encoding='utf-8', errors='replace') as f:
            buffer = ""
            
            for line in f:
                line = line.strip()
                
                # Skip empty lines and comments
                if not line or line.startswith('#'):
                    continue
                
                # Handle prefix declarations
                prefix_match = self.PREFIX_PATTERN.match(line)
                if prefix_match:
                    prefix, uri = prefix_match.groups()
                    self.prefixes[prefix] = uri
                    continue
                
                # Handle base
                base_match = self.BASE_PATTERN.match(line)
                if base_match:
                    self.base = base_match.group(1)
                    continue
                
                # Accumulate line into buffer
                buffer += " " + line
                
                # Check if statement is complete (ends with . or ; or ,)
                if line.endswith('.'):
                    # Process complete statement(s)
                    yield from self._parse_statement(buffer.strip(), current_subject, current_predicate)
                    buffer = ""
                    current_subject = None
                    current_predicate = None
                elif line.endswith(';'):
                    # Same subject, new predicate coming
                    for triple in self._parse_statement(buffer.strip(), current_subject, current_predicate):
                        yield triple
                        current_subject = triple.subject
                    buffer = ""
                    current_predicate = None
                elif line.endswith(','):
                    # Same subject and predicate, new object coming
                    for triple in self._parse_statement(buffer.strip(), current_subject, current_predicate):
                        yield triple
                        current_subject = triple.subject
                        current_predicate = triple.predicate
                    buffer = ""
    
    def _parse_statement(self, statement: str, 
                         carry_subject: Optional[str],
                         carry_predicate: Optional[str]) -> Iterator[Triple]:
        """Parse a single TTL statement into triples"""
        
        statement = statement.rstrip('.;,').strip()
        if not statement:
            return
        
        tokens = self._tokenize(statement)
        if not tokens:
            return
        
        idx = 0
        subject = carry_subject
        predicate = carry_predicate
        
        # If no carried subject, first token is subject
        if subject is None and idx < len(tokens):
            subject = self.normalize_value(tokens[idx])
            idx += 1
        
        while idx < len(tokens):
            # Get predicate if needed
            if predicate is None and idx < len(tokens):
                pred_token = tokens[idx]
                if pred_token == 'a':
                    predicate = 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type'
                else:
                    predicate = self.normalize_value(pred_token)
                idx += 1
            
            # Get object
            if idx < len(tokens):
                obj = self.normalize_value(tokens[idx])
                idx += 1
                
                if subject and predicate:
                    yield Triple(subject, predicate, obj)
                
                # Check for continuation
                if idx < len(tokens):
                    next_token = tokens[idx]
                    if next_token == ';':
                        predicate = None
                        idx += 1
                    elif next_token == ',':
                        idx += 1
                    # Otherwise it's a new predicate
    
    def _tokenize(self, statement: str) -> list:
        """Tokenize TTL statement handling quoted strings"""
        tokens = []
        i = 0
        current = ""
        
        # Strip any inline comments first (not inside quotes)
        clean_statement = ""
        in_quotes = False
        for j, c in enumerate(statement):
            if c == '"' and (j == 0 or statement[j-1] != '\\'):
                in_quotes = not in_quotes
            if c == '#' and not in_quotes:
                break
            clean_statement += c
        statement = clean_statement.strip()
        
        while i < len(statement):
            char = statement[i]
            
            # Handle quoted strings
            if char == '"':
                if current:
                    tokens.append(current)
                    current = ""
                # Find end of quoted string
                j = i + 1
                while j < len(statement):
                    if statement[j] == '"' and statement[j-1] != '\\':
                        break
                    j += 1
                # Include any datatype or language tag
                j += 1
                while j < len(statement) and statement[j] in '^^@':
                    if statement[j:j+2] == '^^':
                        j += 2
                        # Get datatype URI
                        if j < len(statement) and statement[j] == '<':
                            end = statement.find('>', j)
                            if end > 0:
                                j = end + 1
                        else:
                            # Prefixed datatype
                            while j < len(statement) and statement[j] not in ' \t;,.':
                                j += 1
                    elif statement[j] == '@':
                        j += 1
                        while j < len(statement) and statement[j].isalnum():
                            j += 1
                    else:
                        break
                tokens.append(statement[i:j])
                i = j
                continue
            
            # Handle angle-bracket URIs
            if char == '<':
                if current:
                    tokens.append(current)
                    current = ""
                end = statement.find('>', i)
                if end > 0:
                    tokens.append(statement[i:end+1])
                    i = end + 1
                    continue
            
            # Handle whitespace
            if char in ' \t\n':
                if current:
                    tokens.append(current)
                    current = ""
                i += 1
                continue
            
            # Handle punctuation
            if char in ';,.':
                if current:
                    tokens.append(current)
                    current = ""
                tokens.append(char)
                i += 1
                continue
            
            current += char
            i += 1
        
        if current:
            tokens.append(current)
        
        return tokens


def hash_triples_file(filepath: str) -> Tuple[str, int]:
    """
    Compute a canonical hash of all triples in a TTL file.
    Returns (hash, triple_count)
    """
    parser = StreamingTTLParser(filepath)
    
    # Collect all triple hashes
    triple_hashes = []
    count = 0
    
    for triple in parser.parse_triples():
        # Hash each triple individually
        triple_str = f"{triple.subject}\t{triple.predicate}\t{triple.obj}"
        triple_hashes.append(hashlib.md5(triple_str.encode()).hexdigest())
        count += 1
        
        if count % 1000000 == 0:
            print(f"  Processed {count:,} triples...", file=sys.stderr)
    
    # Sort hashes for canonical ordering (block structure independent)
    triple_hashes.sort()
    
    # Compute final hash
    final_hash = hashlib.sha256('\n'.join(triple_hashes).encode()).hexdigest()
    
    return final_hash, count


def extract_sorted_triples(filepath: str, output_path: str) -> int:
    """
    Extract all triples to a sorted temp file for diff comparison.
    Returns triple count.
    """
    parser = StreamingTTLParser(filepath)
    triples = []
    
    print(f"  Parsing {filepath}...", file=sys.stderr)
    for triple in parser.parse_triples():
        triples.append(triple.to_line())
        if len(triples) % 1000000 == 0:
            print(f"    {len(triples):,} triples...", file=sys.stderr)
    
    print(f"  Sorting {len(triples):,} triples...", file=sys.stderr)
    triples.sort()
    
    print(f"  Writing to {output_path}...", file=sys.stderr)
    with open(output_path, 'w') as f:
        for t in triples:
            f.write(t + '\n')
    
    return len(triples)


def compare_sorted_files(file1: str, file2: str) -> Tuple[list, list]:
    """
    Compare two sorted triple files line by line.
    Memory efficient - streams both files.
    Returns (only_in_file1, only_in_file2) as lists of lines.
    """
    only_in_1 = []
    only_in_2 = []
    
    with open(file1, 'r') as f1, open(file2, 'r') as f2:
        line1 = f1.readline()
        line2 = f2.readline()
        
        while line1 and line2:
            line1 = line1.strip()
            line2 = line2.strip()
            
            if line1 == line2:
                line1 = f1.readline()
                line2 = f2.readline()
            elif line1 < line2:
                only_in_1.append(line1)
                line1 = f1.readline()
            else:
                only_in_2.append(line2)
                line2 = f2.readline()
        
        # Drain remaining
        while line1:
            only_in_1.append(line1.strip())
            line1 = f1.readline()
        while line2:
            only_in_2.append(line2.strip())
            line2 = f2.readline()
    
    return only_in_1, only_in_2


def quick_compare(file1: str, file2: str) -> Tuple[bool, int, int]:
    """
    Quick hash-based comparison. Just tells you equal or not.
    Fastest option for large files.
    """
    print(f"Hashing {file1}...", file=sys.stderr)
    hash1, count1 = hash_triples_file(file1)
    
    print(f"Hashing {file2}...", file=sys.stderr)
    hash2, count2 = hash_triples_file(file2)
    
    return hash1 == hash2, count1, count2


def full_compare(file1: str, file2: str, temp_dir: str = "/tmp") -> Tuple[bool, list, list, int, int]:
    """
    Full comparison with diff output.
    Uses temp files for sorting to handle large files.
    """
    import tempfile
    import os
    
    temp1 = os.path.join(temp_dir, "ttl_compare_1.txt")
    temp2 = os.path.join(temp_dir, "ttl_compare_2.txt")
    
    try:
        print("Extracting and sorting file 1...", file=sys.stderr)
        count1 = extract_sorted_triples(file1, temp1)
        
        print("Extracting and sorting file 2...", file=sys.stderr)
        count2 = extract_sorted_triples(file2, temp2)
        
        print("Comparing...", file=sys.stderr)
        only_in_1, only_in_2 = compare_sorted_files(temp1, temp2)
        
        is_equal = len(only_in_1) == 0 and len(only_in_2) == 0
        return is_equal, only_in_1, only_in_2, count1, count2
        
    finally:
        # Cleanup
        if os.path.exists(temp1):
            os.remove(temp1)
        if os.path.exists(temp2):
            os.remove(temp2)


def print_diff_report(only_in_1: list, only_in_2: list, 
                      count1: int, count2: int,
                      file1: str, file2: str,
                      max_show: int = 100):
    """Print human-readable diff report"""
    
    print("=" * 80)
    print("TTL COMPARISON REPORT")
    print("=" * 80)
    print(f"\nFile 1: {file1} ({count1:,} triples)")
    print(f"File 2: {file2} ({count2:,} triples)")
    
    if not only_in_1 and not only_in_2:
        print("\n✓ FILES ARE SEMANTICALLY EQUIVALENT")
    else:
        print(f"\n✗ FILES DIFFER")
        print(f"   Only in file 1: {len(only_in_1):,} triples")
        print(f"   Only in file 2: {len(only_in_2):,} triples")
        
        if only_in_1:
            print(f"\n--- ONLY IN FILE 1 (showing first {min(len(only_in_1), max_show)}) ---")
            for line in only_in_1[:max_show]:
                parts = line.split('\t')
                if len(parts) == 3:
                    s, p, o = parts
                    # Shorten URIs for display
                    s_short = s.split('/')[-1] if '/' in s else s
                    p_short = p.split('/')[-1] if '/' in p else p
                    print(f"  {s_short} | {p_short} | {o[:60]}")
            if len(only_in_1) > max_show:
                print(f"  ... and {len(only_in_1) - max_show:,} more")
        
        if only_in_2:
            print(f"\n--- ONLY IN FILE 2 (showing first {min(len(only_in_2), max_show)}) ---")
            for line in only_in_2[:max_show]:
                parts = line.split('\t')
                if len(parts) == 3:
                    s, p, o = parts
                    s_short = s.split('/')[-1] if '/' in s else s
                    p_short = p.split('/')[-1] if '/' in p else p
                    print(f"  {s_short} | {p_short} | {o[:60]}")
            if len(only_in_2) > max_show:
                print(f"  ... and {len(only_in_2) - max_show:,} more")
    
    print("\n" + "=" * 80)


def main():
    parser = argparse.ArgumentParser(
        description='High-performance TTL semantic comparison for large files',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Modes:
  --quick     Hash comparison only (fastest, just equal/not equal)
  --full      Full diff with triple-level differences (default)

Examples:
  %(prog)s file1.ttl file2.ttl --quick          # Fast pass/fail
  %(prog)s file1.ttl file2.ttl                  # Full diff report
  %(prog)s file1.ttl file2.ttl --export-diff    # Save diffs to files
        """
    )
    parser.add_argument('file1', help='First TTL file')
    parser.add_argument('file2', help='Second TTL file')
    parser.add_argument('--quick', action='store_true',
                        help='Quick hash comparison (no diff details)')
    parser.add_argument('--export-diff', action='store_true',
                        help='Export differences to files')
    parser.add_argument('--max-show', type=int, default=100,
                        help='Max differences to display (default: 100)')
    parser.add_argument('--temp-dir', default='/tmp',
                        help='Temp directory for sorting (default: /tmp)')
    parser.add_argument('--quiet', action='store_true',
                        help='Only output PASS/FAIL')
    
    args = parser.parse_args()
    
    # Validate files
    for f in [args.file1, args.file2]:
        if not Path(f).exists():
            print(f"ERROR: File not found: {f}", file=sys.stderr)
            sys.exit(1)
    
    if args.quick:
        is_equal, count1, count2 = quick_compare(args.file1, args.file2)
        
        if args.quiet:
            print("PASS" if is_equal else "FAIL")
        else:
            print(f"\nFile 1: {count1:,} triples")
            print(f"File 2: {count2:,} triples")
            if is_equal:
                print("✓ FILES ARE SEMANTICALLY EQUIVALENT")
            else:
                print("✗ FILES DIFFER (use --full for details)")
        
        sys.exit(0 if is_equal else 1)
    
    else:
        is_equal, only_in_1, only_in_2, count1, count2 = full_compare(
            args.file1, args.file2, args.temp_dir
        )
        
        if args.quiet:
            print("PASS" if is_equal else "FAIL")
            sys.exit(0 if is_equal else 1)
        
        print_diff_report(only_in_1, only_in_2, count1, count2,
                         args.file1, args.file2, args.max_show)
        
        if args.export_diff:
            if only_in_1:
                with open('only_in_file1.txt', 'w') as f:
                    for line in only_in_1:
                        f.write(line + '\n')
                print(f"Exported {len(only_in_1):,} differences to only_in_file1.txt")
            
            if only_in_2:
                with open('only_in_file2.txt', 'w') as f:
                    for line in only_in_2:
                        f.write(line + '\n')
                print(f"Exported {len(only_in_2):,} differences to only_in_file2.txt")
        
        sys.exit(0 if is_equal else 1)


if __name__ == '__main__':
    main()
