# -*- coding: utf-8 -*-
"""
A股量化数据引擎 - 纯 Tushare + SQLite 重构版
- 全部数据源切换为 Tushare Pro
- 数据库表：daily_data（含行情、基本面、资金面）
- 每日增量更新，计算核心指标，输出 market_data.json
"""

import os
import sys

# 禁用代理，避免连接异常
os.environ["NO_PROXY"] = "*"
os.environ["no_proxy"] = "*"
for k in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
    if k in os.environ:
        del os.environ[k]

import sqlite3
import json
import time
import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import tushare as ts
from dotenv import load_dotenv  # 【新增】用于读取 .env 文件

# ================= 配置区 =================

# 【修改】加载环境变量
# 脚本运行时会自动寻找同目录下的 .env 文件
load_dotenv()

# 【修改】从环境变量获取 Token，如果获取不到则报错
TUSHARE_TOKEN = os.getenv("TUSHARE_TOKEN")
if not TUSHARE_TOKEN:
    # 为了防止报错，如果没读到环境变量，这里留个空，但在 main 里会检查
    print("⚠️ 警告：未读取到 TUSHARE_TOKEN，请检查 .env 文件")
    TUSHARE_TOKEN = ""

OUTPUT_FILE = "market_data.json"
DB_PATH = "stock_data.db"
DATA_YEARS = 3
API_INTERVAL = 0.35  # 秒，5000 积分限流余量

# 股票池（支持 6 位代码，会自动转为 Tushare 格式）
# (代码, 名称) 列表，用于默认标的与 CODE_NAMES 展示
TARGET_LIST = [
    ("000001", "上证指数"),
    ("563300", "中证2000ETF"),
    ("512480", "半导体ETF"),
    ("159590", "软件ETF汇添富"),
    ("399001", "深证成指"),
    ("512500", "中证500ETF华夏"),
    ("159887", "银行ETF"),
    ("563530", "卫星ETF易方达"),
    ("512980", "传媒ETF"),
    ("159806", "新能源车ETF"),
    ("513010", "恒生科技ETF易方达"),
    ("561980", "半导体设备ETF"),
    ("159915", "创业板ETF易方达"),
    ("588000", "科创50ETF"),
    ("000300", "沪深300"),
    ("000852", "中证1000"),
    ("159995", "芯片ETF"),
    ("562500", "机器人ETF"),
    ("561910", "电池ETF"),
    ("515980", "人工智能ETF"),
    ("516510", "云计算ETF易方达"),
    ("516770", "游戏ETF华泰柏瑞"),
    ("399699", "金融科技"),
    ("159316", "港股通创新药ETF"),
]


def _load_target_codes() -> list[str]:
    # 尝试读取 stock_pool.txt，如果没有则使用默认列表
    pool_file = Path(__file__).parent / "stock_pool.txt"
    if pool_file.exists():
        codes = []
        try:
            for line in pool_file.read_text(encoding="utf-8").splitlines():
                line = line.strip()
                if line and not line.startswith("#"):
                    codes.append(line.split("#")[0].strip())
            if codes:
                print(f"  [配置] 已加载 stock_pool.txt，共 {len(codes)} 个标的")
                return codes
        except Exception as e:
            print(f"  [警告] 读取 stock_pool.txt 失败: {e}，将使用默认列表")
    
    print(f"  [配置] 使用内置默认列表，共 {len(TARGET_LIST)} 个标的")
    return [c for c, _ in TARGET_LIST]


TARGET_CODES = _load_target_codes()
# 代码 -> 名称，便于日志/输出展示
CODE_NAMES = {c: n for c, n in TARGET_LIST}
# =========================================


def get_pro():
    """获取 Tushare Pro 实例"""
    ts.set_token(TUSHARE_TOKEN)
    return ts.pro_api()


def ts_sleep():
    """API 限速"""
    time.sleep(API_INTERVAL)


