Skip to content

Commit

Permalink
Merge pull request #45 from shivam096/cindy-backend
Browse files Browse the repository at this point in the history
Cindy backend
  • Loading branch information
lvxinyi2000 authored Mar 13, 2024
2 parents a17bea8 + 684b0a5 commit a25aec9
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 44 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ jobs:
# Next step: run pylint. Anything less than 10/10 will fail.
- name: Lint with pylint
run: |
pylint dinero/**/*.py || exit 0
pylint dinero/**/*.py
# Next step: run the unit tests with code coverage.
- name: Unit tests
Expand Down
69 changes: 47 additions & 22 deletions dinero/backend/stock_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
3. update_stock_data()
4. get_stock_data(ticker_symbol)
5. get_filtered_stock_data(ticker_symbol, start_date='', end_date='')
6. get_last_n_days(stock_data, n_days)
"""
import os

Expand Down Expand Up @@ -35,6 +36,10 @@ def download_stock_data(ticker_symbol, period_str='5y'):
Save non-empty data to 'data/{ticker_symbol}.csv', do nothing otherwise.
Print progress messages (complate or fail) to terminal.
Exceptions:
TypeError if ticker_symbol or period_str is not string
ValueError if period_str has invalid format or ticker data not found.
Example:
>> download_stock_data("AAPL", "max")
=================================================================================
Expand All @@ -49,16 +54,21 @@ def download_stock_data(ticker_symbol, period_str='5y'):
1 Failed download:
['MFT']: Exception('%ticker%: No data found, symbol may be delisted')
"""
if not (isinstance(ticker_symbol, str) and isinstance(period_str, str)):
raise TypeError("Arguments must be strings.")
period_str = period_str.lower()
ticker_symbol = ticker_symbol.upper()
if not (period_str == 'max' or period_str[-1] in ['d','y']
or period_str[-2:] in ['wk','mo']):
raise ValueError("period format: 'max', 'd', 'wk', 'mo', 'y' (case insensitive).")

data = yf.download(ticker_symbol, period=period_str)

if len(data) == 0:
return len(data)
raise ValueError("Fail to download: no data, ticker symbol may be delisted ")

file_path = os.path.join(DEFAULT_DATABASE_PATH, f'{ticker_symbol}.csv')
data.to_csv(file_path)

return len(data)

