Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add onEnter and onExit events to states #50

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions Swift/Sources/StateMachine/StateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
public struct Valid: CustomDebugStringConvertible {

public var debugDescription: String {
guard let sideEffect: SideEffect = sideEffect
else { return "fromState: \(fromState), event: \(event), toState: \(toState), sideEffect: nil" }
return "fromState: \(fromState), event: \(event), toState: \(toState), sideEffect: \(sideEffect)"
return "fromState: \(fromState), event: \(event), toState: \(toState), sideEffects: \(sideEffects)"
}

public let fromState: State
public let event: Event
public let toState: State
public let sideEffect: SideEffect?
public let sideEffects: [SideEffect]
}

public struct Invalid: Error, Equatable {}
Expand Down Expand Up @@ -56,16 +54,33 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
private let states: States
private var observers: [Observer] = []

private typealias EnterExitAction = (State) throws -> Void

private var onEnterActions: [State.HashableIdentifier: EnterExitAction]
private var onExitActions: [State.HashableIdentifier: EnterExitAction]

private var isNotifying: Bool = false

public init(@DefinitionBuilder build: () -> Definition) {
let definition: Definition = build()
state = definition.initialState.state
states = definition.states.reduce(into: States()) {
$0[$1.state] = $1.events.reduce(into: Events()) {
$0[$1.event] = $1.action
var enterActions: [State.HashableIdentifier: EnterExitAction] = [:]
var exitActions: [State.HashableIdentifier: EnterExitAction] = [:]
states = definition.states.reduce(into: States()) { result, tuple in
let (state, events) = tuple
result[state] = events.reduce(into: Events()) {
switch $1.eventType {
case .onEnter(let action):
enterActions[state] = action
case .onExit(let action):
exitActions[state] = action
case .normal(let event, let action):
$0[event] = action
}
}
}
onEnterActions = enterActions
onExitActions = exitActions
observers = definition.callbacks.map {
Observer(object: self, callback: $0)
}
Expand Down Expand Up @@ -105,11 +120,19 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
let transition: Transition.Valid = .init(fromState: state,
event: event,
toState: action.toState ?? state,
sideEffect: action.sideEffect)
sideEffects: action.sideEffects)
let fromState = state
if let toState: State = action.toState {
state = toState
}

result = .success(transition)

// if not `dontTransition`
if action.toState != nil {
try? onExitActions[stateIdentifier]?(fromState)
try? onEnterActions[state.hashableIdentifier]?(state)
}
} else {
result = .failure(Transition.Invalid())
}
Expand Down Expand Up @@ -174,38 +197,54 @@ extension StateMachineBuilder {
.state(state: state, events: build())
}

public static func onEnter(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onEnter(perform))]
}

public static func onExit(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onExit(perform))]
}

public static func onEnter(_ perform: @escaping () throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onEnter({ _ in try perform() }))]
}

