Skip to content

Commit 651f0ee

Browse files
committed
feat: advance cover price set
fix: onclose cover issue refactor: abstract strategy
1 parent 8fc5efa commit 651f0ee

11 files changed

+147
-86
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,7 @@ dmypy.json
132132

133133
# misc
134134
login.json
135-
position.txt
135+
position.txt
136+
137+
# sjtrade stratages
138+
sjtrade/stratages

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ dynamic = ["version", "description"]
1010
dependencies = [
1111
"shioaji>=0.3.4.dev7",
1212
"loguru",
13+
"rs2py",
1314
]
1415
requires-python = ">=3.6"
1516
classifiers = [

sjtrade/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
trading with shioaji
33
"""
44

5-
__version__ = "0.3.1"
5+
__version__ = "0.4.0"
66

77
def inject_env():
88
import os

sjtrade/data.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class Snapshot:
6+
price: float
7+
bid: float = 0
8+
ask: float = 0

sjtrade/position.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import shioaji as sj
2-
3-
from typing import Callable, Dict, List, Optional, Union
2+
from typing import List
43
from threading import Lock
54
from dataclasses import dataclass, field
65
from shioaji.constant import (
@@ -23,7 +22,7 @@ class PositionCond:
2322
entry_price: List[PriceSet]
2423
stop_loss_price: List[PriceSet]
2524
stop_profit_price: List[PriceSet]
26-
cover_price: List[PriceSet]
25+
cover_price: List[PriceSet] = field(default_factory=list)
2726

2827

2928
@dataclass

sjtrade/simulation_shioaji.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
import shioaji as sj
1010
from shioaji.constant import OrderState, Exchange, Action, TFTStockPriceType
1111

12-
13-
@dataclass
14-
class Snapshot:
15-
price: float
16-
bid: float = 0
17-
ask: float = 0
12+
from .data import Snapshot
1813

1914

2015
class SimulationShioaji:

sjtrade/stratage.py

+34-64
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,46 @@
11
import shioaji as sj
2-
from typing import Dict
2+
from typing import Dict, Optional
33
from loguru import logger
4-
from .io.file import read_position
5-
from .utils import price_round
6-
from .position import Position, PriceSet
74
from shioaji.constant import TFTStockPriceType
85

6+
from .io.file import read_csv_position, read_position
7+
from .utils import price_move, price_round, quantity_num_split
8+
from .position import Position, PriceSet
9+
from .data import Snapshot
10+
911

1012
class StratageBase:
1113
name: str
1214

1315
def entry_positions(self):
1416
raise NotImplementedError()
1517

16-
def cover_price_set(self, position: Position):
18+
def cover_price_set(self, position: Position, snapshot: Optional[Snapshot] = None):
1719
raise NotImplementedError()
1820

19-
def cover_positions(self, positions: Dict[str, Position]):
21+
def cover_positions(
22+
self, positions: Dict[str, Position], snapshots: Dict[str, Snapshot] = dict()
23+
):
2024
raise NotImplementedError()
2125

26+
def cover_price_set_onclose(self, position: Position):
27+
if position.status.open_quantity == 0:
28+
return []
29+
return [
30+
PriceSet(
31+
price=position.contract.limit_up
32+
if position.status.open_quantity > 0
33+
else position.contract.limit_down,
34+
quantity=position.cond.quantity * -1,
35+
price_type=TFTStockPriceType.LMT,
36+
)
37+
]
38+
39+
def cover_positions_onclose(self, positions: Dict[str, Position]):
40+
for code, position in positions.items():
41+
position.cond.cover_price = self.cover_price_set_onclose(position)
42+
return positions
43+
2244

2345
class StratageBasic(StratageBase):
2446
def __init__(
@@ -83,63 +105,11 @@ def entry_positions(self):
83105
)
84106
return entry_args
85107

86-
def cover_price_set(self, position: Position):
87-
return [
88-
PriceSet(
89-
price=43.3,
90-
quantity=position.cond.quantity * -1,
91-
price_type=TFTStockPriceType.LMT,
92-
)
93-
]
94-
95-
def cover_positions(self, positions: Dict[str, Position]):
96-
for code, position in positions.items():
97-
if not position.cond.cover_price:
98-
position.cond.cover_price = self.cover_price_set(position)
99-
return positions
108+
def cover_price_set(self, position: Position, snapshot: Optional[Snapshot] = None):
109+
return self.cover_price_set_onclose(position)
100110

111+
def cover_positions(
112+
self, positions: Dict[str, Position], snapshots: Dict[str, Snapshot] = dict()
113+
):
114+
return self.cover_positions_onclose()
101115

102-
class StratageAdvance(StratageBase):
103-
def __init__(
104-
self,
105-
entry_pct: float = -0.1,
106-
stop_loss_pct: float = 0.09,
107-
stop_profit_pct: float = 0.09,
108-
position_filepath: str = "position.txt",
109-
contracts: sj.contracts.Contracts = sj.contracts.Contracts(),
110-
) -> None:
111-
self.position_filepath = position_filepath
112-
self.entry_pct = entry_pct
113-
self.stop_loss_pct = stop_loss_pct
114-
self.stop_profit_pct = stop_profit_pct
115-
self.contracts = contracts
116-
self.name = "dt1"
117-
self.read_position_func = read_position
118-
119-
def entry_positions(self):
120-
positions = self.read_position_func(self.position_filepath)
121-
entry_args = []
122-
for code, pos in positions.items():
123-
contract = self.contracts.Stocks[code]
124-
if not contract:
125-
logger.warning(f"Code: {code} not exist in TW Stock.")
126-
continue
127-
stop_loss_price = contract.reference * (
128-
1 + (-1 if pos > 0 else 1) * (self.stop_loss_pct)
129-
)
130-
stop_profit_price = contract.reference * (
131-
1 + (1 if pos > 0 else -1) * (self.stop_profit_pct)
132-
)
133-
entry_price = contract.reference * (
134-
1 + (-1 if pos > 0 else 1) * self.entry_pct
135-
)
136-
entry_args.append(
137-
{
138-
"code": code,
139-
"pos": pos,
140-
"entry_price": {price_round(entry_price, pos > 0): pos},
141-
"stop_profit_price": {price_round(stop_profit_price, pos < 0): pos},
142-
"stop_loss_price": {price_round(stop_loss_price, pos > 0): pos},
143-
}
144-
)
145-
return entry_args

sjtrade/trader.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import shioaji as sj
66

77
from .utils import quantity_split, sleep_until
8+
from .data import Snapshot
89
from .simulation_shioaji import SimulationShioaji
910
from .stratage import StratageBasic
1011
from .position import Position, PositionCond, PriceSet
@@ -24,6 +25,7 @@ class SJTrader:
2425
def __init__(self, api: sj.Shioaji, simulation: bool = False):
2526
self.api = api
2627
self.positions: Dict[str, Position] = {}
28+
self.snapshots: Dict[str, Snapshot] = {}
2729
self._stop_loss_pct = 0.09
2830
self._stop_profit_pct = 0.09
2931
self._entry_pct = 0.05
@@ -47,6 +49,7 @@ def start(
4749
intraday_handler_time: datetime.time = datetime.time(8, 59, 55),
4850
cover_time: datetime.time = datetime.time(13, 25, 59),
4951
):
52+
self.set_on_tick_handler(self.update_snapshot)
5053
entry_future = self.executor_on_time(entry_time, self.place_entry_positions)
5154
self.executor_on_time(
5255
cancel_preorder_time,
@@ -137,6 +140,7 @@ def place_entry_order(
137140
cover_price=[],
138141
),
139142
)
143+
self.snapshots[code] = Snapshot(price=0.0)
140144
self.api.quote.subscribe(contract, version=QuoteVersion.v1)
141145
for price_set in self.positions[code].cond.entry_price:
142146
price, price_quantity = price_set.price, price_set.quantity
@@ -168,6 +172,9 @@ def place_entry_positions(self) -> Dict[str, Position]:
168172
api.update_status()
169173
return self.positions
170174

175+
def update_snapshot(self, exchange: Exchange, tick: sj.TickSTKv1):
176+
self.snapshots[tick.code].price = tick.close
177+
171178
def cancel_preorder_handler(self, exchange: Exchange, tick: sj.TickSTKv1):
172179
position = self.positions[tick.code]
173180
if self.simulation:
@@ -229,6 +236,7 @@ def intraday_handler(self, exchange: Exchange, tick: sj.TickSTKv1):
229236
self.simulation_api.quote_callback(exchange, tick)
230237
position = self.positions[tick.code]
231238
self.re_entry_order(position, tick)
239+
self.update_snapshot(exchange, tick)
232240
# 9:00 -> 13:24:49 stop loss stop profit
233241
self.stop_loss(position, tick)
234242
self.stop_profit(position, tick)
@@ -243,7 +251,7 @@ def stop_profit(self, position: Position, tick: sj.TickSTKv1):
243251
cross = "under"
244252
for price_set in position.cond.stop_profit_price:
245253
if op(tick.close, price_set.price):
246-
self.place_cover_order(position, price_set)
254+
self.place_cover_order(position, [price_set])
247255
logger.info(
248256
f"{position.contract.code} | price: {tick.close} cross {cross} {price_set.price} "
249257
f"cover quantity: {price_set.quantity}"
@@ -259,14 +267,14 @@ def stop_loss(self, position: Position, tick: sj.TickSTKv1):
259267
cross = "over"
260268
for price_set in position.cond.stop_profit_price:
261269
if op(tick.close, price_set.price):
262-
self.place_cover_order(position, price_set)
270+
self.place_cover_order(position, [price_set])
263271
logger.info(
264272
f"{position.contract.code} | price: {tick.close} cross {cross} {price_set.price} "
265273
f"cover quantity: {price_set.quantity}"
266274
)
267275

268276
def place_cover_order(
269-
self, position: Position, price_set: Optional[PriceSet] = None
277+
self, position: Position, price_sets: List[PriceSet] = []
270278
): # TODO with price quantity
271279
if self.simulation:
272280
api = self.simulation_api
@@ -275,16 +283,19 @@ def place_cover_order(
275283
cover_quantity = (
276284
position.status.open_quantity + position.status.cover_order_quantity
277285
)
278-
if price_set is None:
279-
price_set = self.stratage.cover_price_set(position)[0]
280-
skip = (cover_quantity == 0)
281-
if cover_quantity < 0:
286+
if not price_sets:
287+
price_sets = self.stratage.cover_price_set(
288+
position, self.snapshots[position.contract.code]
289+
)
290+
if cover_quantity == 0:
291+
return
292+
elif cover_quantity < 0:
282293
action = Action.Buy
283294
op = max
284-
elif cover_quantity > 0:
295+
else:
285296
action = Action.Sell
286297
op = min
287-
if not skip:
298+
for price_set in price_sets:
288299
cover_quantity_set = op(cover_quantity, price_set.quantity)
289300
if cover_quantity_set:
290301
quantity_s = quantity_split(cover_quantity_set, threshold=499)
@@ -305,7 +316,7 @@ def place_cover_order(
305316
position.cover_trades.append(trade)
306317
api.update_status(trade=trade)
307318

308-
def open_position_cover(self):
319+
def open_position_cover(self, onclose: bool = True):
309320
if self.simulation:
310321
api = self.simulation_api
311322
else:
@@ -334,7 +345,12 @@ def open_position_cover(self):
334345
]:
335346
api.cancel_order(trade, timeout=0)
336347
# event wait cancel
337-
self.positions = self.stratage.cover_positions(self.positions)
348+
if onclose:
349+
self.positions = self.stratage.cover_positions_onclose(self.positions)
350+
else:
351+
self.positions = self.stratage.cover_positions(
352+
self.positions, self.snapshots
353+
)
338354
for code, position in self.positions.items():
339355
if position.status.open_quantity:
340356
self.place_cover_order(position, position.cond.cover_price)

sjtrade/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22
import time
3+
import rs2py
34
import datetime
5+
from decimal import Decimal
46
from typing import List, Union
57

68

@@ -28,6 +30,24 @@ def price_round(price: float, up: bool = False):
2830
)
2931

3032

33+
def price_move(price: float, tick: int) -> float:
34+
return rs2py.get_price_tick_move(price, tick)
35+
36+
37+
def quantity_num_split(quantity: int, num: int) -> List[int]:
38+
neg = 1 if quantity > 0 else -1
39+
quantity_remain = abs(quantity)
40+
quantity_split_res = [(quantity_remain // num) * neg] * num
41+
quantity_remain -= abs(sum(quantity_split_res))
42+
for idx, _ in enumerate(quantity_split_res):
43+
qadd = math.ceil(quantity_remain / (num - idx))
44+
quantity_remain -= qadd
45+
quantity_split_res[idx] += qadd * neg
46+
if not quantity_remain:
47+
break
48+
return quantity_split_res
49+
50+
3151
def quantity_split(quantity: float, threshold: int) -> List[int]:
3252
return [threshold if quantity > 0 else -threshold] * (
3353
abs(quantity) // threshold

tests/test_stratage.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
3+
from sjtrade.stratage import StratageBase
4+
5+
6+
def test_stratage_base():
7+
stratage = StratageBase()
8+
with pytest.raises(NotImplementedError):
9+
stratage.entry_positions()
10+
11+
with pytest.raises(NotImplementedError):
12+
stratage.cover_price_set(None)
13+
14+
with pytest.raises(NotImplementedError):
15+
stratage.cover_positions(None)
16+
17+
18+

0 commit comments

Comments
 (0)