def get_existing_tickers():
Expand All @@ -69,7 +79,8 @@ def get_existing_tickers():
list: A list of existing ticker symbol strings.
Return empty list if data folder is empty.
"""
return [file[:-4] for file in os.listdir('data') if file.endswith('.csv')]
return [file[:-4] for file in os.listdir(DEFAULT_DATABASE_PATH)
if file.endswith('.csv')]

def update_stock_data():
"""
Expand Down Expand Up @@ -103,14 +114,19 @@ def get_stock_data(ticker_symbol):
Returns:
pd.DataFrame: A DataFrame containing the stock data.
Exceptions:
TypeError if ticker_symbol is not a string
ValueError if ticker not in database
"""
if not isinstance(ticker_symbol, str):
raise TypeError("ticker symbol must be strings.")
ticker_symbol = ticker_symbol.upper()
file_path = os.path.join(DEFAULT_DATABASE_PATH, f'{ticker_symbol}.csv')
if not os.path.exists(file_path):
raise ValueError("No such database. Please download initial data first.")
return pd.read_csv(file_path)


def get_filtered_stock_data(ticker_symbol, start_date='1900-01-01', end_date=''):
"""
Function to fetch stock data of the ticker within given timeframe from statistic database.
Expand All @@ -125,54 +141,63 @@ def get_filtered_stock_data(ticker_symbol, start_date='1900-01-01', end_date='')
Returns:
pd.DataFrame: A DataFrame containing the filtered stock data.
Raises:
Exceptions:
TypeError:
If start_date or end_data is not string
ValueError:
If start_date is later than end_date.
If start_date or end_date has invalid date format
or start_date is later than end_date.
"""
stock_data = get_stock_data(ticker_symbol)

start_date = pd.to_datetime(start_date, format='%Y-%m-%d')
if not (isinstance(start_date, str) and isinstance(end_date, str)):
raise TypeError("Dates must be strings in form of 'yyyy-mm-dd'.")

if len(end_date) == 0:
if len(start_date.strip()) == 0:
start_date = '1900-01-01'
start_date = pd.to_datetime(start_date, format='%Y-%m-%d', errors='coerce')
if pd.isnull(start_date):
raise ValueError("Valid Format of start date should be 'yyyy-mm-dd'.")
if len(end_date.strip()) == 0:
end_date = pd.Timestamp.today() + pd.DateOffset(1)
end_date = pd.to_datetime(end_date, format='%Y-%m-%d')
end_date = pd.to_datetime(end_date, format='%Y-%m-%d', errors='coerce')
if pd.isnull(end_date):
raise ValueError("Valid Format of end date should be 'yyyy-mm-dd'.")

if start_date > end_date:
raise ValueError("Start date after end date.")

return stock_data.query(f"'{start_date}' <= Date <= '{end_date}'")

def get_last_n_days(df, n_days):
def get_last_n_days(stock_data, n_days):
"""
Extracts the last n_days rows from the DataFrame.
Parameters:
df (DataFrame): Input DataFrame.
stock_data (DataFrame): Input DataFrame.
n_days (int): Number of days from the end of the DataFrame.
Returns:
DataFrame: DataFrame containing the last n_days rows, or the entire
DataFrame if n_days exceeds the number of rows.
Raises:
ValueError:
If the number of days entered is a negative integer or 0.
TypeError:
If the number of days enetered is not an integer
Exceptions:
ValueError if the number of days entered is a negative integer or 0.
TypeError if stock_data is not pd.DataFrame or n_days is not an integer
"""

# Ensure n_days is a positive integer
# Ensure stock_data is a dataframe and n_days is a positive integer
if not isinstance(stock_data, pd.DataFrame):
raise TypeError("data must be pandas dataframe")
if not isinstance(n_days, int):
raise TypeError("Number of Days must be an Integer")
if n_days <= 0:
raise ValueError("Number of days must be a positive integer")

# Check if n_days exceeds the number of rows in the DataFrame
if n_days >= len(df):
return df # Return the entire DataFrame
if n_days >= len(stock_data):
return stock_data # Return the entire DataFrame

# Get the last n_days rows from the DataFrame
last_n_days_df = df.iloc[-n_days:]
last_n_days_df = stock_data.iloc[-n_days:]

return last_n_days_df
2 changes: 1 addition & 1 deletion dinero/backend/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def plot_stock_price(ticker_symbol):
name='Candlestick'))

# Update layout for candlestick chart
fig_candlestick.update_layout(title=f'Candlestick Chart',
fig_candlestick.update_layout(title=f'{ticker_symbol} Candlestick Chart',
xaxis_title="Date", yaxis_title="Price")
fig_candlestick.update_layout(hovermode="x unified")
time_buttons = [
Expand Down
85 changes: 71 additions & 14 deletions dinero/tests/test_stock_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
update_stock_data,
get_stock_data,
get_filtered_stock_data,
get_last_n_days,
DEFAULT_DATABASE_PATH
)

Expand All @@ -30,45 +31,74 @@ class TestStockDataManager(unittest.TestCase):
3. update_stock_data()
4. get_stock_data(ticker_symbol)
5. get_filtered_stock_data(ticker_symbol, start_date='', end_date='')
6. get_last_n_days(stock_data, n_days)
Notes:
`get_existing_tickers` is TRIVIAL and is called by update_stock_data(),
2.`get_existing_tickers` is TRIVIAL and is called by update_stock_data(),
so it has no explicit testing.
"""

@mock.patch('backend.stock_data_manager.yf.download')
def test_default_download_stock_data(self, mock_download):
"""
Test download with (valid) default period
(using mock to avoid changing database)
(using mock to simulate yfinance download)
"""
sample_data = f'{DEFAULT_DATABASE_PATH}/MSFT.csv'
mock_download.return_value = pd.read_csv(sample_data)
num = download_stock_data('valid_ticker')
mock_download.return_value = pd.DataFrame({"Close": [100, 200, 300]})
num = download_stock_data('test_ticker')
self.assertIsInstance(num, int)
file_path = f'{DEFAULT_DATABASE_PATH}/VALID_TICKER.csv'
file_path = f'{DEFAULT_DATABASE_PATH}/TEST_TICKER.csv'
self.assertTrue(os.path.exists(file_path))
os.remove(file_path)

def test_download_stock_data_invalid_input(self):
@mock.patch('backend.stock_data_manager.yf.download')
def test_download_stock_data_invalid_input_type(self, mock_download):
"""
Test download with invalid input type
(using mock to simulate yfinance download)
"""
mock_download.return_value = pd.DataFrame({"Close": [100, 200, 300]})
with self.assertRaises(TypeError):
download_stock_data('test_ticker', 1)
file_path = f'{DEFAULT_DATABASE_PATH}/TEST_TICKER.csv'
self.assertFalse(os.path.exists(file_path))

@mock.patch('backend.stock_data_manager.yf.download')
def test_download_stock_data_invalid_period(self, mock_download):
"""
Test download with invalid input
Test download with invalid period input
(using mock to simulate yfinance download)
"""
mock_download.return_value = pd.DataFrame({"Close": [100, 200, 300]})
with self.assertRaises(ValueError):
download_stock_data('test_ticker', '1')
file_path = f'{DEFAULT_DATABASE_PATH}/TEST_TICKER.csv'
self.assertFalse(os.path.exists(file_path))

