Skip to content

Commit

Permalink
Integrate allocators into C++
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-attack committed Mar 1, 2025
1 parent 5c9215d commit 07b52d9
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 33 deletions.
2 changes: 1 addition & 1 deletion include/pubsub/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void ps_node_init_ex(struct ps_node_t* node, const char* name, const char* ip, b

void ps_node_create_publisher(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched);

void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched, unsigned int recommended_transport);
void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched, unsigned int recommended_transport, struct ps_allocator_t* allocator);


typedef void(*ps_subscriber_fn_cb_t)(void* message, unsigned int size, void* data, const struct ps_msg_info_t* info);
Expand Down
18 changes: 10 additions & 8 deletions include/pubsub/Publisher.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ struct ps_message_definition_t;

struct ps_endpoint_t
{
unsigned short port;
int address;
unsigned short port;
int address;
//bool multicast;// this is probably unnecessary
};

// publisher client to network to
struct ps_client_t
{
struct ps_endpoint_t endpoint;
unsigned short sequence_number;// sequence of the networked packets, incremented with each one
unsigned long long last_keepalive;// timestamp of the last keepalive message, used to know when to deactiveate this connection
unsigned int stream_id;// user-unique identifier of what topic this came from
unsigned int modulo;
struct ps_transport_t* transport;
struct ps_endpoint_t endpoint;
unsigned short sequence_number;// sequence of the networked packets, incremented with each one
unsigned long long last_keepalive;// timestamp of the last keepalive message, used to know when to deactiveate this connection
unsigned int stream_id;// user-unique identifier of what topic this came from
unsigned int modulo;
struct ps_transport_t* transport;
};

struct ps_pub_t
Expand All @@ -39,6 +39,8 @@ struct ps_pub_t
struct ps_node_t* node;
unsigned int num_clients;
struct ps_client_t* clients;

struct ps_allocator_t* allocator;

bool latched;// todo make this an enum of options if we add more
uint8_t recommended_transport;
Expand Down
4 changes: 2 additions & 2 deletions include/pubsub/Serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ extern "C"
};

void ps_msg_ref_add(struct ps_msg_ref_t* msg);
void ps_msg_ref_free(struct ps_msg_ref_t* msg);
void ps_msg_ref_free(struct ps_msg_ref_t* msg, struct ps_allocator_t* allocator);

struct ps_allocator_t;
typedef struct ps_msg_t(*ps_fn_encode_t)(const void* msg, struct ps_allocator_t* allocator);
Expand Down Expand Up @@ -150,7 +150,7 @@ extern "C"

// Makes a copy of a given serialized message
// Returns: The new copy
struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg);
struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg, struct ps_allocator_t* allocator);