public static func onExit(_ perform: @escaping () throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onExit({ _ in try perform() }))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping (State, Event) throws -> Action
) -> [EventHandler] {
[EventHandler(event: event, action: perform)]
[EventHandler(eventType: .normal(event, perform))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping (State) throws -> Action
) -> [EventHandler] {
[EventHandler(event: event) { state, _ in try perform(state) }]
[EventHandler(eventType: .normal(event, { state, _ in try perform(state) }))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping () throws -> Action
) -> [EventHandler] {
[EventHandler(event: event) { _, _ in try perform() }]
[EventHandler(eventType: .normal(event, { _, _ in try perform() }))]
}

public static func transition(
to state: State,
emit sideEffect: SideEffect? = nil
emit sideEffect: SideEffect...
) -> Action {
Action(toState: state, sideEffect: sideEffect)
Action(toState: state, sideEffects: sideEffect)
}

public static func dontTransition(
emit sideEffect: SideEffect? = nil
emit sideEffect: SideEffect...
) -> Action {
Action(toState: nil, sideEffect: sideEffect)
Action(toState: nil, sideEffects: sideEffect)
}

public static func onTransition(
Expand Down Expand Up @@ -279,16 +318,24 @@ public enum StateMachineTypes {

public struct EventHandler<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate let event: Event.HashableIdentifier
fileprivate let action: Action<State, Event, SideEffect>.Factory
fileprivate var eventType: EventType<State, Event, SideEffect>

fileprivate enum EventType<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate typealias EnterExitAction = (State) throws -> Void

case normal(Event.HashableIdentifier, Action<State, Event, SideEffect>.Factory)
case onEnter(EnterExitAction)
case onExit(EnterExitAction)
}
}

public struct Action<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate typealias Factory = (State, Event) throws -> Self

fileprivate let toState: State?
fileprivate let sideEffect: SideEffect?
fileprivate let sideEffects: [SideEffect]
}

public struct IncorrectTypeError: Error, CustomDebugStringConvertible {
Expand Down
42 changes: 32 additions & 10 deletions Swift/Tests/StateMachineTests/StateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {

enum State: StateMachineHashable {

case stateOne, stateTwo
case stateOne, stateTwo, stateThree
}

enum Event: StateMachineHashable {

case eventOne, eventTwo
case eventOne, eventTwo, eventThree
}

enum SideEffect {
enum SideEffect: Equatable {

case commandOne, commandTwo, commandThree
case commandOne, commandTwo, commandThree, commandFour(Int)
}

typealias TestStateMachine = StateMachine<State, Event, SideEffect>
Expand All @@ -33,7 +33,7 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {
initialState(_state)
state(.stateOne) {
on(.eventOne) {
dontTransition(emit: .commandOne)
dontTransition(emit: .commandOne, .commandTwo)
}
on(.eventTwo) {
transition(to: .stateTwo, emit: .commandTwo)
Expand All @@ -43,7 +43,11 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {
on(.eventTwo) {
dontTransition(emit: .commandThree)
}
on(.eventThree) { _, event in
transition(to: .stateThree, emit: .commandFour(try event.string()))
}
}
state(.stateThree)
}
}

Expand All @@ -66,7 +70,7 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {
expect(transition).to(equal(ValidTransition(fromState: .stateOne,
event: .eventOne,
toState: .stateOne,
sideEffect: .commandOne)))
sideEffects: [.commandOne, .commandTwo])))
}

func testTransition() throws {
Expand All @@ -82,7 +86,7 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {
expect(transition).to(equal(ValidTransition(fromState: .stateOne,
event: .eventTwo,
toState: .stateTwo,
sideEffect: .commandTwo)))
sideEffects: [.commandTwo])))
}

func testInvalidTransition() throws {
Expand Down Expand Up @@ -131,16 +135,16 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {
.success(ValidTransition(fromState: .stateOne,
event: .eventOne,
toState: .stateOne,
sideEffect: .commandOne)),
sideEffects: [.commandOne, .commandTwo])),
.success(ValidTransition(fromState: .stateOne,
event: .eventTwo,
toState: .stateTwo,
sideEffect: .commandTwo)),
sideEffects: [.commandTwo])),
.failure(InvalidTransition()),
.success(ValidTransition(fromState: .stateTwo,
event: .eventTwo,
toState: .stateTwo,
sideEffect: .commandThree))
sideEffects: [.commandThree]))
]))
}

Expand Down Expand Up @@ -191,6 +195,14 @@ final class StateMachineTests: XCTestCase, StateMachineBuilder {
// Then
expect(error).to(equal(.recursionDetected))
}

func testGettingNonExistingValue() throws {
// Given
let stateMachine: TestStateMachine = givenState(is: .stateTwo)

// Then
XCTAssertThrowsError(try stateMachine.transition(.eventThree))
}
}

final class Logger {
Expand All @@ -212,3 +224,13 @@ func log(_ expectedMessages: String...) -> Predicate<Logger> {
return PredicateResult(bool: actualMessages == expectedMessages, message: message)
}
}

func noLog() -> Predicate<Logger> {
return Predicate {
let actualMessages: [String]? = try $0.evaluate()?.messages
let actualString: String = stringify(actualMessages?.joined(separator: "\\n"))
let message: ExpectationMessage = .expectedCustomValueTo("no logs",
actual: "<\(actualString)>")
return PredicateResult(bool: actualString.count == 0, message: message)
}
}
66 changes: 46 additions & 20 deletions Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,41 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
typealias ValidTransition = MatterStateMachine.Transition.Valid
typealias InvalidTransition = MatterStateMachine.Transition.Invalid