def code_to_ts_code(code: str) -> str:
    """
    将 6 位代码自动转为 Tushare 格式 (000001.SZ / 000001.SH / 881121.SI / 899050.BJ)
    """
    s = str(code).strip()
    if "." in s:
        return s
    if len(s) != 6:
        return s
    prefix = s[:2]
    # 申万行业指数 88xxxx（必须在北交所判断之前）
    if s.startswith("88"):
        return f"{s}.SI"
    # 北交所
    if prefix in ("43", "83", "87") or s.startswith("4") or (s.startswith("8") and not s.startswith("88")):
        return f"{s}.BJ"
    # 深市 ETF 15/16 开头
    if prefix in ("15", "16"):
        return f"{s}.SZ"
    # 沪市 ETF 51/56/58 开头
    if prefix in ("51", "56", "58"):
        return f"{s}.SH"
    # 深圳指数 399/899 开头（北证 899 已上面处理）
    if s.startswith("399"):
        return f"{s}.SZ"
    # 北证指数 899 开头
    if s.startswith("899"):
        return f"{s}.BJ"
    # 沪深300、中证1000 等指数
    if s in ("000300", "000852", "000001"):
        return f"{s}.SH"
    # 上证主板 60/68 开头
    if prefix in ("60", "68"):
        return f"{s}.SH"
    # 深市主板/中小板 00/002/003
    if s.startswith("00") or s.startswith("002") or s.startswith("003"):
        return f"{s}.SZ"
    # 创业板 30 开头
    if s.startswith("30"):
        return f"{s}.SZ"
    return f"{s}.SZ"  # 默认深市


def get_code_type(ts_code: str) -> str:
    """判断类型：stock / etf / index"""
    code = ts_code.split(".")[0]
    suffix = ts_code.split(".")[-1] if "." in ts_code else ""
    if suffix in ("SI", "BJ") or code.startswith(("399", "899", "88")):
        return "index"
    if suffix in ("SI",):
        return "index"
    if code.startswith(("51", "52", "56", "58", "15", "16")):
        return "etf"
    if code in ("000300", "000852", "000001") and suffix == "SH":
        return "index"
    return "stock"


def init_db(conn: sqlite3.Connection) -> None:
    """初始化 daily_data 表"""
    conn.execute("""
        CREATE TABLE IF NOT EXISTS daily_data (
            ts_code TEXT NOT NULL,
            trade_date TEXT NOT NULL,
            open REAL,
            high REAL,
            low REAL,
            close REAL,
            vol REAL,
            amount REAL,
            pe_ttm REAL,
            pb REAL,
            total_mv REAL,
            buy_md_vol REAL,
            sell_md_vol REAL,
            buy_lg_vol REAL,
            sell_lg_vol REAL,
            buy_elg_vol REAL,
            sell_elg_vol REAL,
            buy_elg_amount REAL,
            sell_elg_amount REAL,
            buy_lg_amount REAL,
            sell_lg_amount REAL,
            adj_factor REAL,
            PRIMARY KEY (ts_code, trade_date)
        )
    """)
    conn.execute("CREATE INDEX IF NOT EXISTS idx_daily_ts_date ON daily_data(ts_code, trade_date)")
    # 迁移：若表已存在但缺资金流向 amount 列，则添加（忽略已存在列）
    for col in ("buy_elg_amount", "sell_elg_amount", "buy_lg_amount", "sell_lg_amount"):
        try:
            conn.execute(f"ALTER TABLE daily_data ADD COLUMN {col} REAL")
        except sqlite3.OperationalError:
            pass
    conn.commit()
    print("  [DB] daily_data 表已就绪")


def get_latest_date(conn: sqlite3.Connection, ts_code: str) -> str | None:
    """查询该标的在 DB 中的最新交易日"""
    cur = conn.execute("SELECT MAX(trade_date) FROM daily_data WHERE ts_code = ?", (ts_code,))
    row = cur.fetchone()
    return row[0] if row and row[0] else None


def _safe_call(pro, name: str, **kwargs) -> pd.DataFrame | None:
    """带限流与异常捕获的 Tushare 调用"""
    try:
        method = getattr(pro, name, None)
        if method is None:
            return None
        df = method(**kwargs)
        ts_sleep()
        return df if df is not None and not df.empty else None
    except Exception as e:
        print(f"    [API] {name} 失败: {e}")
        return None


