"""
全球军费开支30年变化分析
数据来源: SIPRI (斯德哥尔摩国际和平研究所)
"""

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import requests
from io import BytesIO

# ============================================
# 第一部分：数据获取
# ============================================

def fetch_sipri_data():
    """
    从SIPRI获取军费开支数据
    数据单位：百万美元（当前价格）
    
    SIPRI现在使用新系统 milex.sipri.org，需要手动下载Excel文件
    下载步骤：
    1. 访问 https://milex.sipri.org/sipri
    2. 选择所有国家、所有年份、"Current USD" 指标
    3. 点击 Download Excel
    4. 将文件保存为 data/SIPRI-Milex-data.xlsx
    """
    import os
    
    local_file = "data/SIPRI-Milex-data.xlsx"
    
    # 优先使用本地文件
    if os.path.exists(local_file):
        print(f"从本地文件加载数据: {local_file}")
        df = pd.read_excel(
            local_file,
            sheet_name="Constant (2023) US$",  # 使用2023年不变价美元
            skiprows=5,
            index_col=0
        )
        print(f"数据加载成功！共 {len(df)} 个国家/地区")
        return df
    
    # 尝试多个可能的下载链接
    urls = [
        "https://www.sipri.org/sites/default/files/SIPRI-Milex-data-1949-2024.xlsx",
        "https://www.sipri.org/sites/default/files/Data%20for%20all%20countries%201949%E2%80%932024.xlsx",
        "https://www.sipri.org/sites/default/files/milex_data_1949-2024.xlsx",
    ]
    
    for url in urls:
        print(f"尝试下载: {url}")
        try:
            response = requests.get(url, timeout=30)
            if response.status_code == 200:
                df = pd.read_excel(
                    BytesIO(response.content),
                    sheet_name="Constant (2023) US$",
                    skiprows=5,
                    index_col=0
                )
                print(f"数据下载成功！共 {len(df)} 个国家/地区")
                return df
        except Exception as e:
            print(f"  失败: {e}")
            continue
    
    # 如果都失败，提示手动下载
    print("\n" + "=" * 60)
    print("自动下载失败！请手动下载数据：")
    print("1. 访问 https://milex.sipri.org/sipri")
    print("2. 点击页面上的 'Download full database' 按钮")
    print("3. 将下载的Excel文件保存到: data/SIPRI-Milex-data.xlsx")
    print("4. 重新运行此脚本")
    print("=" * 60)
    return None


# ============================================
# 第二部分：数据清洗
# ============================================

def clean_data(df):
    """
    清洗原始数据
    - 处理缺失值
    - 转换数据类型
    - 重塑数据结构
    """
    # 只保留1993-2023年的数据（30年）
    years = [str(y) for y in range(1993, 2024)]
    available_years = [y for y in years if y in df.columns or int(y) in df.columns]
    
    # 尝试用整数或字符串列名
    cols_to_keep = []
    for y in range(1993, 2024):
        if y in df.columns:
            cols_to_keep.append(y)
        elif str(y) in df.columns:
            cols_to_keep.append(str(y))
    
    df_subset = df[cols_to_keep].copy()
    
    # 重命名列为字符串年份
    df_subset.columns = [str(c) for c in df_subset.columns]
    
    # 处理特殊标记（如 'xxx' 表示无数据，'. .' 表示不适用）
    df_clean = df_subset.replace(['xxx', '. .', '...'], pd.NA)
    
    # 转换为数值类型
    for col in df_clean.columns:
        df_clean[col] = pd.to_numeric(df_clean[col], errors='coerce')
    
    # 转换为长表格式，方便可视化
    df_reset = df_clean.reset_index()
    id_col = df_reset.columns[0]  # 获取实际的索引列名
    
    df_long = df_reset.melt(
        id_vars=[id_col],
        var_name='Year',
        value_name='Military_Spending'
    )
    df_long.columns = ['Country', 'Year', 'Military_Spending']
    df_long['Year'] = df_long['Year'].astype(int)
    
    # 移除缺失值
    df_long = df_long.dropna(subset=['Military_Spending'])
    
    print(f"数据清洗完成！共 {len(df_long)} 条有效记录")
    return df_clean, df_long


# ============================================
# 第三部分：可视化 - 全球总览
# ============================================

