#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Clean TTL Validator with Readable Output
Handles errors gracefully and provides clear, actionable reports
"""

import re
import sys
import time
import json
import logging
from pathlib import Path
from collections import defaultdict, Counter
from datetime import datetime, timedelta
from urllib.parse import urlparse
import warnings

# Suppress rdflib verbose warnings
warnings.filterwarnings('ignore')
logging.getLogger('rdflib').setLevel(logging.ERROR)

import rdflib
from rdflib import Graph, Namespace, Literal, URIRef, BNode
from rdflib.namespace import RDF, RDFS, XSD, OWL


class CleanTTLValidator:
    def __init__(self, verbose=False):
        self.verbose = verbose
        self.issues = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'examples': [], 'first_line': None}))
        self.stats = defaultdict(int)
        self.max_examples = 5
        self.error_limit = 1000000  # Stop collecting after this many errors
        self.seen_errors = set()  # Track unique errors to avoid duplicates
        
    def validate_file(self, file_path):
        """Main validation entry point"""
        start_time = time.time()
        file_path = Path(file_path)
        
        if not file_path.exists():
            print(f"[ERROR] File {file_path} does not exist")
            return False
        
        file_size_gb = file_path.stat().st_size / (1024**3)
        
        print(f"\nTTL Validator")
        print("=" * 60)
        print(f"File: {file_path.name}")
        print(f"Size: {file_size_gb:.2f} GB")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("=" * 60)
        print()
        
        # Phase 1: Quick syntax scan
        phase1_start = time.time()
        print("Phase 1: Syntax scanning...")
        self._quick_syntax_scan(file_path)
        phase1_time = time.time() - phase1_start
        print(f"  Phase 1 completed in: {self._format_time(phase1_time)}")
        
        # Phase 2: RDF parsing with error suppression
        phase2_start = time.time()
        print("\nPhase 2: RDF validation...")
        graph = self._parse_with_error_handling(file_path)
        phase2_time = time.time() - phase2_start
        print(f"  Phase 2 completed in: {self._format_time(phase2_time)}")
        
        # Phase 3: Content validation (if parsing succeeded)
        if graph:
            phase3_start = time.time()
            print("\nPhase 3: Content validation...")
            self._validate_content(graph)
            phase3_time = time.time() - phase3_start
            print(f"  Phase 3 completed in: {self._format_time(phase3_time)}")
        
        # Generate clean report
        elapsed = time.time() - start_time
        return self._generate_clean_report(elapsed)
    
    def _format_time(self, seconds):
        """Format time in minutes and seconds"""
        minutes = int(seconds // 60)
        secs = int(seconds % 60)
        if minutes > 0:
            return f"{minutes}m {secs}s"
        else:
            return f"{secs}s"
    
    def _remove_string_literals(self, line):
        """Remove all string literals from a line to avoid false positives in URI checking"""
        # Remove triple-quoted strings first
        line = re.sub(r'""".*?"""', '""', line, flags=re.DOTALL)
        
        # Remove single-quoted strings
        line = re.sub(r'"[^"]*"', '""', line)
        
        return line
    
    def _quick_syntax_scan(self, file_path):
        """Quick scan for common syntax issues"""
        patterns = {
            # Date patterns that are invalid - but NOT inside triple quotes
            'incomplete_date': re.compile(r'(?<!"")"(\d{4}-\d{2})"(?!\s*""")(?:\^\^[^>]*date)'),
            'datetime_with_space': re.compile(r'(?<!"")"(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})"(?!\s*""")'),
            'invalid_date_format': re.compile(r'(?<!"")"(\d{2}/\d{2}/\d{4}|\d{2}\.\d{2}\.\d{4})"(?!\s*""")(?:\^\^[^>]*date)'),
            
            # Check prefix declarations
            'malformed_prefix_decl': re.compile(r'@prefix\s+([^:]+)(?!:)\s*<'),
        }
        
        # URI pattern - will be applied to line with string literals removed
        uri_with_space_pattern = re.compile(r'<([^>]*\s+[^>]*)>')
        
        line_count = 0
        issues_found = 0
        in_triple_quotes = False
        triple_quote_buffer = ""
        last_progress_time = time.time()
        
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                for line_num, line in enumerate(f, 1):
                    line_count += 1
                    
                    # Progress indicator with time
                    if line_num % 500000 == 0:
                        current_time = time.time()
                        elapsed = current_time - last_progress_time
                        print(f"  Scanned {line_num:,} lines... (elapsed: {self._format_time(current_time - last_progress_time)})", end='\r')
                    
                    # Handle multi-line triple-quoted strings
                    if in_triple_quotes:
                        triple_quote_buffer += line
                        if '"""' in line:
                            # Check if this closes the triple quote
                            # Count quotes in the buffer to see if we have a complete string
                            quote_count = triple_quote_buffer.count('"""')
                            if quote_count >= 2 and quote_count % 2 == 0:
                                in_triple_quotes = False
                                triple_quote_buffer = ""
                        continue
                    else:
                        # Check if line starts a triple-quoted string
                        if '"""' in line:
                            quote_count = line.count('"""')
                            if quote_count % 2 == 1:  # Odd number means we're starting a multi-line string
                                in_triple_quotes = True
                                triple_quote_buffer = line
                                continue
                    
                    # Now check patterns on the current line
                    for issue_type, pattern in patterns.items():
                        matches = pattern.findall(line)
                        for match in matches:
                            if issues_found < self.error_limit:
                                self._add_issue('syntax', issue_type, {
                                    'line': line_num,
                                    'value': match if isinstance(match, str) else match[0],
                                    'context': line.strip()[:80]
                                })
                                issues_found += 1
                    
                    # Special handling for URI with space - remove string literals first
                    if not in_triple_quotes:
                        cleaned_line = self._remove_string_literals(line)
                        uri_matches = uri_with_space_pattern.findall(cleaned_line)
                        for match in uri_matches:
                            if issues_found < self.error_limit:
                                self._add_issue('syntax', 'uri_with_space', {
                                    'line': line_num,
                                    'value': match,
                                    'context': line.strip()[:80]
                                })
                                issues_found += 1
                    
                    # Stop if too many errors
                    if issues_found >= self.error_limit:
                        print(f"\n  [WARNING] Stopped collecting errors after {self.error_limit:,} issues")
                        break
        
        except Exception as e:
            self._add_issue('file', 'read_error', {'error': str(e)})
        
        self.stats['total_lines'] = line_count
        print(f"\n  [OK] Scanned {line_count:,} lines")
    
    def _parse_with_error_handling(self, file_path):
        """Parse RDF with proper error handling"""
        graph = Graph()
        parse_errors = []
        
        # Custom error handler
        def error_handler(e):
            error_str = str(e)
            # Filter out repetitive errors
            error_key = error_str[:50]
            if error_key not in self.seen_errors:
                self.seen_errors.add(error_key)
                parse_errors.append(error_str)
            return True  # Continue parsing
        
        try:
            # Redirect stderr to suppress rdflib verbose output
            import io
            import contextlib
            
            parse_start = time.time()
            print("  Parsing TTL file...")
            
            with contextlib.redirect_stderr(io.StringIO()):
                # Parse the file
                graph.parse(file_path, format='turtle')
            
            parse_time = time.time() - parse_start
            print(f"  [OK] Successfully parsed {len(graph):,} triples in {self._format_time(parse_time)}")
            
            # Check for collected parsing errors
            if parse_errors:
                # Group similar errors
                error_groups = defaultdict(int)
                for error in parse_errors[:100]:  # Limit to first 100
                    # Extract error type
                    if "Invalid isoformat string" in error:
                        error_groups["Invalid date format"] += 1
                    elif "Failed to convert" in error:
                        error_groups["Type conversion error"] += 1
                    else:
                        error_groups["Other parsing error"] += 1
                
                for error_type, count in error_groups.items():
                    self._add_issue('parsing', error_type, {'count': count})
            
            return graph
            
        except Exception as e:
            error_msg = str(e)
            print(f"  [ERROR] Fatal parsing error: {error_msg[:100]}...")
            self._add_issue('parsing', 'fatal_error', {'error': error_msg[:200]})
            return None
    
    def _validate_content(self, graph):
        """Validate RDF content"""
        print("  Checking literals...")
        
        # Sample validation - check a subset for performance
        total_triples = len(graph)
        sample_size = min(100000, total_triples)
        
        literal_issues = defaultdict(int)
        
        check_start = time.time()
        for i, (s, p, o) in enumerate(graph):
            if i >= sample_size:
                break
                
            if isinstance(o, Literal) and o.datatype:
                datatype = str(o.datatype)
                value = str(o)
                
                # Check for common issues
                if 'date' in datatype.lower():
                    if len(value) == 7 and re.match(r'\d{4}-\d{2}$', value):
                        literal_issues['Incomplete date (YYYY-MM)'] += 1
                    elif ' ' in value and ':' in value and re.match(r'\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2}', value):
                        literal_issues['DateTime with space separator'] += 1
                
                elif 'integer' in datatype.lower():
                    try:
                        int(value)
                    except:
                        literal_issues['Invalid integer value'] += 1
        
        check_time = time.time() - check_start
        
        # Add issues to report
        for issue_type, count in literal_issues.items():
            if count > 0:
                estimated_total = int(count * total_triples / sample_size)
                self._add_issue('content', issue_type, {
                    'sample_count': count,
                    'estimated_total': estimated_total
                })
        
        print(f"  [OK] Validated sample of {sample_size:,} triples in {self._format_time(check_time)}")
    
    def _add_issue(self, category, issue_type, details):
        """Add an issue with deduplication"""
        issue = self.issues[category][issue_type]
        issue['count'] += 1
        
        # Store first occurrence line number
        if 'line' in details and issue['first_line'] is None:
            issue['first_line'] = details['line']
        
        # Store limited examples
        if len(issue['examples']) < self.max_examples:
            # Simplify the details for readability
            example = {}
            if 'value' in details:
                example['value'] = str(details['value'])[:50]
            if 'line' in details:
                example['line'] = details['line']
            if 'error' in details:
                example['error'] = details['error'][:100]
            
            issue['examples'].append(example)
    
    def _generate_clean_report(self, elapsed_time):
        """Generate a clean, readable report"""
        print()
        print("=" * 60)
        print("VALIDATION REPORT")
        print("=" * 60)
        print()
        
        # Calculate totals
        total_errors = sum(
            sum(issue['count'] for issue in issues.values())
            for cat, issues in self.issues.items()
            if cat in ['syntax', 'parsing', 'content']
        )
        
        # Summary
        print("SUMMARY")
        print("-" * 40)
        print(f"Total Issues Found: {total_errors:,}")
        print(f"Validation Time: {self._format_time(elapsed_time)} ({elapsed_time:.1f} seconds)")
        print(f"Lines Processed: {self.stats.get('total_lines', 0):,}")
        print(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print()
        
        # Detailed issues by category
        if total_errors > 0:
            print("ISSUES BY CATEGORY")
            print("-" * 40)
            
            for category, issues in self.issues.items():
                if issues:
                    print(f"\n{category.upper()} ISSUES:")
                    
                    # Sort by count
                    sorted_issues = sorted(issues.items(), key=lambda x: x[1]['count'], reverse=True)
                    
                    for issue_type, data in sorted_issues[:10]:  # Top 10 issues
                        print(f"\n  * {issue_type}: {data['count']:,} occurrences")
                        
                        if data['first_line']:
                            print(f"    First seen: line {data['first_line']:,}")
                        
                        if data['examples']:
                            print("    Examples:")
                            for i, example in enumerate(data['examples'][:3], 1):
                                if 'line' in example and 'value' in example:
                                    print(f"      {i}. Line {example['line']:,}: {example['value']}")
                                elif 'value' in example:
                                    print(f"      {i}. {example['value']}")
                                elif 'error' in example:
                                    print(f"      {i}. {example['error']}")
        
        # Recommendations
        if total_errors > 0:
            print()
            print("-" * 40)
            print("RECOMMENDATIONS:")
            
            # Check for specific issues and provide targeted advice
            for category, issues in self.issues.items():
                for issue_type, data in issues.items():
                    if 'Incomplete date' in issue_type:
                        print("\n  * Fix incomplete dates:")
                        print("    - Change 'YYYY-MM' to 'YYYY-MM-DD'")
                        print("    - Example: '2025-07' -> '2025-07-01'")
                    
                    elif 'DateTime with space' in issue_type:
                        print("\n  * Fix datetime format:")
                        print("    - Replace space with 'T' in datetime values")
                        print("    - Example: '2025-07-01 12:00:00' -> '2025-07-01T12:00:00'")
                    
                    elif 'Invalid integer' in issue_type:
                        print("\n  * Fix integer values:")
                        print("    - Ensure integer literals contain only digits")
                        print("    - Remove any decimal points or non-numeric characters")
        
        # Final verdict
        print()
        print("=" * 60)
        if total_errors == 0:
            print("[SUCCESS] VALIDATION PASSED - No issues found!")
            result = True
        else:
            print(f"[FAILED] VALIDATION FAILED - {total_errors:,} issues need attention")
            result = False
        print("=" * 60)
        print()
        
        # Export detailed report
        self._export_json_report(elapsed_time)
        
        return result
    
    def _export_json_report(self, elapsed_time):
        """Export detailed JSON report"""
        report = {
            'timestamp': datetime.now().isoformat(),
            'validation_time_seconds': elapsed_time,
            'validation_time_formatted': self._format_time(elapsed_time),
            'statistics': dict(self.stats),
            'issues': {
                category: {
                    issue_type: {
                        'count': data['count'],
                        'first_line': data['first_line'],
                        'examples': data['examples'][:5]
                    }
                    for issue_type, data in issues.items()
                }
                for category, issues in self.issues.items()
            }
        }
        
        report_file = 'ttl_validation_report.json'
        with open(report_file, 'w') as f:
            json.dump(report, f, indent=2)
        
        print(f"Detailed report saved to: {report_file}")


def main():
    if len(sys.argv) != 2:
        print("Usage: python ttl_validator.py <ttl_file>")
        sys.exit(1)
    
    print(f"\nStarting validation at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    
    validator = CleanTTLValidator()
    
    try:
        success = validator.validate_file(sys.argv[1])
        sys.exit(0 if success else 1)
        
    except KeyboardInterrupt:
        print("\n\n[WARNING] Validation interrupted by user")
        sys.exit(1)
    except Exception as e:
        print(f"\n[ERROR] Fatal error: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()