Skip to content

Commit 9c08824

Browse files
committed
fix: price round precision issue with decimal
1 parent fff5f5c commit 9c08824

File tree

4 files changed

+34
-11
lines changed

4 files changed

+34
-11
lines changed

sjtrade/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
trading with shioaji
33
"""
44

5-
__version__ = "0.4.4"
5+
__version__ = "0.4.5"
66

77
def inject_env():
88
import os

sjtrade/strategy.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from loguru import logger
44
from shioaji.constant import TFTStockPriceType
55

6-
from .io.file import read_csv_position, read_position
7-
from .utils import price_move, price_round, quantity_num_split
6+
from .io.file import read_position
7+
from .utils import price_round, price_limit
88
from .position import Position, PriceSet
99
from .data import Snapshot
1010

@@ -31,7 +31,7 @@ def cover_price_set_onclose(self, position: Position):
3131
price=position.contract.limit_down
3232
if position.status.open_quantity > 0
3333
else position.contract.limit_up,
34-
quantity=position.status.open_quantity*-1,
34+
quantity=position.status.open_quantity * -1,
3535
price_type=TFTStockPriceType.LMT,
3636
)
3737
]
@@ -70,33 +70,45 @@ def entry_positions(self):
7070
stop_loss_price = contract.reference * (
7171
1 + (-1 if pos > 0 else 1) * (self.stop_loss_pct)
7272
)
73+
stop_loss_price = price_round(stop_loss_price, pos > 0)
74+
stop_loss_price = price_limit(
75+
stop_loss_price, contract.limit_up, contract.limit_down
76+
)
7377
stop_profit_price = contract.reference * (
7478
1 + (1 if pos > 0 else -1) * (self.stop_profit_pct)
7579
)
80+
stop_profit_price = price_round(stop_profit_price, pos < 0)
81+
stop_profit_price = price_limit(
82+
stop_profit_price, contract.limit_up, contract.limit_down
83+
)
7684
entry_price = contract.reference * (
7785
1 + (-1 if pos > 0 else 1) * self.entry_pct
7886
)
87+
entry_price = price_round(entry_price, pos > 0)
88+
entry_price = price_limit(
89+
entry_price, contract.limit_up, contract.limit_down
90+
)
7991
entry_args.append(
8092
{
8193
"code": code,
8294
"pos": pos,
8395
"entry_price": [
8496
PriceSet(
85-
price=price_round(entry_price, pos > 0),
97+
price=entry_price,
8698
quantity=pos,
8799
price_type=TFTStockPriceType.LMT,
88100
)
89101
],
90102
"stop_profit_price": [
91103
PriceSet(
92-
price=price_round(stop_profit_price, pos < 0),
104+
price=stop_profit_price,
93105
quantity=pos,
94106
price_type=TFTStockPriceType.MKT,
95107
)
96108
],
97109
"stop_loss_price": [
98110
PriceSet(
99-
price=price_round(stop_loss_price, pos > 0),
111+
price=stop_loss_price,
100112
quantity=pos,
101113
price_type=TFTStockPriceType.MKT,
102114
)
@@ -112,4 +124,3 @@ def cover_positions(
112124
self, positions: Dict[str, Position], snapshots: Dict[str, Snapshot] = dict()
113125
):
114126
return self.cover_positions_onclose()
115-

sjtrade/utils.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,27 @@ def price_floor(price: float) -> float:
2020
return round(math.floor(price * n - ((price * n % 5) * (1 - quinary))) / n, 2)
2121

2222

23-
def price_round(price: float, up: bool = False):
23+
def price_round(price: float, up: bool = False) -> float:
2424
roudnf = math.ceil if up else math.floor
25+
price = Decimal(f"{price}")
2526
logp = math.floor(math.log10(price))
2627
quinary = ((price / 10**logp) // 5) if logp >= 1 else 1
2728
n = min(10 ** (3 - logp - quinary), 100)
28-
return round(
29-
roudnf(price * n + ((5 * int(up) - (price * n % 5)) * (1 - quinary))) / n, 2
29+
return float(
30+
round(
31+
roudnf(price * n + ((5 * int(up) - (price * n % 5)) * (1 - quinary))) / n, 2
32+
)
3033
)
3134

3235

36+
def price_limit(price: float, up: float, down: float):
37+
if price > up:
38+
return up
39+
elif price < down:
40+
return down
41+
return price
42+
43+
3344
def price_move(price: float, tick: int) -> float:
3445
return rs2py.get_price_tick_move(price, tick)
3546

tests/test_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_price_floor(input: float, expected: float):
6666
(50.01, True, 50.1),
6767
(100.01, True, 100.5),
6868
(500.01, True, 501),
69+
(34.05, False, 34.05),
6970
# TODO case 10% over limit up
7071
],
7172
)

0 commit comments

Comments
 (0)