Skip to content

Commit

Permalink
Fixes #13 Fix threading issue in local inject (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
ncipollo authored Oct 17, 2024
1 parent b511d50 commit 09105e3
Show file tree
Hide file tree
Showing 15 changed files with 350 additions and 119 deletions.
61 changes: 35 additions & 26 deletions Sources/WhoopDIKit/Container/Container.swift
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import Foundation
public final class Container {
private let localDependencyGraph: ThreadSafeDependencyGraph
private var isLocalInjectActive: Bool = false
private let options: WhoopDIOptionProvider

private let serviceDict = ServiceDictionary<DependencyDefinition>()
private var localServiceDict: ServiceDictionary<DependencyDefinition>? = nil

public init() {}
public init(options: WhoopDIOptionProvider = defaultWhoopDIOptions()) {
self.options = options
localDependencyGraph = ThreadSafeDependencyGraph(options: options)
}

/// Registers a list of modules with the DI system.
/// Typically you will create a `DependencyModule` for your feature, then add it to the module list provided to this method.
Expand Down Expand Up @@ -54,28 +60,29 @@ public final class Container {
public func inject<T>(_ name: String? = nil,
params: Any? = nil,
_ localDefinition: (DependencyModule) -> Void) -> T {
guard localServiceDict == nil else {
fatalError("Nesting WhoopDI.inject with local definitions is not currently supported")
}
// We need to maintain a reference to the local service dictionary because transient dependencies may also
// need to reference dependencies from it.
// ----
// This is a little dangerous since we are mutating a static variable but it should be fine as long as you
// don't use `inject { }` within the scope of another `inject { }`.
let serviceDict = ServiceDictionary<DependencyDefinition>()
localServiceDict = serviceDict
defer {
localServiceDict = nil
}

let localModule = DependencyModule()
localDefinition(localModule)
localModule.addToServiceDictionary(serviceDict: serviceDict)

do {
return try get(name, params)
} catch {
fatalError("WhoopDI inject failed with error: \(error)")
return localDependencyGraph.acquireDependencyGraph { localServiceDict in
// Nested local injects are not currently supported. Fail fast here.
guard !isLocalInjectActive else {
fatalError("Nesting WhoopDI.inject with local definitions is not currently supported")
}

isLocalInjectActive = true
defer {
isLocalInjectActive = false
localDependencyGraph.resetDependencyGraph()
}

let localModule = DependencyModule()
localDefinition(localModule)
localModule.addToServiceDictionary(serviceDict: localServiceDict)

do {
return try get(name, params)
} catch {
print("Inject failed with stack trace:")
Thread.callStackSymbols.forEach { print($0) }
fatalError("WhoopDI inject failed with error: \(error)")
}
}
}

Expand All @@ -89,7 +96,7 @@ public final class Container {
} else if let injectable = T.self as? any Injectable.Type {
return try injectable.inject(container: self) as! T
} else {
throw DependencyError.missingDependecy(ServiceKey(T.self, name: name))
throw DependencyError.missingDependency(ServiceKey(T.self, name: name))
}
}

Expand All @@ -106,7 +113,9 @@ public final class Container {
}

private func getDefinition(_ serviceKey: ServiceKey) -> DependencyDefinition? {
return localServiceDict?[serviceKey] ?? serviceDict[serviceKey]
localDependencyGraph.acquireDependencyGraph { localServiceDict in
return localServiceDict[serviceKey] ?? serviceDict[serviceKey]
}
}

public func removeAllDependencies() {
Expand Down
35 changes: 35 additions & 0 deletions Sources/WhoopDIKit/Container/ThreadSafeDependencyGraph.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import Foundation

final class ThreadSafeDependencyGraph: @unchecked Sendable {
private let lock = NSRecursiveLock()
private let serviceDict: ServiceDictionary<DependencyDefinition> = .init()
private let options: WhoopDIOptionProvider

init(options: WhoopDIOptionProvider) {
self.options = options
}

func acquireDependencyGraph<T>(block: (ServiceDictionary<DependencyDefinition>) -> T) -> T {
let threadSafe = options.isOptionEnabled(.threadSafeLocalInject)
if threadSafe {
lock.lock()
}
let result = block(serviceDict)
if threadSafe {
lock.unlock()
}
return result
}

func resetDependencyGraph() {
let threadSafe = options.isOptionEnabled(.threadSafeLocalInject)
if threadSafe {
lock.lock()
}
serviceDict.removeAll()
if threadSafe {
lock.unlock()
}
}

}
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
enum DependencyError: Error, CustomStringConvertible, Equatable {
case badParams(ServiceKey)
case missingDependecy(ServiceKey)
case missingDependency(ServiceKey)
case nilDependency(ServiceKey)

var description: String {
switch self {
case .badParams(let serviceKey):
return "Bad parameters provided for \(serviceKey.type) with name: \(serviceKey.name ?? "<no name>")"
case .missingDependecy(let serviceKey):
case .missingDependency(let serviceKey):
return "Missing dependency for \(serviceKey.type) with name: \(serviceKey.name ?? "<no name>")"
case .nilDependency(let serviceKey):
return "Nil dependency for \(serviceKey.type) with name: \(serviceKey.name ?? "<no name>")"
Expand Down
9 changes: 9 additions & 0 deletions Sources/WhoopDIKit/Options/DefaultOptionProvider.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
struct DefaultOptionProvider: WhoopDIOptionProvider {
func isOptionEnabled(_ option: WhoopDIOption) -> Bool {
false
}
}

public func defaultWhoopDIOptions() -> WhoopDIOptionProvider {
DefaultOptionProvider()
}
4 changes: 4 additions & 0 deletions Sources/WhoopDIKit/Options/WhoopDIOption.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/// Options for WhoopDI. These are typically experimental features which may be enabled or disabled.
public enum WhoopDIOption: Sendable {
case threadSafeLocalInject
}
4 changes: 4 additions & 0 deletions Sources/WhoopDIKit/Options/WhoopDIOptionProvider.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
/// Implement this protocol and pass it into WhoopDI via `WhoopDI.setOptions` to enable and disable various options for WhoopDI.
public protocol WhoopDIOptionProvider: Sendable {
func isOptionEnabled(_ option: WhoopDIOption) -> Bool
}
23 changes: 15 additions & 8 deletions Sources/WhoopDIKit/WhoopDI.swift
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
import Foundation
public final class WhoopDI: DependencyRegister {
nonisolated(unsafe) private static let appContainer = Container()

nonisolated(unsafe) private static var appContainer = Container()

/// Setup WhoopDI with the supplied options.
/// This should only be called once when your application launches (and before WhoopDI is used).
/// By default all options are disabled if you do not call this method.
public static func setup(options: WhoopDIOptionProvider) {
appContainer = Container(options: options)
}

/// Registers a list of modules with the DI system.
/// Typically you will create a `DependencyModule` for your feature, then add it to the module list provided to this method.
public static func registerModules(modules: [DependencyModule]) {
appContainer.registerModules(modules: modules)
}

/// Injects a dependecy into your code.
/// Injects a dependency into your code.
///
/// The injected dependecy will have all of it's sub-dependencies provided by the object graph defined in WhoopDI.
/// The injected dependency will have all of it's sub-dependencies provided by the object graph defined in WhoopDI.
/// Typically this should be called from your top level UI object (ViewController, etc). Intermediate components should rely upon constructor injection (i.e providing dependencies via the constructor)
public static func inject<T>(_ name: String? = nil, _ params: Any? = nil) -> T {
appContainer.inject(name, params)
}

/// Injects a dependency into your code, overlaying local dependencies on top of the object graph.
///
/// The injected dependecy will have all of it's sub-dependencies provided by the object graph defined in WhoopDI.
/// The injected dependency will have all of it's sub-dependencies provided by the object graph defined in WhoopDI.
/// Typically this should be called from your top level UI object (ViewController, etc). Intermediate components should rely
/// upon constructor injection (i.e providing dependencies via the constructor).
///
Expand All @@ -36,12 +43,12 @@ public final class WhoopDI: DependencyRegister {
/// - name: An optional name for the dependency. This can help disambiguate between dependencies of the same type.
/// - params: Optional parameters which will be provided to dependencies which require them (i.e dependencies using defintiions such as
/// (factoryWithParams, etc).
/// - localDefiniton: A local module definition which can be used to supply local dependencies to the object graph prior to injection.
/// - localDefinition: A local module definition which can be used to supply local dependencies to the object graph prior to injection.
/// - Returns: The requested dependency.
public static func inject<T>(_ name: String? = nil,
params: Any? = nil,
_ localDefiniton: (DependencyModule) -> Void) -> T {
appContainer.inject(name, params: params, localDefiniton)
_ localDefinition: (DependencyModule) -> Void) -> T {
appContainer.inject(name, params: params, localDefinition)
}

/// Used internally by the DependencyModule get to loop up a sub-dependency in the object graph.
Expand Down
73 changes: 54 additions & 19 deletions Tests/WhoopDIKitTests/Container/ContainerTests.swift
Original file line number Diff line number Diff line change
@@ -1,60 +1,95 @@
import XCTest
import Testing
@testable import WhoopDIKit

class ContainerTests: XCTestCase {
private let container = Container()
// This is unchecked Sendable so we can run our local inject concurrency test
class ContainerTests: @unchecked Sendable {
private let container: Container

init() {
let options = MockOptionProvider(options: [.threadSafeLocalInject: true])
container = Container(options: options)
}

func test_inject() {
@Test
func inject() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory", "param")
XCTAssertTrue(dependency is DependencyC)
#expect(dependency is DependencyC)
}

func test_inject_generic_integer() {
@Test
func inject_generic_integer() {
container.registerModules(modules: [GoodTestModule()])
let dependency: GenericDependency<Int> = container.inject()
XCTAssertEqual(42, dependency.value)
#expect(42 == dependency.value)
}

func test_inject_generic_string() {
@Test
func inject_generic_string() {
container.registerModules(modules: [GoodTestModule()])
let dependency: GenericDependency<String> = container.inject()
XCTAssertEqual("string", dependency.value)
#expect("string" == dependency.value)
}

func test_inject_localDefinition() {
@Test
func inject_localDefinition() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory") { module in
// Typically you'd override or provide a transient dependency. I'm using the top level dependency here
// for the sake of simplicity.
module.factory(name: "C_Factory") { DependencyA() as Dependency }
}
XCTAssertTrue(dependency is DependencyA)
#expect(dependency is DependencyA)
}

@Test(.bug("https://github.com/WhoopInc/WhoopDI/issues/13"))
func inject_localDefinition_concurrency() async {
container.registerModules(modules: [GoodTestModule()])
// Run many times to try and capture race condition
for _ in 0..<500 {
let taskA = Task.detached {
let _: Dependency = self.container.inject("C_Factory") { module in
module.factory(name: "C_Factory") { DependencyA() as Dependency }
}
}

let taskB = Task.detached {
let _: DependencyA = self.container.inject()
}

for task in [taskA, taskB] {
let _ = await task.result
}
}
}

func test_inject_localDefinition_noOverride() {
@Test
func inject_localDefinition_noOverride() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory", params: "params") { _ in }
XCTAssertTrue(dependency is DependencyC)
#expect(dependency is DependencyC)
}

func test_inject_localDefinition_withParams() {
@Test
func inject_localDefinition_withParams() {
container.registerModules(modules: [GoodTestModule()])
let dependency: Dependency = container.inject("C_Factory", params: "params") { module in
module.factoryWithParams(name: "C_Factory") { params in DependencyB(params) as Dependency }
}
XCTAssertTrue(dependency is DependencyB)
#expect(dependency is DependencyB)
}

func test_injectableWithDependency() throws {
@Test
func injectableWithDependency() throws {
container.registerModules(modules: [FakeTestModuleForInjecting()])
let testInjecting: InjectableWithDependency = container.inject()
XCTAssertEqual(testInjecting, InjectableWithDependency(dependency: DependencyA()))
#expect(testInjecting == InjectableWithDependency(dependency: DependencyA()))
}

func test_injectableWithNamedDependency() throws {
@Test
func injectableWithNamedDependency() throws {
container.registerModules(modules: [FakeTestModuleForInjecting()])
let testInjecting: InjectableWithNamedDependency = container.inject()
XCTAssertEqual(testInjecting, InjectableWithNamedDependency(name: 1))
#expect(testInjecting == InjectableWithNamedDependency(name: 1))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import Testing
@testable import WhoopDIKit

struct ThreadSafeDependencyGraphTests {
@Test(arguments: [false, true])
func acquireDependencyGraph_notThreadSafe(threadsafe: Bool) {
let options = MockOptionProvider(options: [.threadSafeLocalInject: threadsafe])
let graph = ThreadSafeDependencyGraph(options: options)

graph.acquireDependencyGraph { serviceDict in
serviceDict[DependencyA.self] = FactoryDefinition(name: nil) { _ in DependencyA() }
}
graph.acquireDependencyGraph { serviceDict in
let dependency = serviceDict[DependencyA.self]
#expect(dependency != nil)
}

graph.resetDependencyGraph()

graph.acquireDependencyGraph { serviceDict in
let dependency = serviceDict[DependencyA.self]
#expect(dependency == nil)
}
}

@Test
func acquireDependencyGraph_recursive() {
let options = MockOptionProvider(options: [.threadSafeLocalInject: true])
let graph = ThreadSafeDependencyGraph(options: options)

graph.acquireDependencyGraph { outer in
graph.acquireDependencyGraph { serviceDict in
serviceDict[DependencyA.self] = FactoryDefinition(name: nil) { _ in DependencyA() }
}
let dependency = outer[DependencyA.self]
#expect(dependency != nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ class DependencyErrorTests: XCTestCase {
XCTAssertEqual(expected, error.description)
}

func test_description_missingDependecy_noServiceKeyName() {
func test_description_missingDependency_noServiceKeyName() {
let expected = "Missing dependency for String with name: <no name>"
let error = DependencyError.missingDependecy(serviceKey)
let error = DependencyError.missingDependency(serviceKey)
XCTAssertEqual(expected, error.description)
}

func test_description_missingDependecy_withServiceKeyName() {
func test_description_missingDependency_withServiceKeyName() {
let expected = "Missing dependency for String with name: name"
let error = DependencyError.missingDependecy(serviceKeyWithName)
let error = DependencyError.missingDependency(serviceKeyWithName)
XCTAssertEqual(expected, error.description)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DependencyDefinitionTests: XCTestCase {
}

func test_singleton_get_recoversFromThrow() {
let expectedError = DependencyError.missingDependecy(ServiceKey(String.self))
let expectedError = DependencyError.missingDependency(ServiceKey(String.self))
var callCount = 0
let definition = SingletonDefinition(name: nil) { _ -> Int in
callCount += 1
Expand Down
Loading

0 comments on commit 09105e3

Please sign in to comment.