Skip to content

Commit

Permalink
Write callback with buffer access
Browse files Browse the repository at this point in the history
  • Loading branch information
lukepalmer committed Feb 1, 2025
1 parent 4148275 commit e93e384
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 36 deletions.
116 changes: 89 additions & 27 deletions curl-helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <caml/unixsupport.h>
#include <caml/custom.h>
#include <caml/threads.h>
#include <caml/bigarray.h>

#ifndef CAMLdrop
#define CAMLdrop caml_local_roots = caml__frame
Expand Down Expand Up @@ -106,10 +107,10 @@ typedef enum OcamlValues
Ocaml_IOCTLFUNCTION,
Ocaml_SEEKFUNCTION,
Ocaml_OPENSOCKETFUNCTION,
/* Ocaml_CLOSESOCKETFUNCTION, */
Ocaml_SSH_KEYFUNCTION,

Ocaml_ERRORBUFFER,
Ocaml_CALLBACKBUFFER,
Ocaml_PRIVATE,

/* Not used, last for size */
Expand Down Expand Up @@ -642,14 +643,16 @@ static void resetOcamlValues(Connection* connection)
int i;

for (i = 0; i < OcamlValuesSize; i++)
Store_field(connection->ocamlValues, i, Val_unit);
if (i != Ocaml_CALLBACKBUFFER)
Store_field(connection->ocamlValues, i, Val_unit);
}

