import re
import sys
import os
from collections import defaultdict
from datetime import datetime

def find_datetime_literals_simple(ttl_file_path):
    """
    Simple text-based approach to find datetime literals without rdflib.
    Handles multi-line TTL statements correctly.
    """
    # Pattern to match datetime format with space (invalid xsd:dateTime)
    invalid_datetime_pattern = r'"(\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2})"(?:\^\^xsd:dateTime)?'
    
    predicate_examples = defaultdict(list)
    current_predicate = None
    
    try:
        with open(ttl_file_path, 'r', encoding='utf-8') as file:
            lines = file.readlines()
            
            for i, line in enumerate(lines):
                stripped = line.strip()
                
                # Skip comments and empty lines
                if stripped.startswith('#') or not stripped:
                    continue
                
                # Check if this line starts a new predicate (not indented continuation)
                # A new predicate line typically has format: predicate value or predicate <uri>
                # Look for predicate patterns: mt:Something, :something, <http://...>
                predicate_match = re.match(r'^((?:[\w]+:[\w]+)|(?::[\w]+)|(?:<[^>]+>))\s+(.*)$', stripped)
                
                if predicate_match:
                    current_predicate = predicate_match.group(1)
                    rest_of_line = predicate_match.group(2)
                    
                    # Check if datetime is on this same line
                    datetime_match = re.search(invalid_datetime_pattern, rest_of_line)
                    if datetime_match:
                        predicate_examples[current_predicate].append({
                            'value': datetime_match.group(1),
                            'line_num': i + 1,
                            'line': stripped
                        })
                
                # Check for continuation lines (start with datetime value or comma)
                elif stripped.startswith('"') or stripped.startswith(','):
                    datetime_match = re.search(invalid_datetime_pattern, stripped)
                    if datetime_match and current_predicate:
                        predicate_examples[current_predicate].append({
                            'value': datetime_match.group(1),
                            'line_num': i + 1,
                            'line': stripped
                        })
                
                # Reset predicate on statement end (;) or subject end (.)
                # But only if it's at the end of the line
                if stripped.endswith('.'):
                    current_predicate = None
                # Semicolon ends the predicate but not the subject
                elif stripped.endswith(';'):
                    current_predicate = None
        
        return dict(predicate_examples)
        
    except Exception as e:
        print(f"Error reading TTL file: {e}")
        return None


def display_results(predicate_examples):
    """
    Display the results in a formatted way
    """
    if not predicate_examples:
        print("\n✓ No invalid datetime values found (format: 'YYYY-MM-DD HH:MM:SS' with space)")
        return
    
    total_occurrences = sum(len(v) for v in predicate_examples.values())
    print(f"\n✗ Found {total_occurrences} invalid datetime values across {len(predicate_examples)} predicates")
    print("  Invalid format: 'YYYY-MM-DD HH:MM:SS' (space between date and time)")
    print("  Valid format:   'YYYY-MM-DDTHH:MM:SS' (T between date and time)")
    print("-" * 80)
    
    for i, (predicate, occurrences) in enumerate(predicate_examples.items(), 1):
        print(f"\n{i}. Predicate: {predicate}")
        print(f"   Occurrences: {len(occurrences)}")
        
        # Show first few examples
        for j, occ in enumerate(occurrences[:3]):
            print(f"   Line {occ['line_num']}: \"{occ['value']}\"")
        
        if len(occurrences) > 3:
            print(f"   ... and {len(occurrences) - 3} more")
    
    print("\n" + "-" * 80)
    print("FIX: Update your SQL to use TO_CHAR(date_col, 'YYYY-MM-DD\"T\"HH24:MI:SS')")
    print("-" * 80)


def main():
    """
    Main function to execute the script
    """
    if len(sys.argv) < 2:
        print("Usage: python generate_ttl_validation_dates.py <ttl_file_path>")
        print("Example: python generate_ttl_validation_dates.py output.ttl")
        sys.exit(1)
    
    ttl_file_path = sys.argv[1]
    
    if not os.path.exists(ttl_file_path):
        print(f"Error: File '{ttl_file_path}' not found.")
        sys.exit(1)
    
    if not ttl_file_path.lower().endswith('.ttl'):
        print(f"Warning: File '{ttl_file_path}' does not have .ttl extension.")
        proceed = input("Do you want to continue anyway? (y/n): ").lower()
        if proceed != 'y':
            sys.exit(0)
    
    print(f"Analyzing TTL file: {ttl_file_path}")
    print("Looking for INVALID datetime values (space instead of T separator)")
    
    results = find_datetime_literals_simple(ttl_file_path)
    
    if results is not None:
        display_results(results)
        
        if results:
            save_to_file = input("\nSave results to file? (y/n): ").lower()
            if save_to_file == 'y':
                base_name = os.path.splitext(os.path.basename(ttl_file_path))[0]
                output_file = f"{base_name}_invalid_datetime_report.txt"
                
                with open(output_file, 'w') as f:
                    f.write(f"Invalid DateTime Report\n")
                    f.write(f"{'=' * 60}\n")
                    f.write(f"Source: {ttl_file_path}\n")
                    f.write(f"Generated: {datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}\n\n")
                    
                    total = sum(len(v) for v in results.values())
                    f.write(f"Total invalid datetimes: {total}\n")
                    f.write(f"Affected predicates: {len(results)}\n\n")
                    
                    for predicate, occurrences in results.items():
                        f.write(f"\nPredicate: {predicate}\n")
                        f.write(f"Count: {len(occurrences)}\n")
                        for occ in occurrences:
                            f.write(f"  Line {occ['line_num']}: \"{occ['value']}\"\n")
                        
                print(f"\nResults saved to: {output_file}")


if __name__ == "__main__":
    main()
