import os
import pandas as pd
import matplotlib.pyplot as plt

# ===== Config =====
DATA_CSV = 'output/combined.csv'
PLOTS_DIR = 'output/plots'

location_mapper = {
    # Channel Islands
    "JERSEY": "Channel Islands",
    "ALDERNEY": "Channel Islands",

    # UK mainland
    "LONDON GATWICK": "UK mainland",
    "LONDON CITY": "UK mainland",
    "SOUTHAMPTON": "UK mainland",
    "BIRMINGHAM": "UK mainland",
    "NEWCASTLE": "UK mainland",
    "BRISTOL": "UK mainland",
    "MANCHESTER": "UK mainland",
    "NORWICH": "UK mainland",
    "EXETER": "UK mainland",
    "PORTSMOUTH": "UK mainland",
    "POOLE": "UK mainland",

    # France
    "ST MALO": "France",
    "PARIS": "France",
    "PARIS - CHARLES DE GAULLE": "France",
}

# ===== Functions =====

def load_data(csv_path):
    df = pd.read_csv(csv_path)
    df['Datetime'] = pd.to_datetime(df['Date'] + ' ' + df['Time'])
    return df


def plot_temperature_humidity(df, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    fig, ax1 = plt.subplots(figsize=(10, 5))

    ax1.bar(df['Datetime'], df['Humidity'], width=0.5, alpha=0.6, color='skyblue', label='Humidity (%)')
    ax1.set_xlabel('DateTime')
    ax1.set_ylabel('Humidity (%)', color='skyblue')
    ax1.tick_params(axis='y', labelcolor='skyblue')
    ax1.grid(True)

    ax2 = ax1.twinx()
    ax2.plot(df['Datetime'], df['Temp_c'], marker='o', color='red', label='Temperature (°C)')
    ax2.set_ylabel('Temperature (°C)', color='red')
    ax2.tick_params(axis='y', labelcolor='red')

    plt.title('Temperature and Humidity Over Time')

    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='best')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'temperature_humidity_over_time.png'))
    plt.close()


def plot_arrivals(df, output_dir):
    plt.figure(figsize=(10, 5))
    plt.plot(df['Datetime'], df['Amount_boats'], label='Boat Arrivals', marker='o')
    plt.plot(df['Datetime'], df['Amount_planes'], label='Plane Arrivals', marker='o')
    plt.title('Arrivals Over Time')
    plt.xlabel('DateTime')
    plt.ylabel('Number of Arrivals')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'arrivals_over_time.png'))
    plt.close()


def plot_arrivals_with_prediction(df, output_dir):
    # Calculate rolling averages
    df['Rolling_avg_boats'] = df['Amount_boats'].rolling(window=3, min_periods=1).mean()
    df['Rolling_avg_planes'] = df['Amount_planes'].rolling(window=3, min_periods=1).mean()

    predicted_boats = df['Rolling_avg_boats'].iloc[-1]
    predicted_planes = df['Rolling_avg_planes'].iloc[-1]

    next_day = df['Datetime'].iloc[-1] + pd.Timedelta(days=1)

    new_row = pd.DataFrame({
        'Datetime': [next_day],
        'Amount_boats': [None],
        'Amount_planes': [None],
        'Rolling_avg_boats': [predicted_boats],
        'Rolling_avg_planes': [predicted_planes]
    })

    df_pred = pd.concat([df, new_row], ignore_index=True)

    plt.figure(figsize=(10, 5))

    # Plot actual arrivals
    plt.plot(df['Datetime'], df['Amount_boats'], label='Boat Arrivals', marker='o')
    plt.plot(df['Datetime'], df['Amount_planes'], label='Plane Arrivals', marker='o')

    # Plot rolling averages
    plt.plot(df_pred['Datetime'], df_pred['Rolling_avg_boats'], label='Rolling Avg Boats', linestyle='--')
    plt.plot(df_pred['Datetime'], df_pred['Rolling_avg_planes'], label='Rolling Avg Planes', linestyle='--')

    # Plot predictions
    plt.scatter(next_day, predicted_boats, color='blue', marker='x', s=100, label='Predicted Boats')
    plt.scatter(next_day, predicted_planes, color='orange', marker='x', s=100, label='Predicted Planes')

    plt.title('Arrivals Over Time With Prediction')
    plt.xlabel('DateTime')
    plt.ylabel('Number of Arrivals')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'arrivals_over_time_with_prediction.png'))
    plt.close()


def map_region(origin):
    if not isinstance(origin, str):
        return "Other"
    primary = origin.split(",")[0].strip().upper()
    return location_mapper.get(primary, "Other")


def plot_popular_origins(df, output_dir):
    df['Boat_region'] = df['Popular_boat'].apply(map_region)
    df['Plane_region'] = df['Popular_plane'].apply(map_region)

    region_order = ['UK mainland', 'Channel Islands', 'France', 'Other']

    boat_counts = df['Boat_region'].value_counts().reindex(region_order, fill_value=0)
    plane_counts = df['Plane_region'].value_counts().reindex(region_order, fill_value=0)

    combined = pd.DataFrame({
        'Boats': boat_counts,
        'Planes': plane_counts
    })

    ax = combined.plot(kind='bar', figsize=(8, 5))
    plt.title('Popular Origins of Boats and Planes by Region')
    plt.xlabel('Region')
    plt.ylabel('Count')
    plt.xticks(rotation=0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'popular_origins_bar_chart.png'))
    plt.close()

# ===== Main Function =====

def main():

    # Load data
    print("Loading data...")
    df = load_data(DATA_CSV)

    # Plot data
    os.makedirs(PLOTS_DIR, exist_ok=True)

    print("Plotting temperature and humidity...")
    plot_temperature_humidity(df, PLOTS_DIR)

    print("Plotting arrivals...")
    plot_arrivals(df, PLOTS_DIR)

    print("Plotting arrivals with prediction...")
    plot_arrivals_with_prediction(df, PLOTS_DIR)

    print("Plotting popular origins...")
    plot_popular_origins(df, PLOTS_DIR)

    print("All plots saved to:", PLOTS_DIR)


if __name__ == "__main__":
    main()
