diff --git a/src/peer_connection.c b/src/peer_connection.c index 56ec655..608f672 100644 --- a/src/peer_connection.c +++ b/src/peer_connection.c @@ -208,7 +208,7 @@ PeerConnection* peer_connection_create(PeerConfiguration* config) { void peer_connection_destroy(PeerConnection* pc) { if (pc) { - sctp_destroy_socket(&pc->sctp); + sctp_destroy_association(&pc->sctp); dtls_srtp_deinit(&pc->dtls_srtp); agent_destroy(&pc->agent); buffer_free(pc->data_rb); @@ -386,7 +386,7 @@ int peer_connection_loop(PeerConnection* pc) { if (pc->config.datachannel) { LOGI("SCTP create socket"); - sctp_create_socket(&pc->sctp, &pc->dtls_srtp); + sctp_create_association(&pc->sctp, &pc->dtls_srtp); pc->sctp.userdata = pc->config.user_data; } diff --git a/src/sctp.c b/src/sctp.c index 76ebeaa..a30c364 100644 --- a/src/sctp.c +++ b/src/sctp.c @@ -1,21 +1,13 @@ #include #include +#include "dtls_srtp.h" #include "sctp.h" +#include "utils.h" #if CONFIG_USE_USRSCTP #include #endif -#include "dtls_srtp.h" -#include "utils.h" - -#define DATA_CHANNEL_PPID_CONTROL 50 -#define DATA_CHANNEL_PPID_DOMSTRING 51 -#define DATA_CHANNEL_PPID_BINARY_PARTIAL 52 -#define DATA_CHANNEL_PPID_BINARY 53 -#define DATA_CHANNEL_PPID_DOMSTRING_PARTIAL 54 -#define DATA_CHANNEL_OPEN 0x03 - static const uint32_t crc32c_table[256] = { 0x00000000L, 0xF26B8303L, 0xE13B70F7L, 0x1350F3F4L, 0xC79A971FL, 0x35F1141CL, 0x26A1E7E8L, 0xD4CA64EBL, @@ -98,7 +90,6 @@ static int sctp_outgoing_data_cb(void* userdata, void* buf, size_t len, uint8_t Sctp* sctp = (Sctp*)userdata; dtls_srtp_write(sctp->dtls_srtp, buf, len); - return 0; } @@ -114,8 +105,9 @@ int sctp_outgoing_data(Sctp* sctp, char* buf, size_t len, SctpDataPpid ppid, uin spa.sendv_sndinfo.snd_ppid = htonl(ppid); res = usrsctp_sendv(sctp->sock, buf, len, NULL, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0); - if (res < 0) - LOGE("sctp sendv error %d %s", errno, strerror(errno)); + if (res < 0) { + LOGE("sctp sendv error %d: %s", errno, strerror(errno)); + } return res; #else size_t padding_len = 0; @@ -201,6 +193,8 @@ void sctp_parse_data_channel_open(Sctp* sctp, uint16_t sid, char* data, size_t l // Add stream mapping sctp_add_stream_mapping(sctp, label_str, sid); + char ack = DATA_CHANNEL_ACK; + sctp_outgoing_data(sctp, &ack, 1, DATA_CHANNEL_PPID_CONTROL, sid); } } @@ -229,8 +223,6 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) { size_t length = 0; size_t pos = sizeof(SctpHeader); SctpChunkCommon* chunk_common; - SctpDataChunk* data_chunk; - SctpSackChunk* sack; SctpPacket* in_packet = (SctpPacket*)buf; SctpPacket* out_packet = (SctpPacket*)sctp->buf; @@ -252,39 +244,42 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) { // prepare outgoing packet memset(sctp->buf, 0, sizeof(sctp->buf)); - - // chunks while ((4 * (pos + 3) / 4) < len) { chunk_common = (SctpChunkCommon*)(buf + pos); switch (chunk_common->type) { - case SCTP_DATA: - - data_chunk = (SctpDataChunk*)(buf + pos); - LOGD("SCTP_DATA. ppid = %ld", ntohl(data_chunk->ppid)); - -// XXX: not check DATA_CHANNEL_OPEN? -#if 0 + case SCTP_DATA: { + SctpDataChunk* data_chunk = (SctpDataChunk*)(buf + pos); + SctpSackChunk* sack_chunk = (SctpSackChunk*)out_packet->chunks; + + sack_chunk->common.type = SCTP_SACK; + sack_chunk->common.flags = 0x00; + sack_chunk->common.length = htons(16); + sack_chunk->cumulative_tsn_ack = data_chunk->tsn; + sack_chunk->a_rwnd = htonl(0x02); + length = ntohs(sack_chunk->common.length) + sizeof(SctpHeader); + + LOGD("SCTP_DATA. ppid = %ld, data = %d", ntohl(data_chunk->ppid), data_chunk->data[0]); if (ntohl(data_chunk->ppid) == DATA_CHANNEL_PPID_CONTROL && data_chunk->data[0] == DATA_CHANNEL_OPEN) { - + data_chunk = (SctpDataChunk*)sack_chunk->blocks; + data_chunk->type = SCTP_DATA; + data_chunk->iube = 0x03; + data_chunk->tsn = htonl(sctp->tsn++); + data_chunk->sid = htons(0); + data_chunk->sqn = htons(0); + data_chunk->ppid = htonl(DATA_CHANNEL_PPID_CONTROL); + data_chunk->length = htons(1 + sizeof(SctpDataChunk)); + data_chunk->data[0] = DATA_CHANNEL_ACK; + length += ntohs(data_chunk->length); } else if (ntohl(data_chunk->ppid) == DATA_CHANNEL_PPID_DOMSTRING) { -#endif - if (ntohl(data_chunk->ppid) == DATA_CHANNEL_PPID_DOMSTRING) { if (sctp->onmessage) { - sctp->onmessage((char*)data_chunk->data, ntohs(data_chunk->length) - sizeof(SctpDataChunk), sctp->userdata, ntohs(data_chunk->sid)); + sctp->onmessage((char*)data_chunk->data, ntohs(data_chunk->length) - sizeof(SctpDataChunk), + sctp->userdata, ntohs(data_chunk->sid)); } } - - sack = (SctpSackChunk*)out_packet->chunks; - sack->common.type = SCTP_SACK; - sack->common.flags = 0x00; - sack->common.length = htons(16); - sack->cumulative_tsn_ack = data_chunk->tsn; - sack->a_rwnd = htonl(0x02); - length = ntohs(sack->common.length) + sizeof(SctpHeader); pos = len; // Do not handle other msg - break; - case SCTP_INIT: + } break; + case SCTP_INIT: { LOGD("SCTP_INIT"); SctpInitChunk* init_chunk; @@ -304,11 +299,36 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) { SctpChunkParam* param = init_ack->param; param->type = htons(SCTP_PARAM_STATE_COOKIE); - param->length = htons(0x08); - uint32_t value = htonl(0x02); - memcpy(¶m->value, &value, 4); + param->length = htons(8); + *(uint32_t*)¶m->value = htonl(0x02); length = ntohs(init_ack->common.length) + sizeof(SctpHeader); - break; + } break; + case SCTP_INIT_ACK: { + SctpInitChunk* init_ack = (SctpInitChunk*)in_packet->chunks; + SctpCookieEchoChunk* cookie_echo = (SctpCookieEchoChunk*)out_packet->chunks; + SctpChunkParam* param; + sctp->verification_tag = init_ack->initiate_tag; + int type; + // find cookie + uint8_t* cookie = NULL; + cookie = (uint8_t*)&init_ack->param[0]; + for (int i = 0; i < init_ack->common.length - 20; i += 2) { + type = ntohs(*(uint16_t*)&cookie[i]); + // find cookie param + if (type == 0x07) { + param = (SctpChunkParam*)&cookie[i]; + break; + } + } + + cookie_echo->common.type = SCTP_COOKIE_ECHO; + cookie_echo->common.flags = 0x00; + // cookie echo: type + flag + length (4 bytes) + cookie + cookie_echo->common.length = htons(ntohs(param->length)); + // param: type + length (4 bytes) + cookie + memcpy(cookie_echo->cookie, param->value, ntohs(param->length) - 4); + length = ntohs(cookie_echo->common.length) + sizeof(SctpHeader); + } break; case SCTP_SACK: #if 0 LOGD("SCTP_SACK"); @@ -341,7 +361,7 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) { } #endif break; - case SCTP_COOKIE_ECHO: + case SCTP_COOKIE_ECHO: { LOGD("SCTP_COOKIE_ECHO"); SctpChunkCommon* common = (SctpChunkCommon*)out_packet->chunks; common->type = SCTP_COOKIE_ACK; @@ -356,7 +376,7 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) { sctp->onopen(sctp->userdata); } } - break; + } break; case SCTP_ABORT: sctp->connected = 0; if (sctp->onclose) { @@ -383,7 +403,6 @@ void sctp_incoming_data(Sctp* sctp, char* buf, size_t len) { } pos += ntohs(chunk_common->length); } - #endif } @@ -464,7 +483,7 @@ static int sctp_incoming_data_cb(struct socket* sock, union sctp_sockstore addr, } #endif -int sctp_create_socket(Sctp* sctp, DtlsSrtp* dtls_srtp) { +int sctp_create_association(Sctp* sctp, DtlsSrtp* dtls_srtp) { sctp->dtls_srtp = dtls_srtp; sctp->local_port = 5000; sctp->remote_port = 5000; @@ -561,17 +580,41 @@ int sctp_create_socket(Sctp* sctp, DtlsSrtp* dtls_srtp) { } while (0); if (ret < 0) { - sctp_destroy_socket(sctp); + sctp_destroy_association(sctp); return -1; } sctp->sock = sock; +#else + // send SCTP_INIT + int length = 0; + SctpInitChunk* init_chunk; + SctpHeader* header; + SctpPacket* out_packet = (SctpPacket*)sctp->buf; + header = &out_packet->header; + init_chunk = (SctpInitChunk*)out_packet->chunks; + + header->source_port = htons(sctp->local_port); + header->destination_port = htons(sctp->remote_port); + header->verification_tag = 0x0; + init_chunk->common.type = SCTP_INIT; + init_chunk->common.flags = 0x00; + init_chunk->common.length = htons(20); + init_chunk->initiate_tag = htonl(0x12345678); + init_chunk->a_rwnd = htonl(0x100000); + init_chunk->number_of_outbound_streams = 0xffff; + init_chunk->number_of_inbound_streams = 0xffff; + init_chunk->initial_tsn = htonl(sctp->tsn); + length = ntohs(init_chunk->common.length) + sizeof(SctpHeader); + length = (4 * ((length + 3) / 4)); + header->checksum = sctp_get_checksum(sctp, sctp->buf, length); + dtls_srtp_write(sctp->dtls_srtp, sctp->buf, length); #endif return 0; } -void sctp_destroy_socket(Sctp* sctp) { +void sctp_destroy_association(Sctp* sctp) { #if CONFIG_USE_USRSCTP if (sctp && sctp->sock) { usrsctp_shutdown(sctp->sock, SHUT_RDWR); diff --git a/src/sctp.h b/src/sctp.h index c7f70c7..50e6464 100644 --- a/src/sctp.h +++ b/src/sctp.h @@ -6,8 +6,6 @@ #include "dtls_srtp.h" #include "utils.h" -#if !CONFIG_USE_USRSCTP - typedef enum DecpMsgType { DATA_CHANNEL_OPEN = 0x03, @@ -15,6 +13,18 @@ typedef enum DecpMsgType { } DecpMsgType; +typedef enum DataChannelPpid { + + DATA_CHANNEL_PPID_CONTROL = 50, + DATA_CHANNEL_PPID_DOMSTRING = 51, + DATA_CHANNEL_PPID_BINARY_PARTIAL = 52, + DATA_CHANNEL_PPID_BINARY = 53, + DATA_CHANNEL_PPID_DOMSTRING_PARTIAL = 54 + +} DataChannelPpid; + +#if !CONFIG_USE_USRSCTP + typedef struct SctpChunkParam { uint16_t type; uint16_t length; @@ -114,6 +124,11 @@ typedef struct SctpInitChunk { } SctpInitChunk; +typedef struct SctpCookieEchoChunk { + SctpChunkCommon common; + uint8_t cookie[0]; +} SctpCookieEchoChunk; + #endif typedef enum SctpDataPpid { @@ -155,9 +170,9 @@ typedef struct Sctp { uint8_t buf[CONFIG_MTU]; } Sctp; -int sctp_create_socket(Sctp* sctp, DtlsSrtp* dtls_srtp); +int sctp_create_association(Sctp* sctp, DtlsSrtp* dtls_srtp); -void sctp_destroy_socket(Sctp* sctp); +void sctp_destroy_association(Sctp* sctp); int sctp_is_connected(Sctp* sctp);