Skip to content

Commit

Permalink
Modified nim-lang#3472 to make its API more idiomatic.
Browse files Browse the repository at this point in the history
  • Loading branch information
dom96 committed Jun 3, 2016
1 parent c170646 commit 5390c25
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 35 deletions.
26 changes: 19 additions & 7 deletions examples/ssl/extradata.nim
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# Stores extra data inside the SSL context.
import net

let ctx = newContext()

# Our unique index for storing foos
let fooIndex = getSslContextExtraDataIndex()
let fooIndex = ctx.getExtraDataIndex()
# And another unique index for storing foos
let barIndex = getSslContextExtraDataIndex()
let barIndex = ctx.getExtraDataIndex()
echo "got indexes ", fooIndex, " ", barIndex

let ctx = newContext()
assert ctx.getExtraData(fooIndex) == nil
let foo: int = 5
ctx.setExtraData(fooIndex, cast[pointer](foo))
assert cast[int](ctx.getExtraData(fooIndex)) == foo
try:
discard ctx.getExtraData(fooIndex)
assert false
except IndexError:
echo("Success")

type
FooRef = ref object of RootRef
foo: int

let foo = FooRef(foo: 5)
ctx.setExtraData(fooIndex, foo)
doAssert ctx.getExtraData(fooIndex).FooRef == foo

ctx.destroyContext()
80 changes: 52 additions & 28 deletions lib/pure/net.nim
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
##

{.deadCodeElim: on.}
import nativesockets, os, strutils, parseutils, times
import nativesockets, os, strutils, parseutils, times, sets
export Port, `$`, `==`
export Domain, SockType, Protocol

Expand All @@ -88,7 +88,10 @@ when defineSsl:
SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23

SslContext* = distinct SslCtx
SslContext* = ref object
context: SslCtx
extraInternalIndex: int
referencedData: HashSet[int]

SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
Expand Down Expand Up @@ -229,9 +232,10 @@ when defineSsl:
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()

type SslContextExtraInternal = ref object
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
type
SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc

proc raiseSSLError*(s = "") =
## Raises a new SSL error.
Expand All @@ -245,21 +249,33 @@ when defineSsl:
var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr)

proc getSslContextExtraDataIndex*(): cint =
proc getExtraDataIndex*(ctx: SSLContext): int =
## Retrieves unique index for storing extra data in SSLContext.
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
if result < 0:
raiseSSLError()

proc getExtraData*(ctx: SSLContext, index: int): RootRef =
## Retrieves arbitrary data stored inside SSLContext.
if index notin ctx.referencedData:
raise newException(IndexError, "No data with that index.")
let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
if cast[int](res) == 0:
raiseSSLError()
return cast[RootRef](res)

proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) =
## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
raiseSSLError()
if index in ctx.referencedData:
GC_unref(getExtraData(ctx, index))

proc getExtraData*(ctx: SSLContext, index: cint): pointer =
## Retrieves arbitrary data stored inside SSLContext.
return SslCtx(ctx).SSL_CTX_get_ex_data(index)
if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
raiseSSLError()

let extraInternalIndex = getSslContextExtraDataIndex()
if index notin ctx.referencedData:
ctx.referencedData.incl(index)
GC_ref(data)

# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
Expand Down Expand Up @@ -323,34 +339,41 @@ when defineSsl:
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)

result = SSLContext(newCTX)
result = SSLContext(context: newCTX, extraInternalIndex: 0,
referencedData: initSet[int]())
result.extraInternalIndex = getExtraDataIndex(result)
# The PSK callback functions assume the internal index is 0.
assert result.extraInternalIndex == 0

let extraInternal = new(SslContextExtraInternal)
GC_ref(extraInternal)
result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
result.setExtraData(result.extraInternalIndex, extraInternal)

proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))

proc destroyContext*(ctx: SSLContext) =
## Free memory referenced by SSLContext.
let extraInternal = ctx.getExtraInternal()
if extraInternal != nil:
GC_unref(extraInternal)
SSLCTX(ctx).SSL_CTX_free()

# We assume here that OpenSSL's internal indexes increase by 1 each time.
# That means we can assume that the next internal index is the length of
# extra data indexes.
for i in ctx.referencedData:
GC_unref(getExtraData(ctx, i).RootRef)
ctx.context.SSL_CTX_free()

proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
## Sets the identity hint passed to server.
##
## Only used in PSK ciphersuites.
if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
raiseSSLError()

proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
return ctx.getExtraInternal().clientGetPskFunc

proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len:
Expand All @@ -369,13 +392,14 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
ctx.context.SSL_CTX_set_psk_client_callback(
if fun == nil: nil else: pskClientCallback)

proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc

proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len:
return 0
Expand All @@ -388,7 +412,7 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().serverGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
else: pskServerCallback)

proc getPskIdentity*(socket: Socket): string =
Expand All @@ -409,7 +433,7 @@ when defineSsl:
assert (not socket.isSSL)
socket.isSSL = true
socket.sslContext = ctx
socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
socket.sslHandle = SSLNew(socket.sslContext.context)
socket.sslNoHandshake = false
socket.sslHasPeekChar = false
if socket.sslHandle == nil:
Expand Down

0 comments on commit 5390c25

Please sign in to comment.