enum Message {

static let melted: String = "I melted"
static let frozen: String = "I froze"
static let vaporized: String = "I vaporized"
static let condensed: String = "I condensed"
enum Message: String {

case melted = "I melted"
case frozen = "I froze"
case vaporized = "I vaporized"
case condensed = "I condensed"
case enteredSolid
case exitedSolid
case enteredLiquid
case exitedLiquid
case enteredGas
case exitedGas
}

static func matterStateMachine(withInitialState _state: State, logger: Logger) -> MatterStateMachine {
MatterStateMachine {
initialState(_state)
state(.solid) {
onEnter { _ in
logger.log(Message.enteredSolid.rawValue)
}
onExit { _ in
logger.log(Message.exitedSolid.rawValue)
}
on(.melt) {
transition(to: .liquid, emit: .logMelted)
}
}
state(.liquid) {
onEnter { _ in
logger.log(Message.enteredLiquid.rawValue)
}
onExit { _ in
logger.log(Message.exitedLiquid.rawValue)
}
on(.freeze) {
transition(to: .solid, emit: .logFrozen)
}
Expand All @@ -53,17 +71,25 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
}
}
state(.gas) {
onEnter { _ in
logger.log(Message.enteredGas.rawValue)
}
onExit { _ in
logger.log(Message.exitedGas.rawValue)
}
on(.condense) {
transition(to: .liquid, emit: .logCondensed)
}
}
onTransition {
guard case let .success(transition) = $0, let sideEffect = transition.sideEffect else { return }
switch sideEffect {
case .logMelted: logger.log(Message.melted)
case .logFrozen: logger.log(Message.frozen)
case .logVaporized: logger.log(Message.vaporized)
case .logCondensed: logger.log(Message.condensed)
guard case let .success(transition) = $0 else { return }
transition.sideEffects.forEach { sideEffect in
switch sideEffect {
case .logMelted: logger.log(Message.melted.rawValue)
case .logFrozen: logger.log(Message.frozen.rawValue)
case .logVaporized: logger.log(Message.vaporized.rawValue)
case .logCondensed: logger.log(Message.condensed.rawValue)
}
}
}
}
Expand Down Expand Up @@ -100,8 +126,8 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
expect(transition).to(equal(ValidTransition(fromState: .solid,
event: .melt,
toState: .liquid,
sideEffect: .logMelted)))
expect(self.logger).to(log(Message.melted))
sideEffects: [.logMelted])))
expect(self.logger).to(log(Message.exitedSolid.rawValue, Message.enteredLiquid.rawValue, Message.melted.rawValue))
}

func test_givenStateIsSolid_whenFrozen_shouldThrowInvalidTransitionError() throws {
Expand Down Expand Up @@ -133,8 +159,8 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
expect(transition).to(equal(ValidTransition(fromState: .liquid,
event: .freeze,
toState: .solid,
sideEffect: .logFrozen)))
expect(self.logger).to(log(Message.frozen))
sideEffects: [.logFrozen])))
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredSolid.rawValue, Message.frozen.rawValue))
}

func test_givenStateIsLiquid_whenVaporized_shouldTransitionToGasState() throws {
Expand All @@ -150,8 +176,8 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
expect(transition).to(equal(ValidTransition(fromState: .liquid,
event: .vaporize,
toState: .gas,
sideEffect: .logVaporized)))
expect(self.logger).to(log(Message.vaporized))
sideEffects: [.logVaporized])))
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredGas.rawValue, Message.vaporized.rawValue))
}

func test_givenStateIsGas_whenCondensed_shouldTransitionToLiquidState() throws {
Expand All @@ -167,7 +193,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
expect(transition).to(equal(ValidTransition(fromState: .gas,
event: .condense,
toState: .liquid,
sideEffect: .logCondensed)))
expect(self.logger).to(log(Message.condensed))
sideEffects: [.logCondensed])))
expect(self.logger).to(log(Message.exitedGas.rawValue, Message.enteredLiquid.rawValue, Message.condensed.rawValue))
}
}
Loading