def fetch_daily(pro, ts_code: str, start: str, end: str, ctype: str) -> pd.DataFrame | None:
    """拉取日线行情。股票用 daily，ETF 用 fund_daily，指数用 index_daily/sw_daily"""
    if ctype == "etf":
        df = _safe_call(pro, "fund_daily", ts_code=ts_code, start_date=start, end_date=end)
    elif ctype == "index":
        if ts_code.endswith(".SI"):
            df = _safe_call(pro, "sw_daily", ts_code=ts_code, start_date=start, end_date=end)
            if df is None or df.empty:
                df = _safe_call(pro, "index_daily", ts_code=ts_code, start_date=start, end_date=end)
        elif ts_code.endswith(".BJ"):
            df = _safe_call(pro, "index_daily", ts_code=ts_code, start_date=start, end_date=end)
        else:
            df = _safe_call(pro, "index_daily", ts_code=ts_code, start_date=start, end_date=end)
    else:
        df = _safe_call(pro, "daily", ts_code=ts_code, start_date=start, end_date=end)

    if df is None or df.empty:
        return None
    # 统一字段：open, high, low, close, vol, amount
    cols = ["trade_date", "open", "high", "low", "close", "vol", "amount"]
    for c in cols:
        if c not in df.columns:
            return None
    return df[cols]


def fetch_adj_factor(pro, ts_code: str, start: str, end: str, ctype: str) -> pd.DataFrame | None:
    """复权因子（仅股票）"""
    if ctype != "stock":
        return None
    return _safe_call(pro, "adj_factor", ts_code=ts_code, start_date=start, end_date=end)


def fetch_daily_basic(pro, ts_code: str, start: str, end: str, ctype: str) -> pd.DataFrame | None:
    """PE/PB/总市值（仅股票）"""
    if ctype != "stock":
        return None
    return _safe_call(pro, "daily_basic", ts_code=ts_code, start_date=start, end_date=end,
                      fields="trade_date,pe_ttm,pb,total_mv")


def fetch_moneyflow(pro, ts_code: str, start: str, end: str, ctype: str) -> pd.DataFrame | None:
    """资金流向（仅股票），含 vol 和 amount，用于计算主力净流入额"""
    if ctype != "stock":
        return None
    return _safe_call(pro, "moneyflow", ts_code=ts_code, start_date=start, end_date=end,
                      fields="trade_date,buy_md_vol,sell_md_vol,buy_lg_vol,sell_lg_vol,buy_elg_vol,sell_elg_vol,"
                             "buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount")


