Skip to content

Commit

Permalink
Fix aiortc datachannel not open issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sepfy committed Nov 18, 2024
1 parent da7cfcb commit c75499a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 54 deletions.
4 changes: 2 additions & 2 deletions src/peer_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ PeerConnection* peer_connection_create(PeerConfiguration* config) {

void peer_connection_destroy(PeerConnection* pc) {
if (pc) {
sctp_destroy_socket(&pc->sctp);
sctp_destroy_association(&pc->sctp);
dtls_srtp_deinit(&pc->dtls_srtp);
agent_destroy(&pc->agent);
buffer_free(pc->data_rb);
Expand Down Expand Up @@ -386,7 +386,7 @@ int peer_connection_loop(PeerConnection* pc) {

if (pc->config.datachannel) {
LOGI("SCTP create socket");
sctp_create_socket(&pc->sctp, &pc->dtls_srtp);
sctp_create_association(&pc->sctp, &pc->dtls_srtp);
pc->sctp.userdata = pc->config.user_data;
}

Expand Down
139 changes: 91 additions & 48 deletions src/sctp.c
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
#include <stdlib.h>
#include <string.h>

#include "dtls_srtp.h"
#include "sctp.h"
#include "utils.h"
#if CONFIG_USE_USRSCTP
#include <usrsctp.h>
#endif

#include "dtls_srtp.h"
#include "utils.h"

#define DATA_CHANNEL_PPID_CONTROL 50
#define DATA_CHANNEL_PPID_DOMSTRING 51
#define DATA_CHANNEL_PPID_BINARY_PARTIAL 52
#define DATA_CHANNEL_PPID_BINARY 53
#define DATA_CHANNEL_PPID_DOMSTRING_PARTIAL 54
#define DATA_CHANNEL_OPEN 0x03

static const uint32_t crc32c_table[256] = {
0x00000000L, 0xF26B8303L, 0xE13B70F7L, 0x1350F3F4L,
0xC79A971FL, 0x35F1141CL, 0x26A1E7E8L, 0xD4CA64EBL,
Expand Down Expand Up @@ -98,7 +90,6 @@ static int sctp_outgoing_data_cb(void* userdata, void* buf, size_t len, uint8_t
Sctp* sctp = (Sctp*)userdata;

dtls_srtp_write(sctp->dtls_srtp, buf, len);

return 0;
}

Expand All @@ -114,8 +105,9 @@ int sctp_outgoing_data(Sctp* sctp, char* buf, size_t len, SctpDataPpid ppid, uin
spa.sendv_sndinfo.snd_ppid = htonl(ppid);

res = usrsctp_sendv(sctp->sock, buf, len, NULL, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
if (res < 0)
LOGE("sctp sendv error %d %s", errno, strerror(errno));
if (res < 0) {
LOGE("sctp sendv error %d: %s", errno, strerror(errno));
}
return res;
#else
size_t padding_len = 0;
Expand Down Expand Up @@ -201,6 +193,8 @@ void sctp_parse_data_channel_open(Sctp* sctp, uint16_t sid, char* data, size_t l

// Add stream mapping
sctp_add_stream_mapping(sctp, label_str, sid);
char ack = DATA_CHANNEL_ACK;
sctp_outgoing_data(sctp, &ack, 1, DATA_CHANNEL_PPID_CONTROL, sid);
}
}

Expand Down Expand Up @@ -229,8 +223,6 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) {
size_t length = 0;
size_t pos = sizeof(SctpHeader);
SctpChunkCommon* chunk_common;
SctpDataChunk* data_chunk;
SctpSackChunk* sack;
SctpPacket* in_packet = (SctpPacket*)buf;
SctpPacket* out_packet = (SctpPacket*)sctp->buf;

Expand All @@ -252,39 +244,42 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) {

// prepare outgoing packet
memset(sctp->buf, 0, sizeof(sctp->buf));

// chunks
while ((4 * (pos + 3) / 4) < len) {
chunk_common = (SctpChunkCommon*)(buf + pos);

switch (chunk_common->type) {
case SCTP_DATA:

data_chunk = (SctpDataChunk*)(buf + pos);
LOGD("SCTP_DATA. ppid = %ld", ntohl(data_chunk->ppid));

// XXX: not check DATA_CHANNEL_OPEN?
#if 0
case SCTP_DATA: {
SctpDataChunk* data_chunk = (SctpDataChunk*)(buf + pos);
SctpSackChunk* sack_chunk = (SctpSackChunk*)out_packet->chunks;

sack_chunk->common.type = SCTP_SACK;
sack_chunk->common.flags = 0x00;
sack_chunk->common.length = htons(16);
sack_chunk->cumulative_tsn_ack = data_chunk->tsn;
sack_chunk->a_rwnd = htonl(0x02);
length = ntohs(sack_chunk->common.length) + sizeof(SctpHeader);

LOGD("SCTP_DATA. ppid = %ld, data = %d", ntohl(data_chunk->ppid), data_chunk->data[0]);
if (ntohl(data_chunk->ppid) == DATA_CHANNEL_PPID_CONTROL && data_chunk->data[0] == DATA_CHANNEL_OPEN) {

data_chunk = (SctpDataChunk*)sack_chunk->blocks;
data_chunk->type = SCTP_DATA;
data_chunk->iube = 0x03;
data_chunk->tsn = htonl(sctp->tsn++);
data_chunk->sid = htons(0);
data_chunk->sqn = htons(0);
data_chunk->ppid = htonl(DATA_CHANNEL_PPID_CONTROL);
data_chunk->length = htons(1 + sizeof(SctpDataChunk));
data_chunk->data[0] = DATA_CHANNEL_ACK;
length += ntohs(data_chunk->length);
} else if (ntohl(data_chunk->ppid) == DATA_CHANNEL_PPID_DOMSTRING) {
#endif
if (ntohl(data_chunk->ppid) == DATA_CHANNEL_PPID_DOMSTRING) {
if (sctp->onmessage) {
sctp->onmessage((char*)data_chunk->data, ntohs(data_chunk->length) - sizeof(SctpDataChunk), sctp->userdata, ntohs(data_chunk->sid));
sctp->onmessage((char*)data_chunk->data, ntohs(data_chunk->length) - sizeof(SctpDataChunk),
sctp->userdata, ntohs(data_chunk->sid));
}
}

sack = (SctpSackChunk*)out_packet->chunks;
sack->common.type = SCTP_SACK;
sack->common.flags = 0x00;
sack->common.length = htons(16);
sack->cumulative_tsn_ack = data_chunk->tsn;
sack->a_rwnd = htonl(0x02);
length = ntohs(sack->common.length) + sizeof(SctpHeader);
pos = len; // Do not handle other msg
break;
case SCTP_INIT:
} break;
case SCTP_INIT: {
LOGD("SCTP_INIT");

SctpInitChunk* init_chunk;
Expand All @@ -304,11 +299,36 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) {
SctpChunkParam* param = init_ack->param;

param->type = htons(SCTP_PARAM_STATE_COOKIE);
param->length = htons(0x08);
uint32_t value = htonl(0x02);
memcpy(&param->value, &value, 4);
param->length = htons(8);
*(uint32_t*)&param->value = htonl(0x02);
length = ntohs(init_ack->common.length) + sizeof(SctpHeader);
break;
} break;
case SCTP_INIT_ACK: {
SctpInitChunk* init_ack = (SctpInitChunk*)in_packet->chunks;
SctpCookieEchoChunk* cookie_echo = (SctpCookieEchoChunk*)out_packet->chunks;
SctpChunkParam* param;
sctp->verification_tag = init_ack->initiate_tag;
int type;
// find cookie
uint8_t* cookie = NULL;
cookie = (uint8_t*)&init_ack->param[0];
for (int i = 0; i < init_ack->common.length - 20; i += 2) {
type = ntohs(*(uint16_t*)&cookie[i]);
// find cookie param
if (type == 0x07) {
param = (SctpChunkParam*)&cookie[i];
break;
}
}

cookie_echo->common.type = SCTP_COOKIE_ECHO;
cookie_echo->common.flags = 0x00;
// cookie echo: type + flag + length (4 bytes) + cookie
cookie_echo->common.length = htons(ntohs(param->length));
// param: type + length (4 bytes) + cookie
memcpy(cookie_echo->cookie, param->value, ntohs(param->length) - 4);
length = ntohs(cookie_echo->common.length) + sizeof(SctpHeader);
} break;
case SCTP_SACK:
#if 0
LOGD("SCTP_SACK");
Expand Down Expand Up @@ -341,7 +361,7 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) {
}
#endif
break;
case SCTP_COOKIE_ECHO:
case SCTP_COOKIE_ECHO: {
LOGD("SCTP_COOKIE_ECHO");
SctpChunkCommon* common = (SctpChunkCommon*)out_packet->chunks;
common->type = SCTP_COOKIE_ACK;
Expand All @@ -356,7 +376,7 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) {
sctp->onopen(sctp->userdata);
}
}
break;
} break;
case SCTP_ABORT:
sctp->connected = 0;
if (sctp->onclose) {
Expand All @@ -383,7 +403,6 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) {
}
pos += ntohs(chunk_common->length);
}

#endif
}

Expand Down Expand Up @@ -464,7 +483,7 @@ static int sctp_incoming_data_cb(struct socket* sock, union sctp_sockstore addr,
}
#endif

int sctp_create_socket(Sctp* sctp, DtlsSrtp* dtls_srtp) {
int sctp_create_association(Sctp* sctp, DtlsSrtp* dtls_srtp) {
sctp->dtls_srtp = dtls_srtp;
sctp->local_port = 5000;
sctp->remote_port = 5000;
Expand Down Expand Up @@ -561,17 +580,41 @@ int sctp_create_socket(Sctp* sctp, DtlsSrtp* dtls_srtp) {
} while (0);

if (ret < 0) {
sctp_destroy_socket(sctp);
sctp_destroy_association(sctp);
return -1;
}

sctp->sock = sock;
#else
// send SCTP_INIT
int length = 0;
SctpInitChunk* init_chunk;
SctpHeader* header;
SctpPacket* out_packet = (SctpPacket*)sctp->buf;
header = &out_packet->header;
init_chunk = (SctpInitChunk*)out_packet->chunks;

header->source_port = htons(sctp->local_port);
header->destination_port = htons(sctp->remote_port);
header->verification_tag = 0x0;
init_chunk->common.type = SCTP_INIT;
init_chunk->common.flags = 0x00;
init_chunk->common.length = htons(20);
init_chunk->initiate_tag = htonl(0x12345678);
init_chunk->a_rwnd = htonl(0x100000);
init_chunk->number_of_outbound_streams = 0xffff;
init_chunk->number_of_inbound_streams = 0xffff;
init_chunk->initial_tsn = htonl(sctp->tsn);
length = ntohs(init_chunk->common.length) + sizeof(SctpHeader);
length = (4 * ((length + 3) / 4));
header->checksum = sctp_get_checksum(sctp, sctp->buf, length);
dtls_srtp_write(sctp->dtls_srtp, sctp->buf, length);
#endif

return 0;
}

void sctp_destroy_socket(Sctp* sctp) {
void sctp_destroy_association(Sctp* sctp) {
#if CONFIG_USE_USRSCTP
if (sctp && sctp->sock) {
usrsctp_shutdown(sctp->sock, SHUT_RDWR);
Expand Down
23 changes: 19 additions & 4 deletions src/sctp.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,25 @@
#include "dtls_srtp.h"
#include "utils.h"

#if !CONFIG_USE_USRSCTP

typedef enum DecpMsgType {

DATA_CHANNEL_OPEN = 0x03,
DATA_CHANNEL_ACK = 0x02,

} DecpMsgType;

typedef enum DataChannelPpid {

DATA_CHANNEL_PPID_CONTROL = 50,
DATA_CHANNEL_PPID_DOMSTRING = 51,
DATA_CHANNEL_PPID_BINARY_PARTIAL = 52,
DATA_CHANNEL_PPID_BINARY = 53,
DATA_CHANNEL_PPID_DOMSTRING_PARTIAL = 54

} DataChannelPpid;

#if !CONFIG_USE_USRSCTP

typedef struct SctpChunkParam {
uint16_t type;
uint16_t length;
Expand Down Expand Up @@ -114,6 +124,11 @@ typedef struct SctpInitChunk {

} SctpInitChunk;

typedef struct SctpCookieEchoChunk {
SctpChunkCommon common;
uint8_t cookie[0];
} SctpCookieEchoChunk;

#endif

typedef enum SctpDataPpid {
Expand Down Expand Up @@ -155,9 +170,9 @@ typedef struct Sctp {
uint8_t buf[CONFIG_MTU];
} Sctp;

int sctp_create_socket(Sctp* sctp, DtlsSrtp* dtls_srtp);
int sctp_create_association(Sctp* sctp, DtlsSrtp* dtls_srtp);

void sctp_destroy_socket(Sctp* sctp);
void sctp_destroy_association(Sctp* sctp);

int sctp_is_connected(Sctp* sctp);

Expand Down

0 comments on commit c75499a

Please sign in to comment.