static Connection* allocConnection(CURL* h)
{
Connection* connection = (Connection *)malloc(sizeof(Connection));

connection->ocamlValues = caml_alloc(OcamlValuesSize, 0);
Store_field(connection->ocamlValues, Ocaml_CALLBACKBUFFER, caml_ba_alloc_dims(CAML_BA_UINT8 | CAML_BA_C_LAYOUT | CAML_BA_EXTERNAL, 1, NULL, 0));
resetOcamlValues(connection);
caml_register_global_root(&connection->ocamlValues);

Expand Down Expand Up @@ -866,6 +869,44 @@ static size_t cb_WRITEFUNCTION2(char *ptr, size_t size, size_t nmemb, void *data
return r;
}

static size_t cb_WRITEFUNCTION_BUF(char *ptr, size_t size, size_t nmemb, void *data)
{
caml_leave_blocking_section();

CAMLparam0();
CAMLlocal2(result, buf);
Connection *conn = (Connection *)data;

checkConnection(conn);

buf = Field(conn->ocamlValues, Ocaml_CALLBACKBUFFER);
struct caml_ba_array* ba = Caml_ba_array_val(buf);
ba->dim[0]=size*nmemb;
ba->data=ptr;
result = caml_callback_exn(Field(conn->ocamlValues, Ocaml_WRITEFUNCTION), buf);

size_t r = 0;

if (!Is_exception_result(result))
{
if (Is_block(result)) /* Proceed */
{
r = size * nmemb;
}
else
{
if (0 == Int_val(result)) /* Pause */
r = CURL_WRITEFUNC_PAUSE;
/* else 1 = Abort */
}
}

CAMLdrop;

caml_enter_blocking_section();
return r;
}

static size_t cb_READFUNCTION(void *ptr, size_t size, size_t nmemb, void *data)
{
caml_leave_blocking_section();
Expand Down Expand Up @@ -1217,29 +1258,6 @@ static int cb_OPENSOCKETFUNCTION(void *data,
return ((sock == -1) ? CURL_SOCKET_BAD : sock);
}

/*
static int cb_CLOSESOCKETFUNCTION(void *data,
curl_socket_t socket)
{
caml_leave_blocking_section();
CAMLparam0();
CAMLlocal1(camlResult);
Connection *conn = (Connection *)data;
int result = 0;
camlResult = caml_callback_exn(Field(conn->ocamlValues, Ocaml_CLOSESOCKETFUNCTION), Val_int(socket));
if (Is_exception_result(camlResult))
{
result = 1;
}
CAMLdrop;
caml_enter_blocking_section();
return result;
}
*/

static int cb_SSH_KEYFUNCTION(CURL *easy,
const struct curl_khkey *knownkey,
const struct curl_khkey *foundkey,
Expand Down Expand Up @@ -1592,9 +1610,11 @@ static void handle_##name##suffix(Connection *conn, value option) \

#define SETOPT_FUNCTION(name) SETOPT_FUNCTION_(name,FUNCTION)
#define SETOPT_FUNCTION2(name) SETOPT_FUNCTION_(name,FUNCTION2)
#define SETOPT_FUNCTION_BUF(name) SETOPT_FUNCTION_(name,FUNCTION_BUF)

SETOPT_FUNCTION( WRITE)
SETOPT_FUNCTION2( WRITE)
SETOPT_FUNCTION_BUF( WRITE)
SETOPT_FUNCTION( READ)
SETOPT_FUNCTION2( READ)
SETOPT_FUNCTION( HEADER)
Expand All @@ -1617,7 +1637,6 @@ SETOPT_FUNCTION( IOCTL)
#endif

SETOPT_FUNCTION( OPENSOCKET)
/* SETOPT_FUNCTION( CLOSESOCKET) */

static void handle_slist(Connection *conn, struct curl_slist** slist, CURLoption curl_option, value option)
{
Expand Down Expand Up @@ -2932,6 +2951,7 @@ static void handle_FTP_FILEMETHOD(Connection *conn, value option)
result = curl_easy_setopt(conn->handle,
CURLOPT_FTP_FILEMETHOD,
CURLFTPMETHOD_SINGLECWD);
break;

default:
caml_failwith("Invalid FTP_FILEMETHOD value");
Expand Down Expand Up @@ -3683,7 +3703,6 @@ CURLOptionMapping implementedOptionMap[] =
HAVENOT(AUTOREFERER),
#endif
HAVE(OPENSOCKETFUNCTION),
/*HAVE(CLOSESOCKETFUNCTION),*/
#if HAVE_DECL_CURLOPT_PROXYTYPE
HAVE(PROXYTYPE),
#else
Expand Down Expand Up @@ -3778,6 +3797,7 @@ CURLOptionMapping implementedOptionMap[] =
HAVENOT(PROXY_SSL_OPTIONS),
#endif
HAVE(WRITEFUNCTION2),
HAVE(WRITEFUNCTION_BUF),
HAVE(READFUNCTION2),
#if HAVE_DECL_CURLOPT_XFERINFOFUNCTION
HAVE(XFERINFOFUNCTION),
Expand Down Expand Up @@ -4597,6 +4617,7 @@ enum
{
curlmopt_socket_function,
curlmopt_timer_function,
curlmopt_closesocket_function,

/* last, not used */
multi_values_total
Expand Down Expand Up @@ -4879,13 +4900,44 @@ value caml_curl_multi_poll(value v_timeout_ms, value v_extra_fds, value v_multi)
CAMLreturn(Val_bool(numfds != 0));
}

static int curlm_closesocket_cb(void *data, curl_socket_t socket)
{
caml_leave_blocking_section();

CAMLparam0();
CAMLlocal1(camlResult);

ml_multi_handle* multi = (ml_multi_handle*)data;
int result = 0;
camlResult = caml_callback_exn(Field(multi->values, curlmopt_closesocket_function), Val_socket(socket));
if (Is_exception_result(camlResult))
{
result = 1;
}
CAMLdrop;

caml_enter_blocking_section();
return result;
}

value caml_curl_multi_add_handle(value v_multi, value v_easy)
{
CAMLparam2(v_multi,v_easy);
CURLMcode rc = CURLM_OK;
CURLM* multi = CURLM_val(v_multi);
Connection* conn = Connection_val(v_easy);

if (Field(Multi_val(v_multi)->values, curlmopt_closesocket_function) != Val_unit)
{
CURLcode result = CURLE_OK;
result = curl_easy_setopt(conn->handle, CURLOPT_CLOSESOCKETDATA, Multi_val(v_multi));
if (result != CURLE_OK)
raiseError(conn, result);
result = curl_easy_setopt(conn->handle, CURLOPT_CLOSESOCKETFUNCTION, curlm_closesocket_cb);
if (result != CURLE_OK)
raiseError(conn, result);
}

/* prevent collection of OCaml value while the easy handle is used
and may invoke callbacks registered on OCaml side */
conn->refcount++;
Expand Down Expand Up @@ -5082,6 +5134,16 @@ value caml_curl_multi_timerfunction(value v_multi, value v_cb)
CAMLreturn(Val_unit);
}

value caml_curl_multi_closesocketfunction(value v_multi, value v_cb)
{
CAMLparam2(v_multi, v_cb);
ml_multi_handle* multi = Multi_val(v_multi);

Store_field(multi->values, curlmopt_closesocket_function, v_cb);

CAMLreturn(Val_unit);
}

value caml_curl_multi_timeout(value v_multi)
{
CAMLparam1(v_multi);
Expand Down
7 changes: 7 additions & 0 deletions curl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ type 'a xfer_result = Proceed of 'a | Pause | Abort

type write_result = unit xfer_result
type read_result = string xfer_result
type bigstring = (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t

let proceed = Proceed ()

Expand Down Expand Up @@ -491,6 +492,7 @@ type curlOption =
| CURLOPT_SSL_OPTIONS of curlSslOption list
| CURLOPT_PROXY_SSL_OPTIONS of curlSslOption list
| CURLOPT_WRITEFUNCTION2 of (string -> write_result)
| CURLOPT_WRITEFUNCTION_BUF of (bigstring -> write_result)
| CURLOPT_READFUNCTION2 of (int -> read_result)
| CURLOPT_XFERINFOFUNCTION of (int64 -> int64 -> int64 -> int64 -> bool)
| CURLOPT_PREREQFUNCTION of (string -> string -> int -> int -> bool)
Expand Down Expand Up @@ -625,6 +627,7 @@ let errno = int_of_curlCode

type pauseOption = PAUSE_SEND | PAUSE_RECV | PAUSE_ALL


external pause : t -> pauseOption list -> unit = "caml_curl_pause"

let set_writefunction conn closure =
Expand All @@ -633,6 +636,9 @@ let set_writefunction conn closure =
let set_writefunction2 conn closure =
setopt conn (CURLOPT_WRITEFUNCTION2 closure)

let set_writefunction_buf conn closure =
setopt conn (CURLOPT_WRITEFUNCTION_BUF closure)

let set_readfunction conn closure =
setopt conn (CURLOPT_READFUNCTION closure)

Expand Down Expand Up @@ -1565,6 +1571,7 @@ module Multi = struct

external set_socket_function : mt -> (Unix.file_descr -> poll -> unit) -> unit = "caml_curl_multi_socketfunction"
external set_timer_function : mt -> (int -> unit) -> unit = "caml_curl_multi_timerfunction"
external set_closesocket_function : mt -> (Unix.file_descr -> unit) -> unit = "caml_curl_multi_closesocketfunction"
external action_all : mt -> int = "caml_curl_multi_socket_all"
external socket_action : mt -> Unix.file_descr option -> fd_status -> int = "caml_curl_multi_socket_action"

Expand Down
29 changes: 20 additions & 9 deletions curl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ type 'a xfer_result = Proceed of 'a | Pause | Abort

type write_result = unit xfer_result
type read_result = string xfer_result
type bigstring = (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t

val proceed : write_result

Expand Down Expand Up @@ -478,7 +479,6 @@ type curlOption =
| CURLOPT_SEEKFUNCTION of (int64 -> curlSeek -> curlSeekResult)
| CURLOPT_AUTOREFERER of bool
| CURLOPT_OPENSOCKETFUNCTION of (Unix.file_descr -> unit)
(* | CURLOPT_CLOSESOCKETFUNCTION of (Unix.file_descr -> unit) *)
| CURLOPT_PROXYTYPE of curlProxyType
| CURLOPT_PROTOCOLS of curlProto list
| CURLOPT_REDIR_PROTOCOLS of curlProto list
Expand All @@ -504,6 +504,7 @@ type curlOption =
| CURLOPT_SSL_OPTIONS of curlSslOption list
| CURLOPT_PROXY_SSL_OPTIONS of curlSslOption list
| CURLOPT_WRITEFUNCTION2 of (string -> write_result)
| CURLOPT_WRITEFUNCTION_BUF of (bigstring -> write_result)
| CURLOPT_READFUNCTION2 of (int -> read_result)
| CURLOPT_XFERINFOFUNCTION of (int64 -> int64 -> int64 -> int64 -> bool)
| CURLOPT_PREREQFUNCTION of (string -> string -> int -> int -> bool)
Expand Down Expand Up @@ -652,8 +653,8 @@ val pause : t -> pauseOption list -> unit
(** {2 Set transfer options}
All callback functions shouldn't raise exceptions.
Any exception raised in callback function will be silently caught and discared,
and transfer will be aborted. *)
Any exception raised in callback function will be silently caught and discarded,
and the transfer will be aborted. *)

val set_writefunction : t -> (string -> int) -> unit

Expand All @@ -662,6 +663,14 @@ val set_writefunction : t -> (string -> int) -> unit
do not try to call unpause from another thread, see libcurl documentation for details *)
val set_writefunction2 : t -> (string -> write_result) -> unit

(** A write callback that provides direct access to curl's receive buffer. The provided
buffer may only be read within the callback. It is illegal for the buffer to escape
the scope of the callback.
This function provides better performance than string-based variants by avoiding an
intermediate copy. *)
val set_writefunction_buf : t -> (bigstring -> write_result) -> unit

(** [readfunction n] should return string of length at most [n], otherwise
transfer will be aborted (as if with exception) *)
val set_readfunction : t -> (int -> string) -> unit
Expand Down Expand Up @@ -808,11 +817,6 @@ val set_opensocketfunction : t -> (Unix.file_descr -> unit) -> unit
val set_tcpkeepalive : t -> bool -> unit
val set_tcpkeepidle : t -> int -> unit
val set_tcpkeepintvl : t -> int -> unit

(** current implementation is faulty
ref https://github.com/ygrek/ocurl/issues/58
val set_closesocketfunction : t -> (Unix.file_descr -> unit) -> unit
*)
val set_proxytype : t -> curlProxyType -> unit
val set_protocols : t -> curlProto list -> unit
val set_redirprotocols : t -> curlProto list -> unit
Expand Down Expand Up @@ -1048,7 +1052,6 @@ class handle :
method set_seekfunction : (int64 -> curlSeek -> curlSeekResult) -> unit
method set_autoreferer : bool -> unit
method set_opensocketfunction : (Unix.file_descr -> unit) -> unit
(* method set_closesocketfunction : (Unix.file_descr -> unit) -> unit *)
method set_proxytype : curlProxyType -> unit
method set_resolve : (string * int * string) list -> (string * int) list -> unit
method set_dns_servers : string list -> unit
Expand Down Expand Up @@ -1198,6 +1201,14 @@ module Multi : sig
NB {!action_timeout} should be called when timeout occurs *)
val set_timer_function : mt -> (int -> unit) -> unit

(** set a function to be called to close a socket.
The underlying callback is a property of the easy handle, but cannot be stored there
in these bindings because libcurl's use of this callback may occur outside the
lifetime of the easy handle. This means that all easy handles added to a multi
handle must use the same closesocket function. *)
val set_closesocket_function : mt -> (Unix.file_descr -> unit) -> unit

(** perform pending data transfers (if any) on all handles currently in multi stack
(not recommended, {!action} should be used instead)
@return the number of handles that still transfer data
Expand Down

0 comments on commit e93e384

Please sign in to comment.