Skip to content

Commit

Permalink
Merge pull request #2 from dancost/update_websockets
Browse files Browse the repository at this point in the history
Update websockets
  • Loading branch information
dancost authored Jul 23, 2024
2 parents 40cfc1a + 7bc8abf commit e7410c9
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 79 deletions.
29 changes: 0 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,3 @@ The test report will be generated and can be found at:
- `tests/test_results/report.html`
- sample test report can be found in github actions artifacts: https://github.com/dancost/trading_platform_sim/actions/runs/10039954131/artifacts/1725525112

## Sample WebSocket Client

You can use the following Python code to connect to the server via WebSockets and see the messages while tests are running:

```python
import asyncio
import json
import websockets
async def main():
uri = "ws://127.0.0.1:8000/ws"
async with websockets.connect(uri) as websocket:
# Send subscription message for all orders
subscribe_message = json.dumps({"action": "subscribe"})
await websocket.send(subscribe_message)
print("Subscribed to all orders")
while True:
try:
message = await websocket.recv()
print("Received:", json.loads(message))
except websockets.ConnectionClosed:
print("WebSocket connection closed")
break
if __name__ == "__main__":
asyncio.run(main())
14 changes: 9 additions & 5 deletions routers/orders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def execute_order(order_id: str, delay: int = 10):
if order["id"] == order_id and order["status"] == "PENDING":
order["status"] = "EXECUTED"
logger.info(f"Order executed: {order_id}")
await websocket_manager.broadcast({"action": "order_executed", "data": order})
await websocket_manager.broadcast({"action": "order_executed", "data": order}, order_id=order_id)


@router.get("/orders", response_model=List[OrderOutput], status_code=status.HTTP_200_OK)
Expand All @@ -48,7 +48,7 @@ async def post_order(order_input: OrderInput):
new_order["status"] = "PENDING"
orders_db.append(new_order)
logger.info(f"New order created: {new_order['id']}")
await websocket_manager.broadcast({"action": "new_order", "data": new_order})
await websocket_manager.broadcast({"action": "new_order", "data": new_order}, order_id=new_order["id"])

asyncio.create_task(execute_order(new_order["id"]))

Expand All @@ -74,7 +74,7 @@ async def cancel_an_order(orderId: str):
if order["id"] == orderId and order["status"] == "PENDING":
order["status"] = "CANCELED"
logger.info(f"Order canceled: {orderId}")
await websocket_manager.broadcast({"action": "order_cancelled", "data": order})
await websocket_manager.broadcast({"action": "order_cancelled", "data": order}, order_id=orderId)
return
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -88,7 +88,11 @@ async def websocket_connection(websocket: WebSocket):
try:
while True:
data = await websocket.receive_text()
if json.loads(data).get("action") == "subscribe":
websocket_manager.order_subscribers["all"].add(websocket)
message = json.loads(data)
action = message.get("action")
order_id = message.get("order_id")
if action == "subscribe" and order_id:
websocket_manager.order_subscribers[order_id].add(websocket)
logger.info(f"WebSocket subscribed to order ID: {order_id}")
except WebSocketDisconnect:
websocket_manager.disconnect(websocket)
1 change: 0 additions & 1 deletion tests/endpoint_tests/cancel_orders_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import json
import uuid
import time

Expand Down
2 changes: 0 additions & 2 deletions tests/endpoint_tests/create_orders_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import pytest
import json



@pytest.mark.smoke
Expand Down
1 change: 0 additions & 1 deletion tests/endpoint_tests/get_orders_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
import json


@pytest.fixture(scope='function')
Expand Down
31 changes: 17 additions & 14 deletions tests/ws_tests/performance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,43 @@ async def test_performance(forex_api_session):
order_data = {"stocks": "EURUSD", "quantity": 10}

async with aiohttp.ClientSession() as client:
# open ws connection and set ping interval to None (only way I could stop the server from disconnecting me)
# Open WebSocket connection
async with connect(uri, ping_interval=None) as websocket:
# place 100 orders simultaneously
# Place 100 orders simultaneously
tasks = [place_order(client, base_url, order_data) for _ in range(100)]
start_time = time.time()
responses = await asyncio.gather(*tasks)
end_time = time.time()
print(f"Time to place 100 orders: {end_time - start_time:.2f} seconds")

# check response
# Check response and collect order IDs
order_ids = []
for response in responses:
assert 'id' in response, f"Response missing 'id': {response}"
order_ids.append(response['id'])

print(f"Placed 100 orders successfully.")

# read ws messages
pending_timestamps = {}
# Subscribe to specific order IDs
for order_id in order_ids:
subscribe_message = json.dumps({"action": "subscribe", "order_id": order_id})
await websocket.send(subscribe_message)

# Read WebSocket messages
executed_timestamps = {}
for _ in range(200): # expect 2 msg per order (pending and executed)
message = await websocket.recv()
for _ in range(100): # Expect only the EXECUTED messages
message = await asyncio.wait_for(websocket.recv(), timeout=20)
message_data = json.loads(message)
order_id = message_data['data']['id']
if message_data['data']['status'] == 'PENDING':
pending_timestamps[order_id] = time.time()
elif message_data['data']['status'] == 'EXECUTED':
if message_data['data']['status'] == 'EXECUTED':
executed_timestamps[order_id] = time.time()
print(f"Executed message for order {order_id}: {message_data}")

# compute delays
# Compute delays
execution_delays = []
for order_id in order_ids:
if order_id in pending_timestamps and order_id in executed_timestamps:
delay = executed_timestamps[order_id] - pending_timestamps[order_id]
if order_id in executed_timestamps:
delay = executed_timestamps[order_id] - start_time
execution_delays.append(delay)
else:
print(f"Missing timestamps for order {order_id}")
Expand All @@ -67,4 +70,4 @@ async def test_performance(forex_api_session):
print(f"Standard Deviation of Delay: {stddev_delay:.2f} seconds")
else:
print("No execution delays were recorded.")
print(f"Total Time: {end_time - start_time:.2f} seconds")
print(f"Total Time: {end_time - start_time:.2f} seconds")
79 changes: 60 additions & 19 deletions tests/ws_tests/websockets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_websocket_order_status(forex_api_session):

async with aiohttp.ClientSession() as client:
async with connect(uri) as websocket:
# place an order with a valid currency pair
# place order with a valid currency pair
order_data = {"stocks": "EURUSD", "quantity": 10}
order_response = await place_order(client, base_url, order_data)
print(f"Order response after placing: {order_response}")
Expand All @@ -30,16 +30,16 @@ async def test_websocket_order_status(forex_api_session):
print(f"Order ID: {order_id}")
assert order_id is not None, "Order ID should not be None"

# start listening ws messages
pending_message = await websocket.recv()
pending_data = json.loads(pending_message)
print(f"Pending message: {pending_data}")
# check the PENDING status from the REST API response
assert order_response['status'] == 'PENDING', "Initial status is not PENDING"

assert pending_data["data"]["id"] == order_id
assert pending_data["data"]["status"] == "PENDING"
# sub to the specific order ID
subscribe_message = json.dumps({"action": "subscribe", "order_id": order_id})
await websocket.send(subscribe_message)
print(f"Subscribed to order ID: {order_id}")

# wait for EXECUTED status
executed_message = await websocket.recv()
# check the "EXECUTED" status is received
executed_message = await asyncio.wait_for(websocket.recv(), timeout=12)
executed_data = json.loads(executed_message)
print(f"Executed message: {executed_data}")

Expand All @@ -57,7 +57,7 @@ async def test_no_messages_after_cancelled(forex_api_session):

async with aiohttp.ClientSession() as client:
async with connect(uri) as websocket:
# place an order with a valid currency pair
# place order with a valid currency pair
order_data = {"stocks": "EURUSD", "quantity": 10}
order_response = await place_order(client, base_url, order_data)
print(f"Order response after placing: {order_response}")
Expand All @@ -66,30 +66,71 @@ async def test_no_messages_after_cancelled(forex_api_session):
print(f"Order ID: {order_id}")
assert order_id is not None, "Order ID should not be None"

# listen for the first message (PENDING status)
pending_message = await websocket.recv()
pending_data = json.loads(pending_message)
print(f"Pending message: {pending_data}")
# check the PENDING status from the REST API response
assert order_response['status'] == 'PENDING', "Initial status is not PENDING"

assert pending_data["data"]["id"] == order_id
assert pending_data["data"]["status"] == "PENDING"
# sub to the specific order ID
subscribe_message = json.dumps({"action": "subscribe", "order_id": order_id})
await websocket.send(subscribe_message)
print(f"Subscribed to order ID: {order_id}")

# cancel the order
cancel_response = await client.delete(f"{base_url}/orders/{order_id}")
cancel_text = await cancel_response.text()
print(f"Cancel response: {cancel_text}")
assert cancel_response.status == 204, f"Failed to cancel order: {cancel_text}"

# wait for the second message (CANCELLED status)
cancelled_message = await websocket.recv()
# wait for CANCELLED status
cancelled_message = await asyncio.wait_for(websocket.recv(), timeout=15)
cancelled_data = json.loads(cancelled_message)
print(f"Cancelled message: {cancelled_data}")

assert cancelled_data["data"]["id"] == order_id
assert cancelled_data["data"]["status"] == "CANCELED"

# ensure no further messages are received within a reasonable time frame
# ensure no more messages are received within a reasonable time frame
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(websocket.recv(), timeout=12)

await websocket.close()


@pytest.mark.ws
@pytest.mark.asyncio
async def test_multiple_users_notified(forex_api_session):
base_url = forex_api_session.base_url
uri = f"ws://{base_url.split('//')[1]}/ws"

async with aiohttp.ClientSession() as client:
async with connect(uri) as websocket1, connect(uri) as websocket2:
order_data = {"stocks": "EURUSD", "quantity": 10}
order_response = await place_order(client, base_url, order_data)
print(f"Order response after placing: {order_response}")

order_id = order_response.get('id')
print(f"Order ID: {order_id}")
assert order_id is not None, "Order ID should not be None"

assert order_response['status'] == 'PENDING', "Initial status is not PENDING"

subscribe_message = json.dumps({"action": "subscribe", "order_id": order_id})
await websocket1.send(subscribe_message)
await websocket2.send(subscribe_message)
print(f"Subscribed to order ID: {order_id} with both WebSocket clients")

executed_message1 = await asyncio.wait_for(websocket1.recv(), timeout=12)
executed_data1 = json.loads(executed_message1)
print(f"Executed message from websocket1: {executed_data1}")

executed_message2 = await asyncio.wait_for(websocket2.recv(), timeout=12)
executed_data2 = json.loads(executed_message2)
print(f"Executed message from websocket2: {executed_data2}")

assert executed_data1["data"]["id"] == order_id
assert executed_data1["data"]["status"] == "EXECUTED"

assert executed_data2["data"]["id"] == order_id
assert executed_data2["data"]["status"] == "EXECUTED"

await websocket1.close()
await websocket2.close()
16 changes: 8 additions & 8 deletions websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,38 @@
import json
from collections import defaultdict
from fastapi import WebSocket
from typing import Set
from typing import Set, Dict
import logging

logger = logging.getLogger(__name__)


class ConnectionManager:
# store active ws connections and subscribers
def __init__(self):
# initialize and store all active connections
self.active_connections: Set[WebSocket] = set()
self.order_subscribers: defaultdict[str, Set[WebSocket]] = defaultdict(set)
# store subscribers and order id
self.order_subscribers: Dict[str, Set[WebSocket]] = defaultdict(set)

async def connect(self, websocket: WebSocket):
# add connection and add it to active ws connections
# accept new connection and add it to active connections
await websocket.accept()
self.active_connections.add(websocket)
logger.info("WebSocket connected")

def disconnect(self, websocket: WebSocket):
# remove ws connection from active connections and any subscription
# remove websockets connection from active connections and subscriptions
self.active_connections.remove(websocket)
for subscribers in self.order_subscribers.values():
if websocket in subscribers:
subscribers.remove(websocket)
logger.info("WebSocket disconnected")

async def broadcast(self, message: dict, order_id: str = None):
# if an order ID is provided, send the message only to subscribed clients for that order
# send messages to subscribed clients
if order_id:
subscribers = self.order_subscribers[order_id]
logger.info(
f"Broadcasting message to {len(subscribers)} subscriber(s) for order ID: {order_id}")
logger.info(f"Broadcasting message to {len(subscribers)} subscriber(s) for order ID: {order_id}")
for connection in subscribers:
await connection.send_text(json.dumps(message))
else:
Expand Down

0 comments on commit e7410c9

Please sign in to comment.