Skip to content

Commit

Permalink
chore: renamed renderers to plotters, updated tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
chrootlogin committed Aug 25, 2024
1 parent eb9128c commit 6a21e66
Show file tree
Hide file tree
Showing 17 changed files with 319 additions and 420 deletions.
2 changes: 1 addition & 1 deletion docs/source/examples/renderers_and_plotly_chart.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ agent.train(n_episodes=2, n_steps=200, render_interval=10)
Create PlotlyTradingChart and FileLogger renderers. Configuring renderers is optional as they can be used with their default settings.

```python
from tensortrade.env.renderers.abstract import PlotlyTradingChart, FileLogger
from tensortrade.env.plotters.abstract import PlotlyTradingChart, FileLogger

chart_renderer = PlotlyTradingChart(
display=True, # show the chart on screen (default)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/tutorial.ipynb
Git LFS file not shown
157 changes: 157 additions & 0 deletions examples/simple_training_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
This creates a simple training environment.
"""

import pandas as pd
import numpy as np
import ta.trend
import ta.momentum

from stable_baselines3 import PPO

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

from tensortrade.env.stoppers import MaxLossStopper
from tensortrade.feed import Stream
from tensortrade.oms.instruments import Instrument
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.wallets import Portfolio, Wallet

from tensortrade.env import TradingEnv
from tensortrade.env.observers import SimpleObserver
from tensortrade.env.actions import BSH
from tensortrade.env.rewards import SimpleProfit
from tensortrade.env.plotters import PlotlyTradingChart
from tensortrade.feed import DataFeed


"""
Loading data
"""

df = pd.read_csv('data/BTC_USDT_5m_20240601-20240731.csv').set_index('timestamp')

"""
Create TA features
"""

# Simple Moving Averages (SMA)
df['SMA_20'] = df['close'].rolling(window=20).mean()
df['SMA_50'] = df['close'].rolling(window=50).mean()

# Exponential Moving Averages (EMA)
df['EMA_10'] = df['close'].ewm(span=10, adjust=False).mean()
df['EMA_20'] = df['close'].ewm(span=20, adjust=False).mean()

# Relative Strength Index (RSI)
df['RSI_14'] = ta.momentum.RSIIndicator(df['close'], window=14).rsi()

# Moving Average Convergence Divergence (MACD)
macd = ta.trend.MACD(df['close'])
df['MACD'] = macd.macd()
df['MACD_signal'] = macd.macd_signal()
df['MACD_diff'] = macd.macd_diff()

# Price change
df['price_change'] = df['close'].pct_change()

"""
Copy OHLCV data
"""

df['raw-open'] = df['open']
df['raw-high'] = df['high']
df['raw-low'] = df['low']
df['raw-close'] = df['close']
df['raw-volume'] = df['volume']

"""
Clean up and prepare data for learning
"""

# Remove empty values
df.dropna(inplace=True)

# Split in raw data and feature data
feature_columns = [feature for feature in list(df) if not feature.startswith('raw-')]
df_features = df[feature_columns]
df_raw = df.drop(columns=feature_columns)

# Normalize feature data
scaler = StandardScaler()
df_features = scaler.fit_transform(df_features)
df_features = pd.DataFrame(df_features, columns=feature_columns, index=df.index)

# Concat dataframe again
df = pd.concat([df_features, df_raw], axis=1)

# Convert dataframe to float32
df = df.astype(np.float32)

# Last but not least we split the data frame into training and testing data
train_df, test_df = train_test_split(df, test_size=0.3, shuffle=False)

"""
Create Portfolio
"""

# Prepare trading instruments
USDT = Instrument('USDT', 2, 'US Dollar Tether')
BTC = Instrument('BTC', 6, 'Bitcoin')

# prepare exchange
prices_stream = Stream.source(df['raw-close'], dtype='float').rename('USDT/BTC')
exchange = Exchange('dummy', service=execute_order)(prices_stream)

# prepare wallets
usdt_wallet = Wallet(exchange, 1000 * USDT)
btc_wallet = Wallet(exchange, 0 * BTC)

# prepare portfolio
portfolio = Portfolio(USDT, [
usdt_wallet,
btc_wallet
])

"""
Train Agent / Build TensorTrade-NG environment
"""

raw_data = ['raw-open', 'raw-high', 'raw-low', 'raw-close', 'raw-volume']

# prepare features
features = [Stream.source(train_df[f], dtype="float").rename(f) for f in feature_columns]

# prepare schemes
action_scheme = BSH(cash=usdt_wallet, asset=btc_wallet) # BSH Action Sheme
reward_scheme = SimpleProfit() # Simple Profit Reward Scheme
observer = SimpleObserver()
stopper = MaxLossStopper(max_allowed_loss=0.5)

# prepare meta feed
meta = [Stream.source(train_df.index).rename('date')]
meta += [Stream.source(train_df[f], dtype="float").rename(f[4:]) for f in raw_data]
meta += [Stream.sensor(action_scheme, lambda s: s.action, dtype="float").rename("action")]

feed = DataFeed([
Stream.group(features).rename('features'),
Stream.group(meta).rename('meta')
])

# create the tensortrade environment
env = TradingEnv(
portfolio=portfolio,
feed=feed,
action_scheme=action_scheme,
reward_scheme=reward_scheme,
observer=observer,
stopper=stopper,
plotter=[PlotlyTradingChart(save_format='html')]
)

# Last but not least create our model and learn it
model = PPO('MlpPolicy', env, verbose=1).learn(10_000)

env.plot()
106 changes: 1 addition & 105 deletions src/tensortrade/env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,108 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from __future__ import annotations

import typing
from typing import List

from tensortrade.env.environment import TradingEnv

from tensortrade.feed import DataFeed, Stream

if typing.TYPE_CHECKING:
from typing import Optional, Union

from tensortrade.env.interfaces import (
AbstractActionScheme,
AbstractRewardScheme,
AbstractInformer,
AbstractRenderer,
AbstractStopper,
AbstractObserver
)

from tensortrade.oms.wallets import Portfolio


def create(portfolio: Portfolio,
feed: DataFeed,
action_scheme: AbstractActionScheme,
reward_scheme: AbstractRewardScheme,
informer: Optional[AbstractInformer] = None,
observer: Optional[AbstractObserver] = None,
renderer: Optional[Union[List[AbstractRenderer], AbstractRenderer]] = None,
renderer_feed: Optional[DataFeed] = None,
stopper: Optional[AbstractStopper] = None,
window_size: int = 1,
min_periods: int = None,
random_start_pct: float = 0.00,
max_allowed_loss: float = 0.5) -> TradingEnv:
"""Creates a default ``TradingEnv`` to be used by a RL agent of your choice. It allows you
:param portfolio: Portfolio: The portfolio that the RL agent will be interacting with.
:param feed: DataFeed: The data feed for the look back window with the ohlcv and feature data.
:param action_scheme: AbstractActionScheme: The action scheme used by the TradingEnv.
:param reward_scheme: AbstractRewardScheme: The reward scheme applied to the RL agent.
:param informer: Optional[AbstractInformer]: The information logger which runs on every episode. (Default value = None)
:param observer: Optional[AbstractObserver]: The observer which will create the observation for the RL agent. If ``None``, the default observer will be used. (Default value = None)
:param renderer: Optional[Union[List[AbstractRenderer], AbstractRenderer]]: A renderer which will be used for rendering the environment. Like for creating charts. Will be executed when :code:`env.render()` is called. You can insert a list if you want to use more than one renderer. (Default value = None)
:param renderer_feed: Optional[DataFeed]: An optional feed for the renderer, mostly with the actual prices used for rendering. (Default value = None)
:param stopper: Optional[AbstractStopper]: The stopper which resets the environment on the defined circumstanced. If ``None``, the MaxLossStopper will be used which resets the environment on :code:`max_allowed_loss`. (Default value = None)
:param window_size: int: The window size which will used by the default observer. Actually the timerange of your data that agent sees. (Default value = 1)
:param min_periods: int: The amount of steps needed to warm up the :code:`feed`. So actually when the first episode starts. (Default value = None)
:param random_start_pct: float: If the agent should randomly start after this percent of data. Can be used to prevent overfitting. (Default value = 0.00)
:param max_allowed_loss: float: When using the default stopper this is max loss the agent is allowed to have before it gets reseted (Default value = 0.5)
:rtype: TradingEnv
:returns: A training environment you can use for training a reinforcement learning agent.
"""
from tensortrade.env import actions, informers, observers, renderers, rewards, stoppers
from tensortrade.env.renderers.utils import AggregateRenderer

# set portfolio of action scheme
action_scheme.portfolio = portfolio

# prepare observer
if observer is None:
observer = observers.WindowObserver(
window_size=window_size,
)

# prepare stopper
if stopper is None:
stopper = stoppers.MaxLossStopper(
max_allowed_loss=max_allowed_loss
)

# prepare informer
if informer is None:
informer = informers.SimpleInformer()

# prepare renderer
if isinstance(renderer, List):
renderer = AggregateRenderer(renderer)

if renderer_feed is not None:
env_feed = DataFeed([
Stream.group(feed.inputs).rename("features"),
Stream.group(renderer_feed.inputs).rename("data")
])
else:
env_feed = DataFeed([
Stream.group(feed.inputs).rename("features")
])

# create env
return TradingEnv(
portfolio=portfolio,
feed=env_feed,
action_scheme=action_scheme,
reward_scheme=reward_scheme,
observer=observer,
data_feed=renderer_feed,
stopper=stopper,
informer=informer,
renderer=renderer,
min_periods=min_periods,
random_start_pct=random_start_pct,
)
from tensortrade.env.environment import TradingEnv
Loading

0 comments on commit 6a21e66

Please sign in to comment.