#ifdef __cplusplus
}
Expand Down
6 changes: 3 additions & 3 deletions include/pubsub/TCPTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,15 @@ void remove_client_socket(struct ps_tcp_transport_impl* transport, int socket, s

if (transport->clients[i].queued_message)
{
ps_msg_ref_free(transport->clients[i].queued_message);
ps_msg_ref_free(transport->clients[i].queued_message, transport->clients[i].publisher->allocator);
}

// free queued messages
if (transport->clients[i].num_queued_messages)
{
for (int j = 0; j < transport->clients[i].num_queued_messages; j++)
{
ps_msg_ref_free(transport->clients[i].queued_messages[j].msg);
ps_msg_ref_free(transport->clients[i].queued_messages[j].msg, transport->clients[i].publisher->allocator);
}
free(transport->clients[i].queued_messages);
}
Expand Down Expand Up @@ -301,7 +301,7 @@ int ps_tcp_transport_spin(struct ps_transport_t* transport, struct ps_node_t* no
if (client->queued_message_written == client->queued_message_length)
{
//printf("Message sent.\n");
ps_msg_ref_free(client->queued_message);
ps_msg_ref_free(client->queued_message, client->publisher->allocator);
client->queued_message = 0;

// we finished! check if there are more to send
Expand Down
6 changes: 3 additions & 3 deletions include/pubsub_cpp/Node.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <pubsub/Publisher.h>
#include <pubsub/Subscriber.h>
#include <pubsub/System.h>

#include <pubsub_cpp/allocator.h>

#include <vector>
#include <thread>
Expand Down Expand Up @@ -310,7 +310,7 @@ class Publisher: public PublisherBase
remapped_topic_ = handle_remap(real_topic, node.getNamespace());

node.lock_.lock();
ps_node_create_publisher_ex(node.getNode(), remapped_topic_.c_str(), T::GetDefinition(), &publisher_, latched, preferred_transport);
ps_node_create_publisher_ex(node.getNode(), remapped_topic_.c_str(), T::GetDefinition(), &publisher_, latched, preferred_transport, T::Allocator::allocator());
node.lock_.unlock();

//add me to the publisher list
Expand Down Expand Up @@ -576,7 +576,7 @@ class Subscriber: public SubscriberBase
options.skip = skip;
options.cb = cb2;
options.cb_data = this;
options.allocator = 0;
options.allocator = T::Allocator::allocator();
options.ignore_local = true;
options.preferred_transport = preferred_transport;

Expand Down
5 changes: 3 additions & 2 deletions src/Node.c
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ void ps_node_advertise(struct ps_pub_t* pub)
int sent_bytes = sendto(pub->node->socket, (const char*)data, off, 0, (struct sockaddr*)&address, sizeof(struct sockaddr_in));
}

void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched, unsigned int recommended_transport)
void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched, unsigned int recommended_transport, struct ps_allocator_t* allocator)
{
node->num_pubs++;
struct ps_pub_t** old_pubs = node->pubs;
Expand All @@ -164,13 +164,14 @@ void ps_node_create_publisher_ex(struct ps_node_t* node, const char* topic, cons
pub->last_message = 0;
pub->sequence_number = 0;
pub->recommended_transport = recommended_transport;
pub->allocator = allocator ? allocator : &ps_default_allocator;

ps_node_advertise(pub);
}

void ps_node_create_publisher(struct ps_node_t* node, const char* topic, const struct ps_message_definition_t* type, struct ps_pub_t* pub, bool latched)
{
ps_node_create_publisher_ex(node, topic, type, pub, latched, 0);
ps_node_create_publisher_ex(node, topic, type, pub, latched, 0, 0);
}

// Setup Control-C handlers
Expand Down
6 changes: 3 additions & 3 deletions src/Publisher.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,13 @@ void ps_pub_publish(struct ps_pub_t* pub, struct ps_msg_t* msg)
if (pub->last_message)
{
//free the old and add the new
ps_msg_ref_free(pub->last_message);// todo use allocator
ps_msg_ref_free(pub->last_message, pub->allocator);
}
pub->last_message = ref;
}
else
{
ps_msg_ref_free(ref);// todo use allocator
ps_msg_ref_free(ref, pub->allocator);
}
}

Expand Down Expand Up @@ -229,7 +229,7 @@ void ps_pub_destroy(struct ps_pub_t* pub)
// free my latched message
if (pub->last_message)
{
ps_msg_ref_free(pub->last_message);// todo use allocator
ps_msg_ref_free(pub->last_message, pub->allocator);
}

pub->clients = 0;
Expand Down
12 changes: 5 additions & 7 deletions src/Serialization.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ struct enumeration
};
#pragma pack(pop)



int ps_serialize_message_definition(void* start, const struct ps_message_definition_t* definition)
{
//ok, write out number of fields
Expand Down Expand Up @@ -665,21 +663,21 @@ void ps_msg_ref_add(struct ps_msg_ref_t* msg)
msg->refcount++;
}

void ps_msg_ref_free(struct ps_msg_ref_t* msg)
void ps_msg_ref_free(struct ps_msg_ref_t* msg, struct ps_allocator_t* allocator)
{
msg->refcount--;
if (msg->refcount == 0)
{
free(msg->data);
free(msg);
allocator->free(msg->data, allocator->context);
allocator->free(msg, allocator->context);
}
}


struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg)
struct ps_msg_t ps_msg_cpy(const struct ps_msg_t* msg, struct ps_allocator_t* allocator)
{
struct ps_msg_t out;
ps_msg_alloc(msg->len, 0, &out);
ps_msg_alloc(msg->len, allocator, &out);
memcpy(ps_get_msg_start(out.data), ps_get_msg_start(msg->data), msg->len);
return out;
}
71 changes: 71 additions & 0 deletions tests/test_pubsub_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,75 @@ TEST(test_publish_subscribe_cpp, []() {
EXPECT(got_message);
});

// tracking allocator usage
static int allocated;
static int freed;
static std::map<void*, uint32_t> sizes;
static ps_allocator_t alloc;
struct TestAllocator
{
static ps_allocator_t* allocator() {
return &alloc;
}

static void* Allocate(uint32_t size, void* context)
{
auto ptr = malloc(size);
sizes[ptr] = size;
allocated += size;
printf("allocate %i\n", allocated);
return ptr;
}

static void Free(void* ptr, void* context)
{
freed += sizes[ptr];
printf("free %i\n", freed);
free(ptr);
}

static void Setup()
{
allocated = 0;
freed = 0;
alloc.context = 0;
alloc.free = Free;
alloc.alloc = Allocate;
}
};

TEST(test_publish_subscribe_allocator_cpp, []() {
TestAllocator::Setup();
// test that allocators are used with C++
pubsub::Node node("simple_publisher");
bool got_message = false;
{
pubsub::Publisher<pubsub::msg::String_<TestAllocator>> string_pub(node, "/data");

pubsub::msg::String_<TestAllocator> omsg;
omsg.value = "Hello";

pubsub::BlockingSpinnerWithTimers spinner;
spinner.setNode(node);

pubsub::Subscriber<pubsub::msg::String_<TestAllocator>> subscriber(node, "/data", [&](const pubsub::msg::String_<TestAllocator>::SharedPtr& msg) {
printf("Got message %s in sub1\n", msg->value.c_str());
EXPECT(omsg.value == msg->value);
spinner.stop();
got_message = true;
}, 10);

spinner.addTimer(0.1, [&]()
{
string_pub.publish(omsg);
});

spinner.run();
}
EXPECT(got_message);

EXPECT(allocated == 20);
EXPECT(allocated == freed);
});

CREATE_MAIN_ENTRY_POINT();
5 changes: 3 additions & 2 deletions tools/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,9 +891,10 @@ std::string generate(const char* definition, const char* name)
output += "namespace " + ns + "\n{\n";
output += "namespace msg\n{\n";
output += "#pragma pack(push, 1)\n";
output += "template <class Allocator = pubsub::DefaultAllocator>\n";
output += "template <class AllocatorT = pubsub::DefaultAllocator>\n";
output += "struct " + raw_name + "_\n{\n";
output += " typedef std::shared_ptr<" + raw_name + "_<Allocator>> SharedPtr;\n";
output += " typedef std::shared_ptr<" + raw_name + "_<AllocatorT>> SharedPtr;\n";
output += " typedef AllocatorT Allocator;\n";
// generate internal structs
for (auto& type: types)
{
Expand Down
4 changes: 2 additions & 2 deletions tools/pubsub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ int topic_pub(int num_args, char** args, ps_node_t* node)
}

// do initial publish
ps_msg_t cpy = ps_msg_cpy(&msg);
ps_msg_t cpy = ps_msg_cpy(&msg, 0);
ps_pub_publish(&pub, &cpy);
break;
}
Expand All @@ -516,7 +516,7 @@ int topic_pub(int num_args, char** args, ps_node_t* node)
ps_node_spin(node);
if (rate != 0 && remaining < pubsub::Duration(0.0))
{
ps_msg_t cpy = ps_msg_cpy(&msg);
ps_msg_t cpy = ps_msg_cpy(&msg, 0);
ps_pub_publish(&pub, &cpy);
next = next + pubsub::Duration(1.0/rate);
}
Expand Down

0 comments on commit 07b52d9

Please sign in to comment.