def plot_global_trend(df_long):
    """
    图1: 全球军费总额30年趋势（带关键事件标注）
    """
    # 按年份汇总全球总额
    global_total = df_long.groupby('Year')['Military_Spending'].sum().reset_index()
    global_total['Military_Spending_Billion'] = global_total['Military_Spending'] / 1000  # 转为十亿美元
    
    fig = px.line(
        global_total,
        x='Year',
        y='Military_Spending_Billion',
        title='全球军费开支总额 (1993-2023)',
        labels={'Military_Spending_Billion': '军费开支 (十亿美元)', 'Year': '年份'}
    )
    
    # 添加关键事件标注
    events = [
        (2001, '9/11事件'),
        (2008, '金融危机'),
        (2014, '克里米亚'),
        (2022, '俄乌战争'),
    ]
    
    for year, event in events:
        value = global_total[global_total['Year'] == year]['Military_Spending_Billion'].values
        if len(value) > 0:
            fig.add_annotation(
                x=year, y=value[0],
                text=event,
                showarrow=True,
                arrowhead=2,
                arrowsize=1,
                arrowwidth=2,
                ax=0, ay=-40
            )
    
    fig.update_layout(
        template='plotly_dark',
        font=dict(size=14),
        hovermode='x unified'
    )
    
    return fig


def plot_top10_2023(df_clean):
    """
    图2: 2023年军费开支Top10国家
    """
    top10 = df_clean['2023'].dropna().sort_values(ascending=False).head(10)
    top10_df = pd.DataFrame({
        'Country': top10.index,
        'Spending': top10.values / 1000  # 十亿美元
    })
    
    # 美国用不同颜色
    colors = ['#FF6B6B' if c == 'United States of America' else '#4ECDC4' 
              for c in top10_df['Country']]
    
    fig = px.bar(
        top10_df,
        x='Country',
        y='Spending',
        title='2023年全球军费开支Top10',
        labels={'Spending': '军费开支 (十亿美元)', 'Country': '国家'},
        color='Country',
        color_discrete_sequence=colors
    )
    
    fig.update_layout(
        template='plotly_dark',
        showlegend=False,
        xaxis_tickangle=-45
    )
    
    return fig


# ============================================
# 第四部分：可视化 - 美国分析
# ============================================

def plot_usa_trend(df_long):
    """
    图3: 美国军费30年变化（标注总统任期）
    """
    usa_data = df_long[df_long['Country'] == 'United States of America'].copy()
    usa_data['Spending_Billion'] = usa_data['Military_Spending'] / 1000
    
    fig = px.line(
        usa_data,
        x='Year',
        y='Spending_Billion',
        title='美国军费开支变化 (1993-2023)',
        labels={'Spending_Billion': '军费开支 (十亿美元)', 'Year': '年份'}
    )
    
    # 添加总统任期背景色
    presidents = [
        (1993, 2001, '克林顿', 'rgba(0,0,255,0.1)'),
        (2001, 2009, '小布什', 'rgba(255,0,0,0.1)'),
        (2009, 2017, '奥巴马', 'rgba(0,0,255,0.1)'),
        (2017, 2021, '特朗普', 'rgba(255,0,0,0.1)'),
        (2021, 2024, '拜登', 'rgba(0,0,255,0.1)'),
    ]
    
    for start, end, name, color in presidents:
        fig.add_vrect(
            x0=start, x1=end,
            fillcolor=color,
            layer='below',
            line_width=0,
            annotation_text=name,
            annotation_position='top left'
        )
    
    fig.update_layout(template='plotly_dark')
    return fig


def plot_usa_vs_world(df_clean):
    """
    图4: 美国 vs 其他Top10国家军费对比（饼图）
    """
    top10 = df_clean['2023'].dropna().sort_values(ascending=False).head(10)
    usa_spending = top10['United States of America']
    others_spending = top10.drop('United States of America').sum()
    
    fig = go.Figure(data=[go.Pie(
        labels=['美国', '其他Top10国家总和'],
        values=[usa_spending, others_spending],
        hole=0.4,
        marker_colors=['#FF6B6B', '#4ECDC4']
    )])
    
    fig.update_layout(
        title='美国 vs 其他Top10国家军费对比 (2023)',
        template='plotly_dark'
    )
    
    return fig


# ============================================
# 第五部分：可视化 - 中国分析
# ============================================

