From 5390c25b60e79f87aca339f7428575066b0b2d08 Mon Sep 17 00:00:00 2001 From: Dominik Picheta Date: Fri, 3 Jun 2016 13:22:18 +0100 Subject: [PATCH] Modified #3472 to make its API more idiomatic. --- examples/ssl/extradata.nim | 26 +++++++++---- lib/pure/net.nim | 80 +++++++++++++++++++++++++------------- 2 files changed, 71 insertions(+), 35 deletions(-) diff --git a/examples/ssl/extradata.nim b/examples/ssl/extradata.nim index f86dc57f26d07..1e3b89b02168c 100644 --- a/examples/ssl/extradata.nim +++ b/examples/ssl/extradata.nim @@ -1,14 +1,26 @@ # Stores extra data inside the SSL context. import net +let ctx = newContext() + # Our unique index for storing foos -let fooIndex = getSslContextExtraDataIndex() +let fooIndex = ctx.getExtraDataIndex() # And another unique index for storing foos -let barIndex = getSslContextExtraDataIndex() +let barIndex = ctx.getExtraDataIndex() echo "got indexes ", fooIndex, " ", barIndex -let ctx = newContext() -assert ctx.getExtraData(fooIndex) == nil -let foo: int = 5 -ctx.setExtraData(fooIndex, cast[pointer](foo)) -assert cast[int](ctx.getExtraData(fooIndex)) == foo +try: + discard ctx.getExtraData(fooIndex) + assert false +except IndexError: + echo("Success") + +type + FooRef = ref object of RootRef + foo: int + +let foo = FooRef(foo: 5) +ctx.setExtraData(fooIndex, foo) +doAssert ctx.getExtraData(fooIndex).FooRef == foo + +ctx.destroyContext() diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 85d4245b2762c..d6ec314810cca 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -66,7 +66,7 @@ ## {.deadCodeElim: on.} -import nativesockets, os, strutils, parseutils, times +import nativesockets, os, strutils, parseutils, times, sets export Port, `$`, `==` export Domain, SockType, Protocol @@ -88,7 +88,10 @@ when defineSsl: SslProtVersion* = enum protSSLv2, protSSLv3, protTLSv1, protSSLv23 - SslContext* = distinct SslCtx + SslContext* = ref object + context: SslCtx + extraInternalIndex: int + referencedData: HashSet[int] SslAcceptResult* = enum AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess @@ -229,9 +232,10 @@ when defineSsl: ErrLoadBioStrings() OpenSSL_add_all_algorithms() - type SslContextExtraInternal = ref object - serverGetPskFunc: SslServerGetPskFunc - clientGetPskFunc: SslClientGetPskFunc + type + SslContextExtraInternal = ref object of RootRef + serverGetPskFunc: SslServerGetPskFunc + clientGetPskFunc: SslClientGetPskFunc proc raiseSSLError*(s = "") = ## Raises a new SSL error. @@ -245,21 +249,33 @@ when defineSsl: var errStr = ErrErrorString(err, nil) raise newException(SSLError, $errStr) - proc getSslContextExtraDataIndex*(): cint = + proc getExtraDataIndex*(ctx: SSLContext): int = ## Retrieves unique index for storing extra data in SSLContext. - return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil) + result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int + if result < 0: + raiseSSLError() + + proc getExtraData*(ctx: SSLContext, index: int): RootRef = + ## Retrieves arbitrary data stored inside SSLContext. + if index notin ctx.referencedData: + raise newException(IndexError, "No data with that index.") + let res = ctx.context.SSL_CTX_get_ex_data(index.cint) + if cast[int](res) == 0: + raiseSSLError() + return cast[RootRef](res) - proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) = + proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) = ## Stores arbitrary data inside SSLContext. The unique `index` ## should be retrieved using getSslContextExtraDataIndex. - if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1: - raiseSSLError() + if index in ctx.referencedData: + GC_unref(getExtraData(ctx, index)) - proc getExtraData*(ctx: SSLContext, index: cint): pointer = - ## Retrieves arbitrary data stored inside SSLContext. - return SslCtx(ctx).SSL_CTX_get_ex_data(index) + if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1: + raiseSSLError() - let extraInternalIndex = getSslContextExtraDataIndex() + if index notin ctx.referencedData: + ctx.referencedData.incl(index) + GC_ref(data) # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) = @@ -323,26 +339,33 @@ when defineSsl: discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) newCTX.loadCertificates(certFile, keyFile) - result = SSLContext(newCTX) + result = SSLContext(context: newCTX, extraInternalIndex: 0, + referencedData: initSet[int]()) + result.extraInternalIndex = getExtraDataIndex(result) + # The PSK callback functions assume the internal index is 0. + assert result.extraInternalIndex == 0 + let extraInternal = new(SslContextExtraInternal) - GC_ref(extraInternal) - result.setExtraData(extraInternalIndex, cast[pointer](extraInternal)) + result.setExtraData(result.extraInternalIndex, extraInternal) proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = - return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex)) + return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex)) proc destroyContext*(ctx: SSLContext) = ## Free memory referenced by SSLContext. - let extraInternal = ctx.getExtraInternal() - if extraInternal != nil: - GC_unref(extraInternal) - SSLCTX(ctx).SSL_CTX_free() + + # We assume here that OpenSSL's internal indexes increase by 1 each time. + # That means we can assume that the next internal index is the length of + # extra data indexes. + for i in ctx.referencedData: + GC_unref(getExtraData(ctx, i).RootRef) + ctx.context.SSL_CTX_free() proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = ## Sets the identity hint passed to server. ## ## Only used in PSK ciphersuites. - if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0: + if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0: raiseSSLError() proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc = @@ -350,7 +373,7 @@ when defineSsl: proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; max_psk_len: cuint): cuint {.cdecl.} = - let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) let hintString = if hint == nil: nil else: $hint let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) if psk.len.cuint > max_psk_len: @@ -369,13 +392,14 @@ when defineSsl: ## ## Only used in PSK ciphersuites. ctx.getExtraInternal().clientGetPskFunc = fun - SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback) + ctx.context.SSL_CTX_set_psk_client_callback( + if fun == nil: nil else: pskClientCallback) proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc = return ctx.getExtraInternal().serverGetPskFunc proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = - let ctx = SSLContext(ssl.SSL_get_SSL_CTX) + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) let pskString = (ctx.serverGetPskFunc)($identity) if psk.len.cint > max_psk_len: return 0 @@ -388,7 +412,7 @@ when defineSsl: ## ## Only used in PSK ciphersuites. ctx.getExtraInternal().serverGetPskFunc = fun - SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil + ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil else: pskServerCallback) proc getPskIdentity*(socket: Socket): string = @@ -409,7 +433,7 @@ when defineSsl: assert (not socket.isSSL) socket.isSSL = true socket.sslContext = ctx - socket.sslHandle = SSLNew(SSLCTX(socket.sslContext)) + socket.sslHandle = SSLNew(socket.sslContext.context) socket.sslNoHandshake = false socket.sslHasPeekChar = false if socket.sslHandle == nil: