Skip to content

Commit

Permalink
Avoid some copies in TCP transport and reduce number of send calls
Browse files Browse the repository at this point in the history
Note, breaks TCP protocol compatibility
Fix a memory leak
  • Loading branch information
matt-attack committed Oct 27, 2024
1 parent aef549c commit 893b6d9
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 139 deletions.
6 changes: 3 additions & 3 deletions include/pubsub/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ struct ps_endpoint_t;
struct ps_client_t;
struct ps_subscribe_req_t;
struct ps_allocator_t;
typedef void(*ps_transport_fn_pub_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, const void* message, uint32_t length);
struct ps_msg_ref_t;
typedef void(*ps_transport_fn_pub_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, struct ps_msg_ref_t* message);
typedef int(*ps_transport_fn_spin_t)(struct ps_transport_t* transport, struct ps_node_t* node);
typedef void(*ps_transport_fn_add_publisher_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher);
typedef void(*ps_transport_fn_remove_publisher_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher);
Expand Down Expand Up @@ -129,10 +130,9 @@ struct ps_msg_info_t
struct ps_msg_header
{
uint8_t pid;//packet type id
uint32_t length;//message length
uint32_t id;//stream id
uint16_t seq;//sequence number
uint8_t index;
uint8_t count;
};
#pragma pack(pop)

Expand Down
2 changes: 1 addition & 1 deletion include/pubsub/Publisher.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct ps_pub_t
bool latched;// todo make this an enum of options if we add more
uint8_t recommended_transport;

struct ps_msg_t last_message;//only used if latched
struct ps_msg_ref_t* last_message;//only used if latched
unsigned int sequence_number;
};

Expand Down
11 changes: 10 additions & 1 deletion include/pubsub/Serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,22 @@ extern "C"
int field;// the field this is associated with in the message
};


// encoded message
struct ps_msg_t
{
void* data;
unsigned int len;
};

struct ps_msg_ref_t
{
void* data;
unsigned int len;
unsigned int refcount;
};

void ps_msg_ref_add(struct ps_msg_ref_t* msg);
void ps_msg_ref_free(struct ps_msg_ref_t* msg);

struct ps_allocator_t;
typedef struct ps_msg_t(*ps_fn_encode_t)(struct ps_allocator_t* allocator, const void* msg);
Expand Down
147 changes: 54 additions & 93 deletions include/pubsub/TCPTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <pubsub/Publisher.h>
#include <pubsub/Subscriber.h>
#include <pubsub/System.h>
#include <pubsub/UDPTransport.h>
//#include <pubsub/Net.h>

#include <stdlib.h>
Expand All @@ -20,6 +21,12 @@

#define PUBSUB_TCP_TRANSPORT 1

enum
{
PS_TCP_PROTOCOL_DATA = PS_UDP_PROTOCOL_DATA,
PS_TCP_PROTOCOL_MESSAGE_DEFINITION = 0x03,
};

/*
typedef void(*ps_transport_fn_pub_t)(struct ps_transport_t* transport, struct ps_pub_t* publisher, void* message);
Expand Down Expand Up @@ -62,8 +69,7 @@ struct ps_tcp_transport_connection

struct ps_tcp_client_queued_message_t
{
char* data;
int32_t length;
struct ps_msg_ref_t* msg;
};

struct ps_tcp_client_t
Expand All @@ -77,7 +83,7 @@ struct ps_tcp_client_t
int32_t desired_packet_size;
char* packet_data;

char* queued_message;
struct ps_msg_ref_t* queued_message;
int32_t queued_message_length;
int32_t queued_message_written;

Expand Down Expand Up @@ -124,15 +130,15 @@ void remove_client_socket(struct ps_tcp_transport_impl* transport, int socket, s

if (transport->clients[i].queued_message)
{
free(transport->clients[i].queued_message);
ps_msg_ref_free(transport->clients[i].queued_message);
}

// free queued messages
if (transport->clients[i].num_queued_messages)
{
for (int j = 0; j < transport->clients[i].num_queued_messages; j++)
{
free(transport->clients[i].queued_messages[j].data);
ps_msg_ref_free(transport->clients[i].queued_messages[j].msg);
}
free(transport->clients[i].queued_messages);
}
Expand Down Expand Up @@ -292,16 +298,16 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no
if (client->queued_message_written == client->queued_message_length)
{
//printf("Message sent.\n");
free(client->queued_message);
ps_msg_ref_free(client->queued_message);
client->queued_message = 0;

// we finished! check if there are more to send
if (client->num_queued_messages > 0)
{
// grab a message from the front of our message queue
client->queued_message = client->queued_messages[0].data;
client->queued_message = client->queued_messages[0].msg;
client->queued_message_written = 0;
client->queued_message_length = client->queued_messages[0].length;
client->queued_message_length = client->queued_messages[0].msg->len + sizeof(struct ps_msg_header);

client->num_queued_messages -= 1;
if (client->num_queued_messages == 0)
Expand Down Expand Up @@ -338,9 +344,9 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no
// if we havent gotten a header yet, just check for that
if (client->desired_packet_size == 0)
{
const int header_size = 5;
const int header_size = sizeof(struct ps_msg_header);
int len = recv(client->socket, buf, header_size, MSG_PEEK);
//printf("recv %i desired size 0\n", len);
//printf("peek %i desired size %i\n", len, header_size);
if (len == 0)
{
client->needs_removal = true;
Expand Down Expand Up @@ -406,12 +412,13 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no
impl->clients[i].publisher = pub;

// send the client the acknowledgement and message definition
int8_t packet_type = 0x03;//message definition
send(impl->clients[i].socket, (char*)&packet_type, 1, 0);

char buf[1500];
int32_t length = ps_serialize_message_definition((void*)buf, pub->message_definition);
send(impl->clients[i].socket, (char*)&length, 4, 0);
struct ps_msg_header hdr;
hdr.pid = PS_TCP_PROTOCOL_MESSAGE_DEFINITION;// message definition
hdr.length = length;
hdr.id = hdr.seq = 0;
send(impl->clients[i].socket, (char*)&hdr, sizeof(hdr), 0);
send(impl->clients[i].socket, buf, length, 0);

#ifdef PUBSUB_VERBOSE
Expand Down Expand Up @@ -464,28 +471,27 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no

// make the subscribe request in a "packet"
// a packet is an int length followed by data
int8_t packet_type = 0x01;//subscribe
send(connection->socket, (char*)&packet_type, 1, 0);

int32_t length = strlen(connection->subscriber->topic) + 1 + 4;
send(connection->socket, (char*)&length, 4, 0);

struct ps_msg_header hdr;
hdr.pid = 0x01;
hdr.length = length;
hdr.id = hdr.seq = 0;
send(connection->socket, (char*)&hdr, sizeof(hdr), 0);

// make the request
char buffer[500];
strcpy(buffer, connection->subscriber->topic);
uint32_t skip = connection->subscriber->skip;
send(connection->socket, (char*)&skip, 4, 0);
send(connection->socket, buffer, length - 4, 0);
send(connection->socket, connection->subscriber->topic, length - 4, 0);

connection->connecting = false;
}
}
// if we havent gotten a header yet, just check for that
else if (connection->waiting_for_header)
{
const int header_size = 5;
const int header_size = sizeof(struct ps_msg_header);
int len = recv(connection->socket, buf, header_size, MSG_PEEK);
//printf("len %i\n", len);
//printf("peek got: %i\n", len);
if (len == 0)
{
Expand Down Expand Up @@ -533,7 +539,7 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no
if (connection->current_size == connection->packet_size)
{
//printf("message finished type %x\n", connection->packet_type);
if (connection->packet_type == 0x3)
if (connection->packet_type == PS_TCP_PROTOCOL_MESSAGE_DEFINITION)
{
//printf("Was message definition\n");
if (connection->subscriber->type == 0)
Expand All @@ -553,7 +559,7 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no

free(connection->packet_data);
}
else if (connection->packet_type == 0x2)
else if (connection->packet_type == PS_TCP_PROTOCOL_DATA)
{
//printf("added to queue\n");
// decode and add it to the queue
Expand Down Expand Up @@ -595,8 +601,11 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no
return message_count;
}

void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, const void* message, uint32_t length)
void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* publisher, struct ps_client_t* client, struct ps_msg_ref_t* msg)
{
// todo dont
int length = msg->len;
void* message = msg->data;
struct ps_tcp_transport_impl* impl = (struct ps_tcp_transport_impl*)transport->impl;

// the client packs the socket id in the addr
Expand All @@ -619,21 +628,17 @@ void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* pub
{
// check if we have queue space left

// for now hardcode max queue size
// for now hardcode max queue size
const int max_queue_size = 10;

// copy the message to put it in the queue
// todo remove this copy
char* data = (char*)malloc(length + 4 + 1);
data[0] = 0x02;
*((uint32_t*)&data[1]) = length;
memcpy(&data[5], ps_get_msg_start(message), length);
// add a reference to the message and queue it up
ps_msg_ref_add(msg);

// this if statement is unnecessary, but I added it for the sake of testing/completeness
if (tclient->queued_message == 0)
{
tclient->queued_message = data;
tclient->queued_message_length = length + 5;
tclient->queued_message = msg;
tclient->queued_message_length = length + sizeof(struct ps_msg_header);
tclient->queued_message_written = 0;
}
else if (tclient->num_queued_messages >= max_queue_size)
Expand All @@ -644,8 +649,7 @@ void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* pub
{
tclient->queued_messages[i] = tclient->queued_messages[i - 1];
}
tclient->queued_messages[0].data = data;
tclient->queued_messages[0].length = length + 5;
tclient->queued_messages[0].msg = msg;
printf("dropped message on topic '%s'\n", publisher->topic);
return;// drop it, we are out of queue space
}
Expand All @@ -657,8 +661,7 @@ void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* pub
tclient->num_queued_messages += 1;
struct ps_tcp_client_queued_message_t* msgs = (struct ps_tcp_client_queued_message_t*)malloc(tclient->num_queued_messages * sizeof(struct ps_tcp_client_queued_message_t));

msgs[0].data = data;
msgs[0].length = length + 5;
msgs[0].msg = msg;
for (int i = 0; i < tclient->num_queued_messages - 1; i++)
{
msgs[i + 1] = tclient->queued_messages[i];
Expand All @@ -671,11 +674,14 @@ void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* pub
}
//printf("started writing\n");
// try and write, if any of these fail, make a copy
uint8_t packet_type = 0x02;
int c = send(socket, (char*)&packet_type, 1, 0);
if (c == 0)

// the message header is already filled out with the packet id and length

int32_t desired_len = sizeof(struct ps_msg_header) + length;
int32_t c = send(socket, message, desired_len, 0);
if (c < desired_len && c >= 0)
{
tclient->queued_message_written = 0;
tclient->queued_message_written = c;
goto FAILCOPY;
}
if (c < 0)
Expand All @@ -693,48 +699,6 @@ void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* pub
goto FAILDISCONNECT;
}

c = send(socket, (char*)&length, 4, 0);
if (c < 4 && c >= 0)
{
tclient->queued_message_written = c + 1;
goto FAILCOPY;
}
if (c < 0)
{
#ifdef WIN32
int error = WSAGetLastError();
if (error == WSAEWOULDBLOCK)
#else
if (errno == EAGAIN || errno == EWOULDBLOCK)
#endif
{
tclient->queued_message_written = 1;
goto FAILCOPY;
}
goto FAILDISCONNECT;
}

//printf("sending %i bytes\n", length + 4 + 1);
c = send(socket, (char*)ps_get_msg_start(message), length, 0);
if (c < length && c >= 0)
{
tclient->queued_message_written = c + 5;
goto FAILCOPY;
}
if (c < 0)
{
#ifdef WIN32
int error = WSAGetLastError();
if (error == WSAEWOULDBLOCK)
#else
if (errno == EAGAIN || errno == EWOULDBLOCK)
#endif
{
tclient->queued_message_written = 5;
goto FAILCOPY;
}
goto FAILDISCONNECT;
}
//printf("wrote all\n");
return;

Expand All @@ -745,14 +709,11 @@ void ps_tcp_transport_pub(struct ps_transport_t* transport, struct ps_pub_t* pub
return;

FAILCOPY:
// todo remove this copy
data = (char*)malloc(length + 4 + 1);
data[0] = 0x02;
*((uint32_t*)&data[1]) = length;
memcpy(&data[5], ps_get_msg_start(message), length);

tclient->queued_message = data;
tclient->queued_message_length = length + 5;
// add a reference count and put it in our queue
ps_msg_ref_add(msg);

tclient->queued_message = msg;
tclient->queued_message_length = length + sizeof(struct ps_msg_header);
ps_event_set_add_socket_write(&publisher->node->events, socket);
return;
}
Expand Down
3 changes: 2 additions & 1 deletion include/pubsub/UDPTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ struct ps_endpoint_t;
struct ps_sub_t;
struct ps_pub_t;
struct ps_msg_t;
struct ps_msg_ref_t;
struct ps_client_t;

void ps_udp_subscribe(struct ps_sub_t* sub, const struct ps_endpoint_t* ep);

void ps_udp_unsubscribe(struct ps_sub_t* sub);

void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_t* msg);
void ps_udp_publish(struct ps_pub_t* pub, struct ps_client_t* client, struct ps_msg_ref_t* msg);

#endif
3 changes: 1 addition & 2 deletions src/Node.c
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, cons
pub->topic = topic;
pub->node = node;
pub->latched = latched;
pub->last_message.data = 0;
pub->last_message.len = 0;
pub->last_message = 0;
pub->sequence_number = 0;
pub->recommended_transport = recommended_transport;

Expand Down
Loading

0 comments on commit 893b6d9

Please sign in to comment.