def plot_china_trend(df_long):
    """
    图5: 中国军费30年变化
    """
    china_data = df_long[df_long['Country'] == 'China'].copy()
    china_data['Spending_Billion'] = china_data['Military_Spending'] / 1000
    
    fig = px.line(
        china_data,
        x='Year',
        y='Spending_Billion',
        title='中国军费开支变化 (1993-2023)',
        labels={'Spending_Billion': '军费开支 (十亿美元)', 'Year': '年份'}
    )
    
    fig.update_layout(template='plotly_dark')
    fig.update_traces(line_color='#FF6B6B')
    
    return fig


def plot_china_vs_usa(df_long):
    """
    图6: 中美军费差距变化对比
    """
    usa = df_long[df_long['Country'] == 'United States of America'][['Year', 'Military_Spending']]
    usa.columns = ['Year', 'USA']
    
    china = df_long[df_long['Country'] == 'China'][['Year', 'Military_Spending']]
    china.columns = ['Year', 'China']
    
    comparison = usa.merge(china, on='Year')
    comparison['USA'] = comparison['USA'] / 1000
    comparison['China'] = comparison['China'] / 1000
    comparison['Ratio'] = comparison['USA'] / comparison['China']
    
    # 双Y轴图表
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    
    fig.add_trace(
        go.Scatter(x=comparison['Year'], y=comparison['USA'], name='美国', line=dict(color='#4ECDC4')),
        secondary_y=False
    )
    
    fig.add_trace(
        go.Scatter(x=comparison['Year'], y=comparison['China'], name='中国', line=dict(color='#FF6B6B')),
        secondary_y=False
    )
    
    fig.add_trace(
        go.Scatter(x=comparison['Year'], y=comparison['Ratio'], name='美/中比值', 
                   line=dict(color='#FFE66D', dash='dash')),
        secondary_y=True
    )
    
    fig.update_layout(
        title='中美军费对比 (1993-2023)',
        template='plotly_dark'
    )
    fig.update_yaxes(title_text='军费开支 (十亿美元)', secondary_y=False)
    fig.update_yaxes(title_text='美国/中国 比值', secondary_y=True)
    
    return fig


# ============================================
# 第六部分：可视化 - 欧洲分析
# ============================================

def plot_europe_trend(df_long):
    """
    图7: 欧洲主要国家军费变化（德法英意）
    """
    europe_countries = ['Germany', 'France', 'United Kingdom', 'Italy', 'Poland']
    europe_data = df_long[df_long['Country'].isin(europe_countries)].copy()
    europe_data['Spending_Billion'] = europe_data['Military_Spending'] / 1000
    
    fig = px.line(
        europe_data,
        x='Year',
        y='Spending_Billion',
        color='Country',
        title='欧洲主要国家军费变化 (1993-2023)',
        labels={'Spending_Billion': '军费开支 (十亿美元)', 'Year': '年份', 'Country': '国家'}
    )
    
    # 标注2022年俄乌战争
    fig.add_vline(x=2022, line_dash='dash', line_color='red', 
                  annotation_text='俄乌战争爆发')
    
    fig.update_layout(template='plotly_dark')
    
    return fig


# ============================================
# 第七部分：有趣发现
# ============================================

def plot_gdp_ratio_top10(df_clean):
    """
    图8: 军费占GDP比例Top10（需要额外GDP数据，这里用模拟数据演示）
    注意：实际使用时需要从世界银行等获取GDP数据
    """
    # 这里用SIPRI提供的占GDP比例数据（如果有的话）
    # 或者手动整理的2023年数据
    gdp_ratio_data = {
        'Country': ['乌克兰', '以色列', '沙特', '俄罗斯', '美国', '韩国', '巴基斯坦', '印度', '英国', '法国'],
        'GDP_Ratio': [37.0, 5.3, 6.0, 5.9, 3.5, 2.7, 3.7, 2.4, 2.3, 2.1]
    }
    df_ratio = pd.DataFrame(gdp_ratio_data)
    
    fig = px.bar(
        df_ratio,
        x='Country',
        y='GDP_Ratio',
        title='军费占GDP比例 Top10 (2023)',
        labels={'GDP_Ratio': '占GDP比例 (%)', 'Country': '国家'},
        color='GDP_Ratio',
        color_continuous_scale='Reds'
    )
    
    fig.update_layout(template='plotly_dark')
    
    return fig