def merge_and_upsert(conn: sqlite3.Connection, ts_code: str, df_daily: pd.DataFrame,
                     df_adj: pd.DataFrame | None, df_basic: pd.DataFrame | None,
                     df_mf: pd.DataFrame | None, ctype: str) -> int:
    """合并各数据源并写入 daily_data"""
    if df_daily is None or df_daily.empty:
        return 0
    df = df_daily.copy()

    if df_adj is not None and not df_adj.empty:
        df = df.merge(df_adj[["trade_date", "adj_factor"]], on="trade_date", how="left")
    else:
        df["adj_factor"] = 1.0

    if df_basic is not None and not df_basic.empty:
        df = df.merge(df_basic, on="trade_date", how="left")
    for c in ("pe_ttm", "pb", "total_mv"):
        if c not in df.columns:
            df[c] = np.nan

    if df_mf is not None and not df_mf.empty:
        df = df.merge(df_mf, on="trade_date", how="left")
    mf_cols = ("buy_md_vol", "sell_md_vol", "buy_lg_vol", "sell_lg_vol", "buy_elg_vol", "sell_elg_vol",
               "buy_elg_amount", "sell_elg_amount", "buy_lg_amount", "sell_lg_amount")
    for c in mf_cols:
        if c not in df.columns:
            df[c] = np.nan

    df["adj_factor"] = df["adj_factor"].ffill().fillna(1.0)

    count = 0
    for _, row in df.iterrows():
        conn.execute("""
            INSERT OR REPLACE INTO daily_data
            (ts_code, trade_date, open, high, low, close, vol, amount,
             pe_ttm, pb, total_mv, buy_md_vol, sell_md_vol, buy_lg_vol, sell_lg_vol,
             buy_elg_vol, sell_elg_vol, buy_elg_amount, sell_elg_amount, buy_lg_amount, sell_lg_amount, adj_factor)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            ts_code,
            str(row["trade_date"]),
            float(row["open"]) if pd.notna(row["open"]) else None,
            float(row["high"]) if pd.notna(row["high"]) else None,
            float(row["low"]) if pd.notna(row["low"]) else None,
            float(row["close"]) if pd.notna(row["close"]) else None,
            float(row["vol"]) if pd.notna(row["vol"]) else None,
            float(row["amount"]) if pd.notna(row["amount"]) else None,
            float(row["pe_ttm"]) if pd.notna(row["pe_ttm"]) else None,
            float(row["pb"]) if pd.notna(row["pb"]) else None,
            float(row["total_mv"]) if pd.notna(row["total_mv"]) else None,
            float(row["buy_md_vol"]) if pd.notna(row.get("buy_md_vol")) else None,
            float(row["sell_md_vol"]) if pd.notna(row.get("sell_md_vol")) else None,
            float(row["buy_lg_vol"]) if pd.notna(row.get("buy_lg_vol")) else None,
            float(row["sell_lg_vol"]) if pd.notna(row.get("sell_lg_vol")) else None,
            float(row["buy_elg_vol"]) if pd.notna(row.get("buy_elg_vol")) else None,
            float(row["sell_elg_vol"]) if pd.notna(row.get("sell_elg_vol")) else None,
            float(row["buy_elg_amount"]) if pd.notna(row.get("buy_elg_amount")) else None,
            float(row["sell_elg_amount"]) if pd.notna(row.get("sell_elg_amount")) else None,
            float(row["buy_lg_amount"]) if pd.notna(row.get("buy_lg_amount")) else None,
            float(row["sell_lg_amount"]) if pd.notna(row.get("sell_lg_amount")) else None,
            float(row["adj_factor"]) if pd.notna(row["adj_factor"]) else 1.0,
        ))
        count += 1
    conn.commit()
    return count


def incremental_update(conn: sqlite3.Connection, pro, raw_code: str) -> bool:
    """单标的增量更新"""
    ts_code = code_to_ts_code(raw_code)
    ctype = get_code_type(ts_code)
    end_date = datetime.datetime.now().strftime("%Y%m%d")
    latest = get_latest_date(conn, ts_code)

    if latest:
        next_d = (pd.to_datetime(latest) + pd.Timedelta(days=1)).strftime("%Y%m%d")
        start_date = next_d
    else:
        start_date = (datetime.datetime.now() - datetime.timedelta(days=DATA_YEARS * 365)).strftime("%Y%m%d")

    if start_date > end_date:
        print(f"  [{ts_code}] 已是最新，跳过")
        return True

    df_daily = fetch_daily(pro, ts_code, start_date, end_date, ctype)
    if df_daily is None or df_daily.empty:
        print(f"  [{ts_code}] ({ctype}) 行情拉取失败，跳过")
        return False

    df_adj = fetch_adj_factor(pro, ts_code, start_date, end_date, ctype)
    df_basic = fetch_daily_basic(pro, ts_code, start_date, end_date, ctype)
    df_mf = fetch_moneyflow(pro, ts_code, start_date, end_date, ctype)

    n = merge_and_upsert(conn, ts_code, df_daily, df_adj, df_basic, df_mf, ctype)
    print(f"  [{ts_code}] ({ctype}) 更新 {n} 条")
    return True


# ---------- 指标计算 ----------
def calculate_macd(df: pd.DataFrame, col: str = "qfq_close", fast: int = 12, slow: int = 26, signal: int = 9) -> pd.DataFrame:
    """MACD: DIF, DEA, BAR"""
    df = df.sort_values("trade_date").reset_index(drop=True)
    ema_f = df[col].ewm(span=fast, adjust=False).mean()
    ema_s = df[col].ewm(span=slow, adjust=False).mean()
    df["dif"] = ema_f - ema_s
    df["dea"] = df["dif"].ewm(span=signal, adjust=False).mean()
    df["bar"] = (df["dif"] - df["dea"]) * 2
    return df


def percentile_rank(series: pd.Series, value: float) -> float | None:
    """百分位 0~100，越大表示当前值越高（越贵）"""
    valid = series.dropna()
    if valid.empty or pd.isna(value):
        return None
    return float((valid < value).sum() / len(valid) * 100)


# ---------- 事件数据 (T-1/实时) ----------
def fetch_margin_detail(pro, ts_code: str) -> dict | None:
    """融资融券明细，取最新一天"""
    if get_code_type(ts_code) != "stock":
        return None
    for d in range(0, 5):
        trade_date = (datetime.datetime.now() - datetime.timedelta(days=d)).strftime("%Y%m%d")
        try:
            df = pro.margin_detail(trade_date=trade_date)
            ts_sleep()
            if df is None or df.empty:
                continue
            sub = df[df["ts_code"] == ts_code]
            if sub.empty:
                continue
            row = sub.iloc[0]
            rz_buy = float(row["rzmre"]) if pd.notna(row.get("rzmre")) else 0
            rz_repay = float(row["rzche"]) if pd.notna(row.get("rzche")) else 0
            return {"date": trade_date, "rz_net_buy": rz_buy - rz_repay}
        except Exception:
            continue
    return None


def fetch_top_list(pro, ts_code: str) -> dict | None:
    """龙虎榜，最近 5 日是否上榜及净买入、上榜理由"""
    if get_code_type(ts_code) != "stock":
        return None
    for d in range(0, 5):
        trade_date = (datetime.datetime.now() - datetime.timedelta(days=d)).strftime("%Y%m%d")
        try:
            df = pro.top_list(trade_date=trade_date, ts_code=ts_code)
            ts_sleep()
            if df is not None and not df.empty:
                net = float(df["net_amount"].sum()) if "net_amount" in df.columns else 0
                reason = df["reason"].iloc[0] if "reason" in df.columns and not df["reason"].empty else ""
                return {"date": trade_date, "net_buy": net, "reason": str(reason)}
        except Exception:
            continue
    return None


def fetch_stk_holdernumber(pro, ts_code: str) -> dict | None:
    """股东户数，取最新一期及较上期变化率"""
    if get_code_type(ts_code) != "stock":
        return None
    try:
        end = datetime.datetime.now().strftime("%Y%m%d")
        start = (datetime.datetime.now() - datetime.timedelta(days=365 * 2)).strftime("%Y%m%d")
        df = pro.stk_holdernumber(ts_code=ts_code, start_date=start, end_date=end)
        ts_sleep()
        if df is None or df.empty or "holder_num" not in df.columns:
            return None
        df = df.sort_values("end_date", ascending=False).reset_index(drop=True)
        if len(df) < 2:
            return {"holder_num": int(df["holder_num"].iloc[0]), "change_pct": None}
        curr = int(df["holder_num"].iloc[0])
        prev = int(df["holder_num"].iloc[1])
        if prev == 0:
            return {"holder_num": curr, "change_pct": None}
        change_pct = round((curr - prev) / prev * 100, 2)
        return {"holder_num": curr, "change_pct": change_pct}
    except Exception:
        return None


# ---------- 输出 JSON ----------
def build_market_data(conn: sqlite3.Connection, pro) -> dict:
    """计算指标并生成 market_data"""
    result = {}
    cutoff = (datetime.datetime.now() - datetime.timedelta(days=DATA_YEARS * 365)).strftime("%Y%m%d")

    for raw_code in TARGET_CODES:
        ts_code = code_to_ts_code(raw_code)
        ctype = get_code_type(ts_code)

        df = pd.read_sql_query(
            "SELECT * FROM daily_data WHERE ts_code = ? AND trade_date >= ? ORDER BY trade_date",
            conn, params=(ts_code, cutoff)
        )
        if df.empty:
            print(f"  [{ts_code}] 无历史数据，跳过")
            continue

        df["amount"] = df["amount"].fillna(0)
        if "adj_factor" in df.columns:
            adj = df["adj_factor"].replace(0, np.nan).fillna(1.0)
            adj_latest = adj.iloc[-1]
            df["qfq_close"] = df["close"] * adj_latest / adj if adj_latest > 0 else df["close"]
        else:
            df["qfq_close"] = df["close"]

        df = calculate_macd(df)
        df["ma20_amount"] = df["amount"].rolling(20).mean()
        last_ma20 = float(df["ma20_amount"].iloc[-1]) if pd.notna(df["ma20_amount"].iloc[-1]) else 0

        pe_ttm = df["pe_ttm"].iloc[-1] if "pe_ttm" in df.columns else None
        pb = df["pb"].iloc[-1] if "pb" in df.columns else None
        pe_pct = percentile_rank(df["pe_ttm"], pe_ttm) if pd.notna(pe_ttm) else None
        pb_pct = percentile_rank(df["pb"], pb) if pd.notna(pb) else None

        # 主力净流入额(万元) = (特大单买-特大单卖) + (大单买-大单卖)
        main_net_inflow = None
        row_last = df.iloc[-1]
        if ctype == "stock":
            if "buy_elg_amount" in df.columns and pd.notna(row_last.get("buy_elg_amount")):
                elg_net = (float(row_last.get("buy_elg_amount") or 0) - float(row_last.get("sell_elg_amount") or 0))
                lg_net = (float(row_last.get("buy_lg_amount") or 0) - float(row_last.get("sell_lg_amount") or 0))
                main_net_inflow = round(elg_net + lg_net, 2)  # 万元

        item = {
            "ts_code": ts_code,
            "type": ctype,
            "update_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "klines_count": len(df),
            "last_close": float(df["close"].iloc[-1]),
            "indicators": {
                "macd": {
                    "dif": round(float(df["dif"].iloc[-1]), 3),
                    "dea": round(float(df["dea"].iloc[-1]), 3),
                    "bar": round(float(df["bar"].iloc[-1]), 3),
                },
                "ma20_amount": float(last_ma20),
                "ma20_amount_str": f"{last_ma20 / 1e5:.2f}亿",
            },
            "valuation": {
                "pe_ttm": float(pe_ttm) if pd.notna(pe_ttm) else None,
                "pb": float(pb) if pd.notna(pb) else None,
                "pe_percentile": round(pe_pct, 1) if pe_pct is not None else None,
                "pb_percentile": round(pb_pct, 1) if pb_pct is not None else None,
            },
            "money_flow": {"main_net_inflow": main_net_inflow},
            "margin": {},
            "lhb": {},
            "holders": {},
        }

        margin = fetch_margin_detail(pro, ts_code)
        if margin:
            item["margin"]["rz_net_buy"] = margin["rz_net_buy"]

        lhb = fetch_top_list(pro, ts_code)
        if lhb:
            item["lhb"]["net_buy"] = lhb["net_buy"]
            item["lhb"]["reason"] = lhb.get("reason", "")

        holders = fetch_stk_holdernumber(pro, ts_code)
        if holders is not None and holders.get("change_pct") is not None:
            item["holders"]["change_pct"] = holders["change_pct"]

        result[ts_code] = item

    return result


def main():
    if not TUSHARE_TOKEN:
        print("请先配置 TUSHARE_TOKEN")
        return

    print("=" * 50)
    print("A股量化数据引擎 - 纯 Tushare + SQLite")
    print("=" * 50)

    pro = get_pro()
    Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
    conn = sqlite3.connect(DB_PATH)
    init_db(conn)

    # 转为 ts_code 并去重
    codes = list(dict.fromkeys(code_to_ts_code(c) for c in TARGET_CODES))

    print(f"\n[1] 增量更新 ({len(codes)} 只)...")
    for raw_code in TARGET_CODES:
        incremental_update(conn, pro, raw_code)

    print("\n[2] 计算指标，生成 market_data.json ...")
    market_data = build_market_data(conn, pro)
    conn.close()

    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        json.dump(market_data, f, ensure_ascii=False, indent=2)

    print(f"\n[OK] 完成！已写入 {OUTPUT_FILE}")


if __name__ == "__main__":
    main()