@mock.patch('backend.stock_data_manager.yf.download')
def test_download_stock_data_invalid_ticker(self, mock_download):
"""
num = download_stock_data('invalid', '1d')
self.assertEqual(num, 0)
Test download with invalid ticker input
(using mock to simulate yfinance download)
"""
mock_download.return_value = pd.DataFrame()
with self.assertRaises(ValueError):
download_stock_data('invalid', '1d')
file_path = f'{DEFAULT_DATABASE_PATH}/INVALID.csv'
self.assertFalse(os.path.exists(file_path))

@mock.patch('backend.stock_data_manager.yf.download')
def test_update_stock_data(self, mock_download):
"""
Test update stock database
(using mock to avoid changing database)
(using mock to simulate yfinance download)
"""
mock_download.return_value = pd.DataFrame()
tikers = update_stock_data()
self.assertIsInstance(tikers, list)

tickers = update_stock_data()
self.assertIsInstance(tickers, list)
for ticker in tickers:
file_path = f'{DEFAULT_DATABASE_PATH}/{ticker}.csv'
self.assertTrue(os.path.exists(file_path))

@mock.patch("backend.stock_data_manager.pd.read_csv")
def test_get_stock_data(self, mock_read_csv):
Expand All @@ -86,6 +116,8 @@ def test_get_stock_data_invalid_ticker(self):
"""
Test fetch non-existing stock data from database
"""
with self.assertRaises(TypeError):
get_stock_data(1)
with self.assertRaises(ValueError):
get_stock_data('NON_EXISTENT_SYMBOL')

Expand All @@ -109,9 +141,34 @@ def test_get_filtered_stock_data_with_invalid_input(self):
"""
Test fetch filtered stock data from database with invalid input
"""
with self.assertRaises(TypeError):
get_filtered_stock_data('MSFT', 20240101, '2023/01/0')
with self.assertRaises(ValueError):
get_filtered_stock_data('MSFT', '2024/01/01', '2023/01/0')

def test_get_last_n_days_valid(self):
"""
Test get data of n last days from dataframe with valid input
(num of days > len(dataframe))
"""
test_df = pd.DataFrame({"Date": ['2023/01/0', '2023/10/01',
'2024/11/01','2024/01/01']})
test_result = get_last_n_days(test_df,6)
self.assertIsInstance(test_result, pd.DataFrame)
self.assertEqual(len(test_result), 4)

def test_get_last_n_days_invalid(self):
"""
Test get data of n last days from dataframe with invalid input
"""
test_df = pd.DataFrame({"Date": ['2023/01/0', '2023/10/01',
'2024/11/01','2024/01/01']})
with self.assertRaises(TypeError):
get_last_n_days(['2023/01/0'],2)
with self.assertRaises(TypeError):
get_last_n_days(test_df,"1")
with self.assertRaises(ValueError):
get_last_n_days(test_df,-1)

if __name__ == '__main__':
unittest.main()
14 changes: 8 additions & 6 deletions dinero/tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ def test_plot_stock_price_layout(self):
"""
Test the layout of the stock price candlestick chart.
Checks if the layout of the candlestick chart is as expected.
Checks if the layout of the candlestick chart is as expected:
chart type and tittles
tooltips
"""
stock_price_fig = plot_stock_price('MSFT')
self.assertIsInstance(stock_price_fig, go.Figure)
self.assertEqual(len(stock_price_fig.data), 1)
self.assertEqual(stock_price_fig.data[0].type, 'candlestick')
self.assertEqual(stock_price_fig.layout.hovermode, 'x unified')
self.assertEqual(stock_price_fig.layout.title['text'],
'Candlestick Chart')
'MSFT Candlestick Chart')
self.assertEqual(stock_price_fig.layout.xaxis.title['text'], 'Date')
self.assertEqual(stock_price_fig.layout.yaxis.title['text'], 'Price')
self.assertEqual(stock_price_fig.layout.hovermode, 'x unified')

def test_plot_stock_price_selector_view_adjustment(self):
"""
Expand All @@ -63,9 +65,9 @@ def test_plot_stock_price_selector_view_adjustment(self):

def test_plot_kpi_ma(self):
"""
Test plotting Moving Average (MA) on the stock price chart.
Test technical indicaters plotting using Moving Average (MA).
Checks if Moving Average is correctly added to the stock price candlestick chart.
Checks if MA trace is correctly added to the stock price candlestick chart.
"""
stock_price_fig = plot_stock_price('MSFT')
kpi_fig = plot_kpis(stock_price_fig, 'MSFT', 50, 'MA')
Expand All @@ -80,7 +82,7 @@ def test_plot_kpi_other(self):
Test plotting other technical indicators on the stock price chart.
Checks if other technical indicators are correctly added to
the stock price candlestick chart.
the plot groups with the stock price chart.
"""
stock_price_fig = plot_stock_price('MSFT')
kpi_fig = plot_kpis(stock_price_fig, 'MSFT', 50, 'ROC')
Expand Down

0 comments on commit a25aec9

Please sign in to comment.