"""
update_position.py - Run this every 10 seconds to update ISS position
This reads the TLE file and generates position data as JSON
Optimized version with terminator caching (120 second intervals)
"""
from skyfield.api import load, wgs84, EarthSatellite
from datetime import datetime, timedelta, timezone
import json
import math
import os

def load_tle(filename='iss_tle.txt'):
    """Load TLE data from file"""
    if not os.path.exists(filename):
        print(f"Error: {filename} not found. Run download_tle.py first!")
        return None
    
    with open(filename, 'r') as f:
        lines = f.read().strip().split('\n')
    
    # Filter out empty lines (some TLE files have blank lines between data)
    lines = [line.strip() for line in lines if line.strip()]
    
    if len(lines) < 3:
        print("Error: Invalid TLE file format - need at least 3 lines")
        return None
    
    return lines[0], lines[1], lines[2]

def calculate_terminator(time, altitude_km=0):
    """Calculate day/night terminator line where sun altitude = 0° at given altitude"""
    ts = load.timescale()
    t = ts.from_datetime(time)
    
    # Get sun and earth
    eph = load('de421.bsp')
    sun = eph['sun']
    earth = eph['earth']
    
    # Get sun's declination (subsolar latitude)
    sun_position = earth.at(t).observe(sun).apparent()
    ra, dec, distance = sun_position.radec()
    sun_declination = dec.degrees
    
    # Calculate horizon dip for altitude
    earth_radius_km = 6371.0
    if altitude_km > 0:
        horizon_dip_deg = math.degrees(math.acos(earth_radius_km / (earth_radius_km + altitude_km)))
    else:
        horizon_dip_deg = 0
    
    terminator_segments = []
    current_segment = []
    prev_lon = None
    
    # OPTIMIZATION: Find subsolar longitude with fewer samples
    max_alt = -90
    subsolar_lon = None
    
    # Coarse search every 30 degrees
    for test_lon in range(-180, 181, 30):
        location = earth + wgs84.latlon(0, test_lon)
        alt, az, dist = location.at(t).observe(sun).apparent().altaz()
        if subsolar_lon is None or alt.degrees > max_alt:
            max_alt = alt.degrees
            subsolar_lon = test_lon
    
    # Fine search only in ±30 degree range
    for test_lon in range(subsolar_lon - 30, subsolar_lon + 31, 2):
        location = earth + wgs84.latlon(0, test_lon)
        alt, az, dist = location.at(t).observe(sun).apparent().altaz()
        if alt.degrees > max_alt:
            max_alt = alt.degrees
            subsolar_lon = test_lon
    
    # OPTIMIZATION: Reduce sample resolution from 2° to 5°
    for lon in range(-180, 181, 5):
        # Calculate hour angle (longitude difference from subsolar point)
        hour_angle = lon - subsolar_lon
        while hour_angle > 180:
            hour_angle -= 360
        while hour_angle < -180:
            hour_angle += 360
        
        hour_angle_rad = math.radians(hour_angle)
        dec_rad = math.radians(sun_declination)
        
        # Calculate terminator latitude using spherical formula
        if abs(sun_declination) < 0.01:
            terminator_lat = 0
        elif abs(math.tan(dec_rad)) < 0.001:
            terminator_lat = 0
        else:
            tan_lat = -math.cos(hour_angle_rad) / math.tan(dec_rad)
            tan_lat = max(-20, min(20, tan_lat))
            terminator_lat = math.degrees(math.atan(tan_lat))
        
        terminator_lat = max(-90, min(90, terminator_lat))
        
        # For altitude offset: simple longitude shift everywhere
        output_lon = lon
        output_lat = terminator_lat
        
        if altitude_km > 0:
            lon_shift = horizon_dip_deg
            
            if hour_angle >= 0:
                output_lon = lon + lon_shift
            else:
                output_lon = lon - lon_shift
            
            while output_lon > 180:
                output_lon -= 360
            while output_lon < -180:
                output_lon += 360
        
        # Check for antimeridian crossing
        if prev_lon is not None and abs(output_lon - prev_lon) > 180:
            if len(current_segment) > 0:
                terminator_segments.append(current_segment)
                current_segment = []
        
        current_segment.append([output_lat, output_lon])
        prev_lon = output_lon
    
    if len(current_segment) > 0:
        terminator_segments.append(current_segment)
    
    return terminator_segments, sun_declination

