-
Notifications
You must be signed in to change notification settings - Fork 683
/
Copy pathbacktest.py
888 lines (818 loc) · 47.1 KB
/
backtest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'chengzhi'
import asyncio
import json
import math
import os
import time
from datetime import date, datetime
from typing import Union
import aiohttp
import requests
import pandas as pd
from tqsdk.channel import TqChan
from tqsdk.datetime import _get_trading_day_start_time, _get_trading_day_end_time, _get_trading_day_from_timestamp
from tqsdk.diff import _merge_diff, _get_obj
from tqsdk.entity import Entity
from tqsdk.exceptions import BacktestFinished
from tqsdk.objs import Kline, Tick
from tqsdk.rangeset import _rangeset_range_union, _rangeset_difference, _rangeset_union
from tqsdk.utils import _generate_uuid, _query_for_quote
class BtQuote(Entity):
""" Quote 是一个行情对象 """
def __init__(self, api):
self._api = api
self.price_tick: float = float("nan")
class TqBacktest(object):
"""
天勤回测类
将该类传入 TqApi 的构造函数, 则策略就会进入回测模式。
回测模式下 k线会在刚创建出来时和结束时分别更新一次, 在这之间 k线是不会更新的。
回测模式下 quote 的更新频率由所订阅的 tick 和 k线周期确定:
* 只要订阅了 tick, 则对应合约的 quote 就会使用 tick 生成, 更新频率也和 tick 一致, 但 **只有下字段** :
datetime/ask&bid_price1/ask&bid_volume1/last_price/highest/lowest/average/volume/amount/open_interest/
price_tick/price_decs/volume_multiple/max&min_limit&market_order_volume/underlying_symbol/strike_price
* 如果没有订阅 tick, 但是订阅了 k线, 则对应合约的 quote 会使用 k线生成, 更新频率和 k线的周期一致, 如果订阅了某个合约的多个周期的 k线,
则任一个周期的 k线有更新时, quote 都会更新. 使用 k线生成的 quote 的盘口由收盘价分别加/减一个最小变动单位, 并且 highest/lowest/average/amount
始终为 nan, volume 始终为0
* 如果即没有订阅 tick, 也没有订阅k线或 订阅的k线周期大于分钟线, 则 TqBacktest 会 **自动订阅分钟线** 来生成 quote
* 如果没有订阅 tick, 但是订阅了 k线, 则对应合约的 quote **只有下字段** :
datetime/ask&bid_price1/ask&bid_volume1/last_price/open_interest/
price_tick/price_decs/volume_multiple/max&min_limit&market_order_volume/underlying_symbol/strike_price
**注意** :如果未订阅 quote,模拟交易在下单时会自动为此合约订阅 quote ,根据回测时 quote 的更新规则,如果此合约没有订阅K线或K线周期大于分钟线 **则会自动订阅一个分钟线** 。
模拟交易要求报单价格大于等于对手盘价格才会成交, 例如下买单, 要求价格大于等于卖一价才会成交, 如果不能立即成交则会等到下次行情更新再重新判断。
回测模式下 wait_update 每次最多推进一个行情时间。
回测结束后会抛出 BacktestFinished 例外。
对 **组合合约** 进行回测时需注意:只能通过订阅 tick 数据来回测,不能订阅K线,因为K线是由最新价合成的,而交易所发回的组合合约数据中无最新价。
"""
def __init__(self, start_dt: Union[date, datetime], end_dt: Union[date, datetime]) -> None:
"""
创建天勤回测类
Args:
start_dt (date/datetime): 回测起始时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
end_dt (date/datetime): 回测结束时间, 如果类型为 date 则指的是交易日, 如果为 datetime 则指的是具体时间点
"""
if isinstance(start_dt, datetime):
self._start_dt = int(start_dt.timestamp() * 1e9)
elif isinstance(start_dt, date):
self._start_dt = _get_trading_day_start_time(
int(datetime(start_dt.year, start_dt.month, start_dt.day).timestamp()) * 1000000000)
else:
raise Exception("回测起始时间(start_dt)类型 %s 错误, 请检查 start_dt 数据类型是否填写正确" % (type(start_dt)))
if isinstance(end_dt, datetime):
self._end_dt = int(end_dt.timestamp() * 1e9)
elif isinstance(end_dt, date):
self._end_dt = _get_trading_day_end_time(
int(datetime(end_dt.year, end_dt.month, end_dt.day).timestamp()) * 1000000000)
else:
raise Exception("回测结束时间(end_dt)类型 %s 错误, 请检查 end_dt 数据类型是否填写正确" % (type(end_dt)))
self._current_dt = self._start_dt
# 记录当前的交易日 开始时间/结束时间
self._trading_day_start = _get_trading_day_start_time(_get_trading_day_from_timestamp(self._current_dt))
self._trading_day_end = _get_trading_day_end_time(_get_trading_day_from_timestamp(self._current_dt))
async def _run(self, api, sim_send_chan, sim_recv_chan, md_send_chan, md_recv_chan):
"""回测task"""
self._api = api
# 下载历史主连合约信息
rsp = requests.get(os.getenv("TQ_CONT_TABLE_URL", "https://files.shinnytech.com/continuous_table.json"),
headers=self._api._base_headers,
timeout=30)
rsp.raise_for_status()
self._continuous_table = {f"KQ.m@{k}": pd.DataFrame(data=v, columns=['date', 'underlying'])
for k, v in rsp.json().items()}
self._logger = api._logger.getChild("TqBacktest") # 调试信息输出
self._sim_send_chan = sim_send_chan
self._sim_recv_chan = sim_recv_chan
self._md_send_chan = md_send_chan
self._md_recv_chan = md_recv_chan
self._pending_peek = False
self._data = Entity() # 数据存储
self._data._instance_entity([])
self._prototype = {
"quotes": {
"#": BtQuote(self._api), # 行情的数据原型
},
"klines": {
"*": {
"*": {
"data": {
"@": Kline(self._api), # K线的数据原型
}
}
}
},
"ticks": {
"*": {
"data": {
"@": Tick(self._api), # Tick的数据原型
}
}
}
}
self._sended_to_api = {} # 已经发给 api 的 rangeset (symbol, dur),只记录了 kline
self._serials = {} # 所有用户请求的 chart 序列,如果用户订阅行情,默认请求 1 分钟 Kline
# gc 是会循环 self._serials,来计算用户需要的数据,self._serials 不应该被删除,
self._generators = {} # 所有用户请求的 chart 序列相应的 generator 对象,创建时与 self._serials 一一对应,会在一个序列计算到最后一根 kline 时被删除
self._had_any_generator = False # 回测过程中是否有过 generator 对象
self._sim_recv_chan_send_count = 0 # 统计向下游发送的 diff 的次数,每 1w 次执行一次 gc
self._quotes = {} # 记录 min_duration 记录某一合约的最小duration; sended_init_quote 是否已经过这个合约的初始行情
self._diffs = []
self._is_first_send = True
md_task = self._api.create_task(self._md_handler())
try:
await self._send_snapshot()
async for pack in self._sim_send_chan:
if pack["aid"] == "ins_query":
await self._md_send_chan.send(pack)
# 回测 query 不为空时需要ensure_query
# 1. 在api初始化时会发送初始化请求(2.5.0版本开始已经不再发送初始化请求),接着会发送peek_message,如果这里没有等到结果,那么在收到 peek_message 的时候,会发现没有数据需要发送,回测结束
# 2. api在发送请求后,会调用 wait_update 更新数据,如果这里没有等到结果,行情可能会被推进
# query 为空时,表示清空数据的请求,这个可以直接发出去,不需要等到收到回复
if pack["query"] != "":
await self._ensure_query(pack)
await self._send_diff()
elif pack["aid"] == "subscribe_quote":
# todo: 回测时,用户如果先订阅日线,再订阅行情,会直接返回以日线 datetime 标识的行情信息,而不是当前真正的行情时间
self._diffs.append({
"ins_list": pack["ins_list"]
})
for ins in pack["ins_list"].split(","):
await self._ensure_quote(ins)
await self._send_diff() # 处理上一次未处理的 peek_message
elif pack["aid"] == "set_chart":
if pack["ins_list"]:
# 回测模块中已保证每次将一个行情时间的数据全部发送给api,因此更新行情时 保持与初始化时一样的charts信息(即不作修改)
self._diffs.append({
"charts": {
pack["chart_id"]: {
# 两个id设置为0:保证api在回测中判断此值时不是-1,即直接通过对数据接收完全的验证
"left_id": 0,
"right_id": 0,
"more_data": False, # 直接发送False给api,表明数据发送完全,使api中通过数据接收完全的验证
"state": pack
}
}
})
await self._ensure_serial(pack["ins_list"], pack["duration"], pack["chart_id"])
else:
self._diffs.append({
"charts": {
pack["chart_id"]: None
}
})
await self._send_diff() # 处理上一次未处理的 peek_message
elif pack["aid"] == "peek_message":
self._pending_peek = True
await self._send_diff()
finally:
# 关闭所有 generator
for s in self._generators.values():
await s.aclose()
md_task.cancel()
await asyncio.gather(md_task, return_exceptions=True)
async def _md_handler(self):
async for pack in self._md_recv_chan:
await self._md_send_chan.send({
"aid": "peek_message"
})
for d in pack.get("data", []):
_merge_diff(self._data, d, self._prototype, False)
# 收到的 quotes 转发给下游
quotes = d.get("quotes", {})
if quotes:
quotes = self._update_valid_quotes(quotes) # 删去回测 quotes 不应该下发的字段
self._diffs.append({"quotes": quotes})
# 收到的 symbols 应该转发给下游
if d.get("symbols"):
self._diffs.append({"symbols": d["symbols"]})
def _update_valid_quotes(self, quotes):
# 从 quotes 返回只剩余合约信息的字段的 quotes,防止发生未来数据发送给下游
# backtest 模块会生成的数据
invalid_keys = {f"{d}{i+1}" for d in ['ask_price', 'ask_volume', 'bid_price', 'bid_volume'] for i in range(5)}
invalid_keys.union({'datetime', 'last_price', 'highest', 'lowest', 'average', 'volume', 'amount', 'open_interest'})
# backtest 模块不会生成的数据,下游服务也不应该收到的数据
invalid_keys.union({'open', 'close', 'settlement', 'lowest', 'lower_limit', 'upper_limit', 'pre_open_interest', 'pre_settlement', 'pre_close', 'expired'})
for symbol, quote in quotes.items():
[quote.pop(k, None) for k in invalid_keys]
if symbol.startswith("KQ.m"):
quote.pop("underlying_symbol", None)
if quote.get('expire_datetime'):
# 先删除所有的 quote 的 expired 字段,只在有 expire_datetime 字段时才会添加 expired 字段
quote['expired'] = quote.get('expire_datetime') * 1e9 <= self._trading_day_start
return quotes
async def _send_snapshot(self):
"""发送初始合约信息"""
async with TqChan(self._api, last_only=True) as update_chan: # 等待与行情服务器连接成功
self._data["_listener"].add(update_chan)
while self._data.get("mdhis_more_data", True):
await update_chan.recv()
# 发送初始行情(合约信息截面)时
quotes = {}
for ins, quote in self._data["quotes"].items():
if not ins.startswith("_"):
trading_time = quote.get("trading_time", {})
quotes[ins] = {
"open": None, # 填写None: 删除api中的这个字段
"close": None,
"settlement": None,
"lower_limit": None,
"upper_limit": None,
"pre_open_interest": None,
"pre_settlement": None,
"pre_close": None,
"ins_class": quote.get("ins_class", ""),
"instrument_id": quote.get("instrument_id", ""),
"exchange_id": quote.get("exchange_id", ""),
"margin": quote.get("margin"), # 用于内部实现模拟交易, 不作为api对外可用数据(即 Quote 类中无此字段)
"commission": quote.get("commission"), # 用于内部实现模拟交易, 不作为api对外可用数据(即 Quote 类中无此字段)
"price_tick": quote["price_tick"],
"price_decs": quote["price_decs"],
"volume_multiple": quote["volume_multiple"],
"max_limit_order_volume": quote["max_limit_order_volume"],
"max_market_order_volume": quote["max_market_order_volume"],
"min_limit_order_volume": quote["min_limit_order_volume"],
"min_market_order_volume": quote["min_market_order_volume"],
"underlying_symbol": quote["underlying_symbol"],
"strike_price": quote["strike_price"],
"expired": quote.get('expire_datetime', float('nan')) <= self._trading_day_start, # expired 默认值就是 False
"trading_time": {"day": trading_time.get("day", []), "night": trading_time.get("night", [])},
"expire_datetime": quote.get("expire_datetime"),
"delivery_month": quote.get("delivery_month"),
"delivery_year": quote.get("delivery_year"),
"option_class": quote.get("option_class", ""),
"product_id": quote.get("product_id", ""),
}
# 修改历史主连合约信息
dt = datetime.fromtimestamp(self._trading_day_end / 1e9).strftime("%Y%m%d")
quotes.update(self._get_history_cont_quotes(dt))
self._diffs.append({
"quotes": quotes,
"ins_list": "",
"mdhis_more_data": False,
"_tqsdk_backtest": {
"start_dt": self._start_dt,
"current_dt": self._current_dt,
"end_dt": self._end_dt
}
})
async def _send_diff(self):
"""发送数据到 api, 如果 self._diffs 不为空则发送 self._diffs, 不推进行情时间, 否则将时间推进一格, 并发送对应的行情"""
if self._pending_peek:
if not self._diffs:
quotes = await self._generator_diffs(False)
else:
quotes = await self._generator_diffs(True)
for ins, diff in quotes.items():
self._quotes[ins]["sended_init_quote"] = True
for d in diff:
self._diffs.append({
"quotes": {
ins: d
}
})
if self._diffs:
# 发送数据集中添加 backtest 字段,开始时间、结束时间、当前时间,表示当前行情推进是由 backtest 推进
if self._is_first_send:
self._diffs.append({
"_tqsdk_backtest": {
"start_dt": self._start_dt,
"current_dt": self._current_dt,
"end_dt": self._end_dt
}
})
self._is_first_send = False
else:
self._diffs.append({
"_tqsdk_backtest": {
"current_dt": self._current_dt
}
})
# 切换交易日,将历史的主连合约信息添加的 diffs
if self._current_dt > self._trading_day_end:
# 使用交易日结束时间,每个交易日切换只需要计算一次交易日结束时间
# 相比发送 diffs 前每次都用 _current_dt 计算当前交易日,计算次数更少
self._trading_day_start = _get_trading_day_start_time(_get_trading_day_from_timestamp(self._current_dt))
self._trading_day_end = _get_trading_day_end_time(_get_trading_day_from_timestamp(self._current_dt))
self._diffs.append({
"quotes": self._get_history_cont_quotes(
datetime.fromtimestamp(self._trading_day_end / 1e9).strftime("%Y%m%d")
)
})
self._diffs.append({
"quotes": {k: {'expired': v.get('expire_datetime', float('nan')) <= self._trading_day_start}
for k, v in self._data.get('quotes').items()}
})
self._sim_recv_chan_send_count += 1
if self._sim_recv_chan_send_count > 10000:
self._sim_recv_chan_send_count = 0
self._diffs.append(self._gc_data())
rtn_data = {
"aid": "rtn_data",
"data": self._diffs,
}
self._diffs = []
self._pending_peek = False
await self._sim_recv_chan.send(rtn_data)
async def _generator_diffs(self, keep_current):
"""
keep_current 为 True 表示不会推进行情,为 False 表示需要推进行情
即 self._diffs 为 None 并且 keep_current = True 会推进行情
"""
quotes = {}
while self._generators:
# self._generators 存储了 generator,self._serials 记录一些辅助的信息
min_request_key = min(self._generators.keys(), key=lambda serial: self._serials[serial]["timestamp"])
timestamp = self._serials[min_request_key]["timestamp"] # 所有已订阅数据中的最小行情时间
quotes_diff = self._serials[min_request_key]["quotes"]
if timestamp < self._current_dt and self._quotes.get(min_request_key[0], {}).get("sended_init_quote"):
# 先订阅 A 合约,再订阅 A 合约日线,那么 A 合约的行情时间会回退: 2021-01-04 09:31:59.999999 -> 2021-01-01 18:00:00.000000
# 如果当前 timestamp 小于 _current_dt,那么这个 quote_diff 不需要发到下游
# 如果先订阅 A 合约(有夜盘),时间停留在夜盘开始时间, 再订阅 B 合约(没有夜盘),那么 B 合约的行情(前一天收盘时间)应该发下去,
# 否则 get_quote(B) 等到收到行情才返回,会直接把时间推进到第二天白盘。
quotes_diff = None
# 推进时间,一次只会推进最多一个(补数据时有可能是0个)行情时间,并确保<=该行情时间的行情都被发出
# 如果行情时间大于当前回测时间 则 判断是否diff中已有数据;否则表明此行情时间的数据未全部保存在diff中,则继续append
if timestamp > self._current_dt:
if self._diffs or keep_current: # 如果diffs中已有数据:退出循环并发送数据给下游api
break
else:
self._current_dt = timestamp # 否则将回测时间更新至最新行情时间
diff = self._serials[min_request_key]["diff"]
self._diffs.append(diff)
# klines 请求,需要记录已经发送 api 的数据
for symbol in diff.get("klines", {}):
for dur in diff["klines"][symbol]:
for kid in diff["klines"][symbol][dur]["data"]:
rs = self._sended_to_api.setdefault((symbol, int(dur)), [])
kid = int(kid)
self._sended_to_api[(symbol, int(dur))] = _rangeset_range_union(rs, (kid, kid + 1))
quote_info = self._quotes[min_request_key[0]]
if quotes_diff and (quote_info["min_duration"] != 0 or min_request_key[1] == 0):
quotes[min_request_key[0]] = quotes_diff
await self._fetch_serial(min_request_key)
if self._had_any_generator and not self._generators and not self._diffs: # 当无可发送数据时则抛出BacktestFinished例外,包括未订阅任何行情 或 所有已订阅行情的最后一笔行情获取完成
self._api._print("回测结束")
self._logger.debug("backtest finished")
if self._current_dt < self._end_dt:
self._current_dt = 2145888000000000000 # 一个远大于 end_dt 的日期 20380101
await self._sim_recv_chan.send({
"aid": "rtn_data",
"data": [{
"_tqsdk_backtest": {
"start_dt": self._start_dt,
"current_dt": self._current_dt,
"end_dt": self._end_dt
}
}]
})
await self._api._wait_until_idle()
raise BacktestFinished(self._api) from None
return quotes
def _get_history_cont_quotes(self, dt):
quotes = {}
for symbol, df in self._continuous_table.items():
le = df['date'].le(dt)
if le.any():
quotes[symbol] = {"underlying_symbol": df[le].iloc[-1]['underlying']}
return quotes
async def _ensure_serial(self, ins, dur, chart_id=None):
if (ins, dur) not in self._serials:
quote = self._quotes.setdefault(ins, { # 在此处设置 min_duration: 每次生成K线的时候会自动生成quote, 记录某一合约的最小duration
"min_duration": dur
})
quote["min_duration"] = min(quote["min_duration"], dur)
self._serials[(ins, dur)] = {
"chart_id_set": {chart_id} if chart_id else set() # 记录当前 serial 对应的 chart_id
}
self._generators[(ins, dur)] = self._gen_serial(ins, dur)
self._had_any_generator = True
await self._fetch_serial((ins, dur))
elif chart_id:
self._serials[(ins, dur)]["chart_id_set"].add(chart_id)
async def _ensure_query(self, pack):
"""一定收到了对应 query 返回的包"""
query_pack = {"query": pack["query"]}
if query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
return
async with TqChan(self._api, last_only=True) as update_chan:
self._data["_listener"].add(update_chan)
while not query_pack.items() <= self._data.get("symbols", {}).get(pack["query_id"], {}).items():
await update_chan.recv()
async def _ensure_quote(self, ins):
# 在接新版合约服务器后,合约信息程序运行过程中查询得到的,这里不再能保证合约一定存在,需要添加 quote 默认值
quote = _get_obj(self._data, ["quotes", ins], BtQuote(self._api))
if math.isnan(quote.get("price_tick")):
query_pack = _query_for_quote(ins)
await self._md_send_chan.send(query_pack)
async with TqChan(self._api, last_only=True) as update_chan:
quote["_listener"].add(update_chan)
while math.isnan(quote.get("price_tick")):
await update_chan.recv()
if ins not in self._quotes or self._quotes[ins]["min_duration"] > 60000000000:
await self._ensure_serial(ins, 60000000000)
async def _fetch_serial(self, key):
s = self._serials[key]
try:
s["timestamp"], s["diff"], s["quotes"] = await self._generators[key].__anext__()
except StopAsyncIteration:
del self._generators[key] # 删除一个行情时间超过结束时间的 generator
async def _gen_serial(self, ins, dur):
"""k线/tick 序列的 async generator, yield 出来的行情数据带有时间戳, 因此 _send_diff 可以据此归并"""
# 先定位左端点, focus_datetime 是 lower_bound ,这里需要的是 upper_bound
# 因此将 view_width 和 focus_position 设置成一样,这样 focus_datetime 所对应的 k线刚好位于屏幕外
# 使用两个长度为 8964 的 chart,去缓存/回收下游需要的数据
chart_id_a = _generate_uuid("PYSDK_backtest")
chart_id_b = _generate_uuid("PYSDK_backtest")
chart_info = {
"aid": "set_chart",
"chart_id": chart_id_a,
"ins_list": ins,
"duration": dur,
"view_width": 8964, # 设为8964原因:可满足用户所有的订阅长度,并在backtest中将所有的 相同合约及周期 的K线用同一个serial存储
"focus_datetime": int(self._current_dt),
"focus_position": 8964,
}
chart_a = _get_obj(self._data, ["charts", chart_id_a])
chart_b = _get_obj(self._data, ["charts", chart_id_b])
symbol_list = ins.split(',')
current_id = None # 当前数据指针
if dur == 0:
serials = [_get_obj(self._data, ["ticks", symbol_list[0]])]
else:
serials = [_get_obj(self._data, ["klines", s, str(dur)]) for s in symbol_list]
async with TqChan(self._api, last_only=True) as update_chan:
for serial in serials:
serial["_listener"].add(update_chan)
chart_a["_listener"].add(update_chan)
chart_b["_listener"].add(update_chan)
await self._md_send_chan.send(chart_info.copy())
try:
async for _ in update_chan:
chart = _get_obj(self._data, ["charts", chart_info["chart_id"]])
if not (chart_info.items() <= _get_obj(chart, ["state"]).items()):
# 当前请求还没收齐回应, 不应继续处理
continue
left_id = chart.get("left_id", -1)
right_id = chart.get("right_id", -1)
if (left_id == -1 and right_id == -1) or chart.get("more_data", True):
continue # 定位信息还没收到, 数据没有完全收到
last_id = serials[0].get("last_id", -1)
if last_id == -1:
continue # 数据序列还没收到
if self._data.get("mdhis_more_data", True):
self._data["_listener"].add(update_chan)
continue
else:
self._data["_listener"].discard(update_chan)
if current_id is None:
current_id = max(left_id, 0)
# 发送下一段 chart 8964 根 kline
chart_info["chart_id"] = chart_id_b if chart_info["chart_id"] == chart_id_a else chart_id_a
chart_info["left_kline_id"] = right_id
chart_info.pop("focus_datetime", None)
chart_info.pop("focus_position", None)
await self._md_send_chan.send(chart_info.copy())
while True:
if current_id > last_id:
# 当前 id 已超过 last_id
return
# 将订阅的8964长度的窗口中的数据都遍历完后,退出循环,然后再次进入并处理下一窗口数据
if current_id > right_id:
break
item = {k: v for k, v in serials[0]["data"].get(str(current_id), {}).items()}
if dur == 0:
diff = {
"ticks": {
ins: {
"last_id": current_id,
"data": {
str(current_id): item,
str(current_id - 8964): None,
}
}
}
}
if item["datetime"] > self._end_dt: # 超过结束时间
return
yield item["datetime"], diff, self._get_quotes_from_tick(item)
else:
timestamp = item["datetime"] if dur < 86400000000000 else _get_trading_day_start_time(
item["datetime"])
if timestamp > self._end_dt: # 超过结束时间
return
binding = serials[0].get("binding", {})
diff = {
"klines": {
symbol_list[0]: {
str(dur): {
"last_id": current_id,
"data": {
str(current_id): {
"datetime": item["datetime"],
"open": item["open"],
"high": item["open"],
"low": item["open"],
"close": item["open"],
"volume": 0,
"open_oi": item["open_oi"],
"close_oi": item["open_oi"],
}
}
}
}
}
}
for chart_id in self._serials[(ins, dur)]["chart_id_set"]:
diff["charts"] = {
chart_id: {
"right_id": current_id # api 中处理多合约 kline 需要 right_id 信息
}
}
for i, symbol in enumerate(symbol_list):
if i == 0:
diff_binding = diff["klines"][symbol_list[0]][str(dur)].setdefault("binding", {})
continue
other_id = binding.get(symbol, {}).get(str(current_id), -1)
if other_id >= 0:
diff_binding[symbol] = {str(current_id): str(other_id)}
other_item = serials[i]["data"].get(str(other_id), {})
diff["klines"][symbol] = {
str(dur): {
"last_id": other_id,
"data": {
str(other_id): {
"datetime": other_item["datetime"],
"open": other_item["open"],
"high": other_item["open"],
"low": other_item["open"],
"close": other_item["open"],
"volume": 0,
"open_oi": other_item["open_oi"],
"close_oi": other_item["open_oi"],
}
}
}
}
yield timestamp, diff, self._get_quotes_from_kline_open(
self._data["quotes"][symbol_list[0]],
timestamp,
item) # K线刚生成时的数据都为开盘价
timestamp = item["datetime"] + dur - 1000 \
if dur < 86400000000000 else _get_trading_day_start_time(item["datetime"] + dur) - 1000
if timestamp > self._end_dt: # 超过结束时间
return
diff = {
"klines": {
symbol_list[0]: {
str(dur): {
"data": {
str(current_id): item,
}
}
}
}
}
for i, symbol in enumerate(symbol_list):
if i == 0:
continue
other_id = binding.get(symbol, {}).get(str(current_id), -1)
if other_id >= 0:
diff["klines"][symbol] = {
str(dur): {
"data": {
str(other_id): {k: v for k, v in
serials[i]["data"].get(str(other_id), {}).items()}
}
}
}
yield timestamp, diff, self._get_quotes_from_kline(self._data["quotes"][symbol_list[0]],
timestamp,
item) # K线结束时生成quote数据
current_id += 1
finally:
# 释放chart资源
chart_info["ins_list"] = ""
await self._md_send_chan.send(chart_info.copy())
chart_info["chart_id"] = chart_id_b if chart_info["chart_id"] == chart_id_a else chart_id_a
await self._md_send_chan.send(chart_info.copy())
def _gc_data(self):
# api 应该删除的数据 diff
need_rangeset = {}
for ins, dur in self._serials:
if dur == 0: # tick 在发送数据过程中已经回收内存
continue
symbol_list = ins.split(',')
for s in symbol_list:
need_rangeset.setdefault((s, dur), [])
main_serial = _get_obj(self._data, ["klines", symbol_list[0], str(dur)])
main_serial_rangeset = self._sended_to_api.get((symbol_list[0], dur), []) # 此 request 还没有给 api 发送过任何数据时为 []
if not main_serial_rangeset:
continue
last_id = main_serial_rangeset[-1][-1] - 1
assert last_id > -1
need_rangeset[(symbol_list[0], dur)] = _rangeset_range_union(need_rangeset[(symbol_list[0], dur)],
(last_id - 8963, last_id + 1))
for symbol in symbol_list[1:]:
symbol_need_rangeset = []
symbol_binding = main_serial.get("binding", {}).get(symbol, {})
if symbol_binding:
for i in range(last_id - 8963, last_id + 1):
other_id = symbol_binding.get(str(i))
if other_id:
symbol_need_rangeset = _rangeset_range_union(symbol_need_rangeset, (other_id, other_id + 1))
if symbol_need_rangeset:
need_rangeset[(symbol, dur)] = _rangeset_union(need_rangeset[(symbol, dur)], symbol_need_rangeset)
gc_rangeset = {}
for key, rs in self._sended_to_api.items():
gc_rangeset[key] = _rangeset_difference(rs, need_rangeset.get(key, []))
# 更新 self._sended_to_api
for key, rs in gc_rangeset.items():
self._sended_to_api[key] = _rangeset_difference(self._sended_to_api[key], rs)
gc_klines_diff = {}
for (symbol, dur), rs in gc_rangeset.items():
gc_klines_diff.setdefault(symbol, {})
gc_klines_diff[symbol][str(dur)] = {"data": {}}
serial = _get_obj(self._data, ["klines", symbol, str(dur)])
serial_binding = serial.get("binding", None)
if serial_binding:
gc_klines_diff[symbol][str(dur)]["binding"] = {s: {} for s in serial_binding.keys()}
for start_id, end_id in rs:
for i in range(start_id, end_id):
gc_klines_diff[symbol][str(dur)]["data"][str(i)] = None
if serial_binding:
for s, s_binding in serial_binding.items():
gc_klines_diff[symbol][str(dur)]["binding"][s][str(i)] = None
return {"klines": gc_klines_diff}
@staticmethod
def _get_quotes_from_tick(tick):
quote = {k: v for k, v in tick.items()}
quote["datetime"] = datetime.fromtimestamp(tick["datetime"] / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f")
return [quote]
@staticmethod
def _get_quotes_from_kline_open(info, timestamp, kline):
return [
{ # K线刚生成时的数据都为开盘价
"datetime": datetime.fromtimestamp(timestamp / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f"),
"ask_price1": kline["open"] + info["price_tick"],
"ask_volume1": 1,
"bid_price1": kline["open"] - info["price_tick"],
"bid_volume1": 1,
"last_price": kline["open"],
"highest": float("nan"),
"lowest": float("nan"),
"average": float("nan"),
"volume": 0,
"amount": float("nan"),
"open_interest": kline["open_oi"],
},
]
@staticmethod
def _get_quotes_from_kline(info, timestamp, kline):
"""
分为三个包发给下游:
1. 根据 diff 协议,对于用户收到的最终结果没有影响
2. TqSim 撮合交易会按顺序处理收到的包,分别比较 high、low、close 三个价格对应的买卖价
3. TqSim 撮合交易只用到了买卖价,所以最新价只产生一次 close,而不会发送三次
"""
return [
{
"datetime": datetime.fromtimestamp(timestamp / 1e9).strftime("%Y-%m-%d %H:%M:%S.%f"),
"ask_price1": kline["high"] + info["price_tick"],
"ask_volume1": 1,
"bid_price1": kline["high"] - info["price_tick"],
"bid_volume1": 1,
"last_price": kline["close"],
"highest": float("nan"),
"lowest": float("nan"),
"average": float("nan"),
"volume": 0,
"amount": float("nan"),
"open_interest": kline["close_oi"],
},
{
"ask_price1": kline["low"] + info["price_tick"],
"bid_price1": kline["low"] - info["price_tick"],
},
{
"ask_price1": kline["close"] + info["price_tick"],
"bid_price1": kline["close"] - info["price_tick"],
}
]
class TqReplay(object):
"""天勤复盘类"""
def __init__(self, replay_dt: date):
"""
除了传统的回测模式以外,TqSdk 提供独具特色的复盘模式,它与回测模式有以下区别
1.复盘模式为时间驱动,回测模式为事件驱动
复盘模式下,你可以指定任意一天交易日,后端行情服务器会传输用户订阅合约的当天的所有历史行情数据,重演当天行情,而在回测模式下,我们根据用户订阅的合约周期数据来进行推送
因此在复盘模式下K线更新和实盘一模一样,而回测模式下就算订阅了 Tick 数据,回测中任意周期 K 线最后一根的 close 和其他数据也不会随着 Tick 更新而更新,而是随着K线频率生成和结束时更新一次
2.复盘和回测的行情速度
因为两者的驱动机制不同,回测会更快,但是我们在复盘模式下也提供行情速度调节功能,可以结合web_gui来实现
3.复盘目前只支持单日复盘
因为复盘提供对应合约当日全部历史行情数据,对后端服务器会有较大压力,目前只支持复盘模式下选择单日进行复盘
Args:
replay_dt (date): 指定复盘交易日
"""
if isinstance(replay_dt, date):
self._replay_dt = replay_dt
else:
raise Exception("复盘时间(dt)类型 %s 错误, 请检查 dt 数据类型是否填写正确" % (type(replay_dt)))
if self._replay_dt.weekday() >= 5:
# 0~6, 检查周末[5,6] 提前抛错退出
raise Exception("无法创建复盘服务器,请检查复盘日期后重试。")
self._default_speed = 1
self._api = None
def _create_server(self, api):
self._api = api
self._logger = api._logger.getChild("TqReplay") # 调试信息输出
self._logger.debug('replay prepare', replay_dt=self._replay_dt)
session = self._prepare_session()
self._session_url = "http://%s:%d/t/rmd/replay/session/%s" % (
session["ip"], session["session_port"], session["session"])
self._ins_url = "http://%s:%d/t/rmd/replay/session/%s/symbol" % (
session["ip"], session["session_port"], session["session"])
self._md_url = "ws://%s:%d/t/rmd/front/mobile" % (session["ip"], session["gateway_web_port"])
self._server_status = None
self._server_status = self._wait_server_status("running", 60)
if self._server_status == "running":
self._logger.debug('replay start successed', replay_dt=self._replay_dt)
return self._ins_url, self._md_url
else:
self._logger.debug('replay start failed', replay_dt=self._replay_dt)
raise Exception("无法创建复盘服务器,请检查复盘日期后重试。")
async def _run(self):
try:
self._send_chan = TqChan(self._api)
self._send_chan.send_nowait({"aid": "ratio", "speed": self._default_speed})
_senddata_task = self._api.create_task(self._senddata_handler())
while True:
await self._send_chan.send({"aid": "heartbeat"})
await asyncio.sleep(30)
finally:
await self._send_chan.close()
_senddata_task.cancel()
await asyncio.gather(_senddata_task, return_exceptions=True)
def _prepare_session(self):
create_session_url = "http://replay.api.shinnytech.com/t/rmd/replay/create_session"
response = requests.post(create_session_url,
headers=self._api._base_headers,
data=json.dumps({'dt': self._replay_dt.strftime("%Y%m%d")}),
timeout=5)
if response.status_code == 200:
return json.loads(response.content)
else:
raise Exception("创建复盘服务器失败,请检查复盘日期后重试。")
def _wait_server_status(self, target_status, timeout):
"""等服务器状态为 target_status,超时时间 timeout 秒"""
deadline = time.time() + timeout
server_status = self._get_server_status()
while deadline > time.time():
if target_status == server_status:
break
else:
time.sleep(1)
server_status = self._get_server_status()
return server_status
def _get_server_status(self):
try:
response = requests.get(self._session_url,
headers=self._api._base_headers,
timeout=5)
if response.status_code == 200:
return json.loads(response.content)["status"]
else:
raise Exception("无法创建复盘服务器,请检查复盘日期后重试。")
except requests.exceptions.ConnectionError as e:
# 刚开始 _session_url 还不能访问的时候~
return None
async def _senddata_handler(self):
try:
session = aiohttp.ClientSession(headers=self._api._base_headers)
async for data in self._send_chan:
await session.post(self._session_url, data=json.dumps(data))
finally:
await session.post(self._session_url, data=json.dumps({"aid": "terminate"}))
await session.close()
def set_replay_speed(self, speed: float = 10.0) -> None:
"""
调整复盘服务器行情推进速度
Args:
speed (float): 复盘服务器行情推进速度, 默认为 10.0
Example::
from datetime import date
from tqsdk import TqApi, TqAuth, TqReplay
replay = TqReplay(date(2020, 9, 10))
api = TqApi(backtest=replay, auth=("信易账户,账户密码"))
replay.set_replay_speed(3.0)
quote = api.get_quote("SHFE.cu2012")
while True:
api.wait_update()
if api.is_changing(quote):
print("最新价", quote.datetime, quote.last_price)
"""
if self._api:
self._send_chan.send_nowait({"aid": "ratio", "speed": speed})
else:
# _api 未初始化,只记录用户设定的速度,在复盘服务器启动完成后,发动请求
self._default_speed = speed