from datetime import datetime, date, timedelta
import json
import socket

import threading

from ctaTemplate import CtaTemplate
from vtObject import TickData, OrderData
import ctaEngine  # type:ignore

HOST = "localhost"
PORT = 55288

class MyJson(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, datetime):
            return o.strftime("%Y-%m-%d %X")
        elif isinstance(o, date):
            return o.strftime("%Y-%m-%d")
        elif isinstance(o, bytes):
            return o.decode()
        else:
            super().default(o)

class RTDServer(CtaTemplate):
    """RTD 策略测试文件"""
    def __init__(self):
        super().__init__()
        self.paramMap = { "investor": "行情账号" }
        """参数栏"""

        self.varMap = { "is_server_running": "连接状态" }
        """状态栏"""
        
        self.is_server_running: bool = False
        """判断是否 RTD 正常连接"""

        self.server_thread: threading.Thread = None
        """RTD 运行线程"""

        self.server_socket: socket.socket = None
        """socket"""

        self.lock: threading.Lock = threading.Lock()
        """lock"""

        self.sub_key: list = []
        """已订阅的数据"""

        self.tick_data: dict = {}
        """最新订阅的 tick 数据"""

        self.order: dict = {}
        """订单"""

        self.order_open_dict: dict = {}
        """尚未完全成交订单"""

        self.options_dict: dict = {}
        """期权数据"""

    def onStart(self) -> None:
        self.output( "策略开始，当前策略更新时间：2024-07-24 17:03:51" )
        self.is_server_running = True
        self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.init_data()

        self.server_thread = threading.Thread(target=self.start_server)
        self.server_thread.start()

    def init_data(self) -> None:
        self.sub_key = []

    def onStop(self) -> None:
        """暂停策略"""
        for key in self.sub_key:
            symbol, exchange = key.split(".")
            self.un_subsymbol(symbol, exchange)

        with self.lock:
            self.output(f"{self.name} 策略停止")
            self.trading = False

            self.is_server_running = False
            if self.server_socket:
                self.output("RTD 服务关闭")
                self.server_socket.close()

            if self.server_thread:
                self.server_thread.join()

    def start_server(self) -> None:
        """开启服务器"""
        with self.lock:
            self.server_socket.bind((HOST, PORT))
            self.server_socket.listen(50)
            self.output("RTD 服务开启，可于 Excel 中获取新数据")

        while self.is_server_running:
            try:
                client_socket, addr = self.server_socket.accept()  
                client_thread = threading.Thread(
                    target=self.handle_request, args=(client_socket,)
                )
                client_thread.start()

            except OSError:
                break

    def handle_request(self, client_socket):
        """
        处理服务器信息
        """
        # 接收传入数据的长度
        length_bytes = client_socket.recv(4)
        data_length = int.from_bytes(length_bytes, byteorder="big")

        # 接收实际数据
        request_data = client_socket.recv(data_length)
        decoded_data = request_data.decode()

        json_data = json.loads(decoded_data)
        title = json_data["info_title"].lower()

        func = getattr(self, "handle_" + title + "_request")
        result = func(json_data)

        response_data = json.dumps(result, cls=MyJson).encode("utf-8")
        client_socket.sendall(response_data)
        client_socket.close()

    def sub_symbol(self, instruments, exchanges) -> None:
        """订阅数据"""
        symbol_list = instruments if isinstance(instruments, list) else [instruments]
        exchange_list = exchanges if isinstance(exchanges, list) else [exchanges]
        
        for symbol, exchange in zip(symbol_list, exchange_list):
                key = symbol + "." + exchange
                if key not in self.sub_key:
                    ctaEngine.subMarketData(
                        {
                            "sid": self, 
                            "InstrumentID": str(symbol), 
                            "ExchangeID": str(exchange)
                        })
                    self.sub_key.append(key)

    def un_subsymbol(self, instruments, exchanges) -> None:
        """取消订阅"""
        symbol_list = instruments if isinstance(instruments, list) else [instruments]
        exchange_list = exchanges if isinstance(exchanges, list) else [exchanges]

        for symbol, exchange in zip(symbol_list, exchange_list):
            key = symbol + "." + exchange
            if key in self.sub_key:
                ctaEngine.unsubMarketData(
                    {
                        "sid": self, 
                        "InstrumentID": str(symbol), 
                        "ExchangeID": str(exchange)
                    })
                self.sub_key.remove(key)
                if self.tick_data.get(symbol):
                    self.tick_data.pop(symbol)

    def onTick(self, tick: TickData) -> None:
        if tick.datetime is not None and not isinstance(tick.datetime, str) and tick.lastPrice != 0:
            tick.datetime = tick.datetime.strftime("%Y-%m-%d %H:%M:%S.%f")
        else:
            return

        self.tick_data[tick.symbol] = tick.__dict__
        self.putEvent()

    def onOrder(self, order: OrderData, log: bool = False) -> None:
        """收到委托变化推送，发单成功也算委托变化"""
        self.order[order.memo] = order.__dict__
        status = order.status
        self.putEvent()

        if status == "未成交" or status == "部分成交":
            self.order_open_dict[order.memo] = order
        else:
            if self.order_open_dict.get(order.memo, False):
                self.order_open_dict.pop(order.memo)
    
    def onErr(self, error: dict) -> None:
        super().onErr(error)

        orderID = error["orderID"]
        for i in self.order.values():
            if isinstance(i, dict) and i["orderID"] == orderID:
                memo = i["memo"]
                # 如果是撤单已全部成交才会调用，其他错单不调用
                err_memo = error['errMsg'] 
                self.order[memo] = {"orderID": orderID, "vtOrderID": err_memo, "status": "错单", "memo": memo}

                if self.order_open_dict.get(memo, False):
                    self.order_open_dict.pop(memo)

    def handle_unsubsymbol_request(self, json_data: json) -> None:
        symbol_origin = json_data["symbol"]
        exchange = json_data["exchange"]

        symbol_list = [
            self.get_contract(ex, sym).underlyingSymbol if sym.endswith("Main") else sym
            for sym, ex in zip(symbol_origin, exchange)
        ]
        
        self.un_subsymbol(symbol_list,exchange)
        return "取消订阅"
    
    def handle_historydata_request(self, json_data: json) -> list:
        symbol_origin = json_data["symbol"]
        exchange = json_data["exchange"]

        symbol = (
            self.get_contract(exchange, symbol_origin).underlyingSymbol
            if symbol_origin.endswith("Main")
            else symbol_origin
        )

        minutes = int(json_data["minutes"])
        period = int(json_data["period"])

        if minutes < 1440:
            # 将天数切割为 6 天以内的单元
            time_gap = 6
            divisor = int(period / time_gap)
            days_list = [time_gap] * divisor
            if (remainder := period % time_gap) != 0:
                days_list.insert(0, remainder)

            # 分批次把历史数据取到本地
            bars_list = []
            start_date = json_data["startdate"]
            start_time = json_data["starttime"]
            for _days in days_list:
                bars: list = ctaEngine.getKLineData(
                    symbol, exchange, start_date, _days, 0, start_time, 1
                )

                bars.reverse()
                bars_list.extend(bars)
                start_date = self.translate_date(start_date, _days)

            return bars_list
        else:
            start_date = json_data["startdate"]
            return ctaEngine.getKLineData(
                symbol, exchange, start_date, 0, period
            )
        
    def translate_date(self, accept_date, days):
        start_date = datetime.strptime(accept_date, "%Y%m%d")
        new_start_date = start_date - timedelta(days=days)

        return new_start_date.strftime("%Y%m%d")
        
    def handle_order_request(self, json_data: json) -> dict:
        return self.order
        
    def handle_cancelorder_request(self, json_data: json) -> str:
        accept_orderid = str(json_data["orderID"])
        order = self.order_open_dict.get(accept_orderid, False)

        if order:
            ctaEngine.cancelOrder(order.orderID)
            self.order_open_dict.pop(accept_orderid)

        if accept_orderid in self.order.keys():
            result_pre = self.order[accept_orderid]
            
            status = result_pre["status"]
            if isinstance(result_pre, (int, float, str, complex)):
                result = "错单"
            elif isinstance(result_pre, dict):
                if status == "未成交":
                    result = "已报撤单，未成交"
                elif status == "已撤销":
                    result = status
                else:
                    result = f"无法撤单，当前订单状态{status}"
            else:
                result = result_pre
        else:
            result = "没有该发单信息"

        return result

    def handle_tick_request(self, json_data: json) -> dict:
        symbol_list = []
        exchange_list = []

        for key in json_data["key_list"]:
            symbol, exchange = key.split(".")
            if symbol.endswith("Main"):
                    symbol_main = self.get_contract(exchange, symbol).underlyingSymbol
                    self.tick_data[symbol] = self.tick_data.get(symbol_main)
                    key = symbol_main + "." + exchange
                    symbol = symbol_main
            symbol_list.append(symbol)
            exchange_list.append(exchange)

        self.sub_symbol(symbol_list, exchange_list)

        return self.tick_data

    def handle_sendorder_request(self, json_data: json) -> str:
        investorID = json_data.get("investor") or self.get_investor()
        accept_orderId = json_data["orderID"]
        symbol_origin = json_data["symbol"]
        exchange = json_data["exchange"]
        symbol = (
            self.get_contract(exchange, symbol_origin).underlyingSymbol
            if symbol_origin.endswith("Main")
            else symbol_origin
        )

        price = json_data["price"]
        volume = json_data["volume"]
        order_type = json_data["orderType"]

        return self.send_order_RTD(
                accept_orderid=accept_orderId,
                order_type=order_type,
                price=price,
                volume=volume,
                symbol=symbol,
                exchange=exchange,
                investorID=investorID,
            )
    
    def send_order_RTD(
        self, 
        accept_orderid: int, 
        order_type: str, 
        price: float, 
        volume: int, 
        symbol: str, 
        exchange: str, 
        investorID: str
    )-> str:
        if accept_orderid not in self.order.keys():
            self.trading = True

            orderID = self.sendOrder(
                orderType=order_type,
                price=price,
                volume=volume,
                symbol=symbol,
                exchange=exchange,
                investor=str(investorID),
                memo=accept_orderid,
            )

            self.order[accept_orderid] = {
                "memo":accept_orderid, 
                "orderID":orderID, 
                "status": "已发送" if orderID != -1 else "参数错误或网络异常"
            } 
            
        return "收到发单"

    def handle_position_request(self, json_data: json) -> dict:
        investor_list = investor_list = json_data["investor"]

        result = []

        for investor_origin in investor_list:
            investor = investor_origin or self.get_investor()
            pos = ctaEngine.getInvestorPosition(str(investor))
            for item in pos:
                if isinstance(item["HedgeFlag"], bytes):
                    item["HedgeFlag"] = item["HedgeFlag"].decode("UTF-8")

            result.extend(pos)

        return result

    def handle_account_request(self, json_data: json):
        investorID = json_data.get("investor") or self.get_investor()

        result = ctaEngine.getInvestorAccount(investorID)

        return result
    
    def handle_contract_request(self, json_data: json):
        result = {}

        for key in json_data["key_list"]:
            symbol, exchange = key.split(".")
            contract: dict = ctaEngine.getInstrument(exchange, symbol)
            result[key] = contract if contract["ExpireDate"] else None

        return result

    def handle_greeks_request(self, json_data: json):
        key_list = json_data["key_list"]
        for key in key_list:
            symbol, exchange = key.split(".")

            if key not in self.sub_key:
                self.sub_symbol(symbol, exchange)
            self.get_greeks_RTD(symbol, exchange)

        return self.options_dict
    
    def get_greeks_RTD(self, symbol: str, exchange: str) -> None:
        key = symbol + "." + exchange

        if not self.options_dict.get(key):
            self.options_dict[key] = self.get_greeks_contract(symbol, exchange)

        option_price, underlying_price = self.get_underlying_price(symbol, exchange)

        self.options_dict[key]["underlying_price"] = underlying_price
        self.options_dict[key]["option_price"] = option_price
    
    def get_greeks_contract(self, symbol: str, exchange: str) -> None:
        contract_info: dict = ctaEngine.getInstrument(exchange, symbol)

        if isinstance(contract_info["OptionsType"], bytes):
            contract_info["OptionsType"] = contract_info["OptionsType"].decode("UTF-8")

        if contract_info.get("OptionsType") not in {"1","2"}:
            return {"option_type": "NotOption",}   
        
        option_type = "Call" if contract_info["OptionsType"] == "1" else "Put"
        expire_date = contract_info.get("ExpireDate", "")

        expire_date = datetime.strptime(expire_date, "%Y%m%d")
        expire_date_days = (expire_date - datetime.now()).days + 1

        underlying_symbol = contract_info.get("UnderlyingInstrID", "")
        strike = contract_info.get("StrikePrice", 0)
        product = contract_info.get("ProductID", "")
    
        opposite_option_symbol = self.find_opposite_option_symbol(exchange, product, option_type, strike, underlying_symbol)
        symbol_to_sub = opposite_option_symbol if ctaEngine.getInstrument(exchange, underlying_symbol).get("Instrument", "") == "" else underlying_symbol

        self.sub_symbol(symbol_to_sub, exchange)

        return {
            "expire_date": expire_date,
            "expire_date_days": expire_date_days,
            "strike": strike,
            "underlying_symbol": underlying_symbol,
            "option_type": option_type,
            "underlying_price": 0.0,
            "option_price": 0.0,
            "price_tick": contract_info.get("PriceTick", 0),
            "instrument_name": contract_info.get("InstrumentName", ""),
            "multiple": contract_info.get("VolumeMultiple", 0),
            "min_limit_order_volume": contract_info.get("MinLimitOrderVolume", 0),
            "max_limit_order_volume": contract_info.get("MaxLimitOrderVolume", 0),
            "product": product,
            "opposite_option_symbol": opposite_option_symbol
        }
    
    def get_underlying_price(self, symbol: str, exchange: str) -> tuple:
        """用于获取期权标的价格，标的价格不存在时使用合成期货价格"""
        option_tick = self.tick_data.get(symbol, False)

        if not option_tick:
            return [0, 0]
        
        option_price = option_tick["lastPrice"]
        key = symbol + "." + exchange
        underlying_symbol = self.options_dict[key]["underlying_symbol"]
        underlying_tick = self.tick_data.get(underlying_symbol,None)
        opposite_option_symbol = self.options_dict[key]["opposite_option_symbol"]
        opposite_option_tick = self.tick_data.get(opposite_option_symbol, None)

        if underlying_tick is not None:
            underlying_price = underlying_tick["lastPrice"]

        elif opposite_option_tick is not None:
            # 没有标的, 且已订阅对应期权，使用合成期货价格
            strike = self.options_dict[key]["strike"]

            option_signal = 1 if self.options_dict[key]["option_type"] == "Call" else -1
            opposite_option_price = opposite_option_tick["lastPrice"]
            
            underlying_price = option_signal * (option_price - opposite_option_price) + strike

        return [option_price, underlying_price]

    def find_opposite_option_symbol(
            self, 
            exchange: str, 
            product: str, 
            option_type: str, 
            strike: str, 
            underlying_symbol: str
        )-> None:
        
        contract_raw = ctaEngine.getInstListByExchAndProduct(str(exchange), str(product))
        opposite_option_type = "1" if option_type == "Put" else "2"

        return next((
            i["Instrument"]
            for i in contract_raw
            if i["StrikePrice"] == strike 
                and i["OptionsType"] == opposite_option_type
                and i["UnderlyingInstrID"] == underlying_symbol
        ), None)
    
    def handle_products_request(self, json_data: json):
        exchange = json_data["exchange"]
        product = json_data["product"]
        contract_raw = ctaEngine.getInstListByExchAndProduct(str(exchange), str(product))
        
        return contract_raw
    