def plot_growth_rate_top10(df_clean):
    """
    图9: 30年军费增长率Top10
    """
    # 计算1993到2023的增长率
    growth = (df_clean['2023'] / df_clean['1993']).dropna().sort_values(ascending=False)
    growth_top10 = growth.head(10)
    
    growth_df = pd.DataFrame({
        'Country': growth_top10.index,
        'Growth_Rate': growth_top10.values
    })
    
    fig = px.bar(
        growth_df,
        x='Country',
        y='Growth_Rate',
        title='军费30年增长倍数 Top10 (1993-2023)',
        labels={'Growth_Rate': '增长倍数', 'Country': '国家'},
        color='Growth_Rate',
        color_continuous_scale='Viridis'
    )
    
    fig.update_layout(template='plotly_dark', xaxis_tickangle=-45)
    
    return fig


# ============================================
# 第八部分：动态排名图（Racing Bar Chart）
# ============================================

def create_racing_bar_data(df_clean):
    """
    准备动态排名图数据
    """
    years = [str(y) for y in range(1993, 2024)]
    frames_data = []
    
    for year in years:
        if year in df_clean.columns:
            year_data = df_clean[year].dropna().sort_values(ascending=False).head(10)
            for rank, (country, value) in enumerate(year_data.items(), 1):
                frames_data.append({
                    'Year': int(year),
                    'Country': country,
                    'Spending': value / 1000,  # 十亿美元
                    'Rank': rank
                })
    
    return pd.DataFrame(frames_data)


def plot_racing_bar(df_clean):
    """
    图10: 动态排名图 - 军费Top10国家30年变化
    """
    racing_data = create_racing_bar_data(df_clean)
    
    fig = px.bar(
        racing_data,
        x='Spending',
        y='Country',
        color='Country',
        animation_frame='Year',
        orientation='h',
        range_x=[0, racing_data['Spending'].max() * 1.1],
        title='全球军费Top10动态排名 (1993-2023)',
        labels={'Spending': '军费开支 (十亿美元)', 'Country': '国家'}
    )
    
    fig.update_layout(
        template='plotly_dark',
        showlegend=False,
        yaxis={'categoryorder': 'total ascending'}
    )
    
    # 调整动画速度
    fig.layout.updatemenus[0].buttons[0].args[1]['frame']['duration'] = 500
    fig.layout.updatemenus[0].buttons[0].args[1]['transition']['duration'] = 300
    
    return fig


# ============================================
# 主程序
# ============================================

def main():
    """
    主函数：运行完整分析流程
    """
    print("=" * 50)
    print("全球军费开支30年变化分析")
    print("=" * 50)
    
    # 1. 获取数据
    df_raw = fetch_sipri_data()
    if df_raw is None:
        print("数据获取失败，请检查网络连接")
        return
    
    # 2. 清洗数据
    df_clean, df_long = clean_data(df_raw)
    
    # 3. 生成所有图表
    print("\n正在生成图表...")
    
    charts = {
        '01_global_trend': plot_global_trend(df_long),
        '02_top10_2023': plot_top10_2023(df_clean),
        '03_usa_trend': plot_usa_trend(df_long),
        '04_usa_vs_world': plot_usa_vs_world(df_clean),
        '05_china_trend': plot_china_trend(df_long),
        '06_china_vs_usa': plot_china_vs_usa(df_long),
        '07_europe_trend': plot_europe_trend(df_long),
        '08_gdp_ratio': plot_gdp_ratio_top10(df_clean),
        '09_growth_rate': plot_growth_rate_top10(df_clean),
        '10_racing_bar': plot_racing_bar(df_clean),
    }
    
    # 4. 保存图表为HTML（可交互）
    print("\n保存图表...")
    for name, fig in charts.items():
        filename = f"output/{name}.html"
        fig.write_html(filename)
        print(f"  ✓ {filename}")
    
    # 5. 保存静态图片（用于视频）
    print("\n保存静态图片...")
    for name, fig in charts.items():
        if 'racing' not in name:  # 动态图不保存静态版
            filename = f"output/{name}.png"
            fig.write_image(filename, width=1920, height=1080, scale=2)
            print(f"  ✓ {filename}")
    
    print("\n" + "=" * 50)
    print("分析完成！所有图表已保存到 output/ 目录")
    print("=" * 50)
    
    # 返回数据供进一步分析
    return df_clean, df_long, charts


if __name__ == "__main__":
    # 创建输出目录
    import os
    os.makedirs("output", exist_ok=True)
    os.makedirs("data", exist_ok=True)
    
    # 运行分析
    results = main()