def generate_iss_track(recalculate_terminator=True, cached_terminator_data=None):
    """Generate ISS position data for past 110 mins to future 110 mins"""
    # Load TLE
    tle_data = load_tle()
    if tle_data is None:
        return None
    
    name, line1, line2 = tle_data
    
    # Load timescale and create satellite
    ts = load.timescale()
    
    try:
        satellite = EarthSatellite(line1, line2, name, ts)
        print(f"Satellite created successfully")
    except Exception as e:
        print(f"Error creating satellite from TLE: {e}")
        return None
    
    # Current time (timezone-aware)
    now = datetime.now(timezone.utc)
    
    # Generate times: 110 minutes past to 110 minutes future
    start_time = now - timedelta(minutes=100)
    end_time = now + timedelta(minutes=100)
    
    # Create time array (every 30 seconds for smooth lines)
    times = []
    current = start_time
    while current <= end_time:
        times.append(current)
        current += timedelta(seconds=30)
    
    # Calculate positions
    past_positions = []
    future_positions = []
    current_position = None
    
    prev_lon = None
    for time in times:
        t = ts.from_datetime(time)
        geocentric = satellite.at(t)
        subpoint = wgs84.subpoint(geocentric)
        
        lon = subpoint.longitude.degrees
        
        # Detect antimeridian crossing
        if prev_lon is not None and abs(lon - prev_lon) > 180:
            if time < now:
                past_positions.append(None)
            elif time > now:
                future_positions.append(None)
        
        position = {
            'lat': subpoint.latitude.degrees,
            'lon': lon,
            'alt': subpoint.elevation.km,
            'time': time.isoformat()
        }
        
        if time < now:
            past_positions.append(position)
        elif time > now:
            future_positions.append(position)
        else:
            current_position = position
        
        prev_lon = lon
    
    # If current_position is None, use the closest time
    if current_position is None:
        t = ts.from_datetime(now)
        geocentric = satellite.at(t)
        subpoint = wgs84.subpoint(geocentric)
        current_position = {
            'lat': subpoint.latitude.degrees,
            'lon': subpoint.longitude.degrees,
            'alt': subpoint.elevation.km,
            'time': now.isoformat()
        }
    
    # Check for NaN values
    if math.isnan(current_position['lat']) or math.isnan(current_position['lon']):
        print("ERROR: Got NaN values for position!")
        print(f"Current position data: {current_position}")
        print("This usually means the TLE data is invalid or corrupted.")
        return None
    
    # Calculate or use cached terminator
    if recalculate_terminator:
        print("Calculating terminator...")
        terminator, sun_declination = calculate_terminator(now, altitude_km=0)
        terminator_iss, _ = calculate_terminator(now, altitude_km=400)
    else:
        print("Using cached terminator")
        terminator = cached_terminator_data['terminator']
        terminator_iss = cached_terminator_data['terminator_iss']
        sun_declination = cached_terminator_data['sun_declination']
    
    # Calculate ISS sunrise/sunset times
    # Look through past and future positions to find terminator crossings
    eph = load('de421.bsp')
    sun_obj = eph['sun']
    earth_obj = eph['earth']
    
    next_sunrise = None
    next_sunset = None
    last_sunrise = None
    last_sunset = None
    
    # Check all positions (past and future)
    all_positions = []
    for pos in past_positions:
        if pos is not None:
            all_positions.append((pos, datetime.fromisoformat(pos['time'])))
    all_positions.append((current_position, now))
    for pos in future_positions:
        if pos is not None:
            all_positions.append((pos, datetime.fromisoformat(pos['time'])))
    
    prev_in_sunlight = None
    for i, (pos, pos_time) in enumerate(all_positions):
        # Calculate if ISS is in sunlight at this position
        t = ts.from_datetime(pos_time)
        iss_location = earth_obj + wgs84.latlon(pos['lat'], pos['lon'], elevation_m=400000)
        sun_alt, sun_az, sun_dist = iss_location.at(t).observe(sun_obj).apparent().altaz()
        in_sunlight = sun_alt.degrees > 0
        
        if prev_in_sunlight is not None:
            if prev_in_sunlight and not in_sunlight:
                # Sunset occurred
                if pos_time > now and next_sunset is None:
                    next_sunset = pos_time
                elif pos_time <= now:
                    last_sunset = pos_time
            elif not prev_in_sunlight and in_sunlight:
                # Sunrise occurred
                if pos_time > now and next_sunrise is None:
                    next_sunrise = pos_time
                elif pos_time <= now:
                    last_sunrise = pos_time
        
        prev_in_sunlight = in_sunlight
    
    return {
        'current': current_position,
        'past': past_positions,
        'future': future_positions,
        'terminator': terminator,
        'terminator_iss': terminator_iss,
        'sun_declination': sun_declination,
        'next_sunrise': next_sunrise.isoformat() if next_sunrise else None,
        'next_sunset': next_sunset.isoformat() if next_sunset else None,
        'last_sunrise': last_sunrise.isoformat() if last_sunrise else None,
        'last_sunset': last_sunset.isoformat() if last_sunset else None,
        'timestamp': now.isoformat()
    }

if __name__ == '__main__':
    print("Updating ISS position...")
    
    # Check if we need to recalculate terminator (every 120 seconds)
    recalculate_terminator = True
    cached_terminator_data = None
    terminator_cache_file = 'terminator_cache.json'
    
    if os.path.exists(terminator_cache_file):
        try:
            with open(terminator_cache_file, 'r') as f:
                cache = json.load(f)
                last_calc_time = datetime.fromisoformat(cache['timestamp']).replace(tzinfo=timezone.utc)
                now = datetime.now(timezone.utc)
                
                # Check if less than 120 seconds have passed
                if (now - last_calc_time).total_seconds() < 120:
                    recalculate_terminator = False
                    cached_terminator_data = cache
        except Exception as e:
            print(f"Cache read error: {e}")
            pass
    
    data = generate_iss_track(recalculate_terminator, cached_terminator_data)
    
    if data:
        # Save terminator cache if we recalculated
        if recalculate_terminator:
            with open(terminator_cache_file, 'w') as f:
                json.dump({
                    'timestamp': data['timestamp'],
                    'terminator': data['terminator'],
                    'terminator_iss': data['terminator_iss'],
                    'sun_declination': data['sun_declination']
                }, f)
        
        with open('iss_data.json', 'w') as f:
            json.dump(data, f)
        
        print(f"Position updated at {data['timestamp']}")
        print(f"ISS Position: {data['current']['lat']:.4f}°, {data['current']['lon']:.4f}°")
    else:
        print("Failed to generate position data")
