Skip to content

Commit

Permalink
Add onEnter and onExit events to states
Browse files Browse the repository at this point in the history
  • Loading branch information
Gray-Wind committed Mar 8, 2022
1 parent d8c86b9 commit 578b1c9
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 22 deletions.
65 changes: 57 additions & 8 deletions Swift/Sources/StateMachine/StateMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,33 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
private let states: States
private var observers: [Observer] = []

private typealias EnterExitAction = (State) throws -> Void

private var onEnterActions: [State.HashableIdentifier: EnterExitAction]
private var onExitActions: [State.HashableIdentifier: EnterExitAction]

private var isNotifying: Bool = false

public init(@DefinitionBuilder build: () -> Definition) {
let definition: Definition = build()
state = definition.initialState.state
states = definition.states.reduce(into: States()) {
$0[$1.state] = $1.events.reduce(into: Events()) {
$0[$1.event] = $1.action
var enterActions: [State.HashableIdentifier: EnterExitAction] = [:]
var exitActions: [State.HashableIdentifier: EnterExitAction] = [:]
states = definition.states.reduce(into: States()) { result, tuple in
let (state, events) = tuple
result[state] = events.reduce(into: Events()) {
switch $1.eventType {
case .onEnter(let action):
enterActions[state] = action
case .onExit(let action):
exitActions[state] = action
case .normal(let event, let action):
$0[event] = action
}
}
}
onEnterActions = enterActions
onExitActions = exitActions
observers = definition.callbacks.map {
Observer(object: self, callback: $0)
}
Expand Down Expand Up @@ -104,10 +121,18 @@ open class StateMachine<State: StateMachineHashable, Event: StateMachineHashable
event: event,
toState: action.toState ?? state,
sideEffects: action.sideEffects)
let fromState = state
if let toState: State = action.toState {
state = toState
}

result = .success(transition)

// if not `dontTransition`
if action.toState != nil {
try? onExitActions[stateIdentifier]?(fromState)
try? onEnterActions[state.hashableIdentifier]?(state)
}
} else {
result = .failure(Transition.Invalid())
}
Expand Down Expand Up @@ -172,25 +197,41 @@ extension StateMachineBuilder {
.state(state: state, events: build())
}

public static func onEnter(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onEnter(perform))]
}

public static func onExit(_ perform: @escaping (State) throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onExit(perform))]
}

public static func onEnter(_ perform: @escaping () throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onEnter({ _ in try perform() }))]
}

public static func onExit(_ perform: @escaping () throws -> Void) -> [EventHandler] {
[EventHandler(eventType: .onExit({ _ in try perform() }))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping (State, Event) throws -> Action
) -> [EventHandler] {
[EventHandler(event: event, action: perform)]
[EventHandler(eventType: .normal(event, perform))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping (State) throws -> Action
) -> [EventHandler] {
[EventHandler(event: event) { state, _ in try perform(state) }]
[EventHandler(eventType: .normal(event, { state, _ in try perform(state) }))]
}

public static func on(
_ event: Event.HashableIdentifier,
perform: @escaping () throws -> Action
) -> [EventHandler] {
[EventHandler(event: event) { _, _ in try perform() }]
[EventHandler(eventType: .normal(event, { _, _ in try perform() }))]
}

public static func transition(
Expand Down Expand Up @@ -277,8 +318,16 @@ public enum StateMachineTypes {

public struct EventHandler<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate let event: Event.HashableIdentifier
fileprivate let action: Action<State, Event, SideEffect>.Factory
fileprivate var eventType: EventType<State, Event, SideEffect>

fileprivate enum EventType<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {

fileprivate typealias EnterExitAction = (State) throws -> Void

case normal(Event.HashableIdentifier, Action<State, Event, SideEffect>.Factory)
case onEnter(EnterExitAction)
case onExit(EnterExitAction)
}
}

public struct Action<State: StateMachineHashable, Event: StateMachineHashable, SideEffect> {
Expand Down
10 changes: 10 additions & 0 deletions Swift/Tests/StateMachineTests/StateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,13 @@ func log(_ expectedMessages: String...) -> Predicate<Logger> {
return PredicateResult(bool: actualMessages == expectedMessages, message: message)
}
}

func noLog() -> Predicate<Logger> {
return Predicate {
let actualMessages: [String]? = try $0.evaluate()?.messages
let actualString: String = stringify(actualMessages?.joined(separator: "\\n"))
let message: ExpectationMessage = .expectedCustomValueTo("no logs",
actual: "<\(actualString)>")
return PredicateResult(bool: actualString.count == 0, message: message)
}
}
52 changes: 38 additions & 14 deletions Swift/Tests/StateMachineTests/StateMachine_Matter_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,41 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
typealias ValidTransition = MatterStateMachine.Transition.Valid
typealias InvalidTransition = MatterStateMachine.Transition.Invalid

enum Message {

static let melted: String = "I melted"
static let frozen: String = "I froze"
static let vaporized: String = "I vaporized"
static let condensed: String = "I condensed"
enum Message: String {

case melted = "I melted"
case frozen = "I froze"
case vaporized = "I vaporized"
case condensed = "I condensed"
case enteredSolid
case exitedSolid
case enteredLiquid
case exitedLiquid
case enteredGas
case exitedGas
}

static func matterStateMachine(withInitialState _state: State, logger: Logger) -> MatterStateMachine {
MatterStateMachine {
initialState(_state)
state(.solid) {
onEnter { _ in
logger.log(Message.enteredSolid.rawValue)
}
onExit { _ in
logger.log(Message.exitedSolid.rawValue)
}
on(.melt) {
transition(to: .liquid, emit: .logMelted)
}
}
state(.liquid) {
onEnter { _ in
logger.log(Message.enteredLiquid.rawValue)
}
onExit { _ in
logger.log(Message.exitedLiquid.rawValue)
}
on(.freeze) {
transition(to: .solid, emit: .logFrozen)
}
Expand All @@ -53,6 +71,12 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
}
}
state(.gas) {
onEnter { _ in
logger.log(Message.enteredGas.rawValue)
}
onExit { _ in
logger.log(Message.exitedGas.rawValue)
}
on(.condense) {
transition(to: .liquid, emit: .logCondensed)
}
Expand All @@ -61,10 +85,10 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
guard case let .success(transition) = $0 else { return }
transition.sideEffects.forEach { sideEffect in
switch sideEffect {
case .logMelted: logger.log(Message.melted)
case .logFrozen: logger.log(Message.frozen)
case .logVaporized: logger.log(Message.vaporized)
case .logCondensed: logger.log(Message.condensed)
case .logMelted: logger.log(Message.melted.rawValue)
case .logFrozen: logger.log(Message.frozen.rawValue)
case .logVaporized: logger.log(Message.vaporized.rawValue)
case .logCondensed: logger.log(Message.condensed.rawValue)
}
}
}
Expand Down Expand Up @@ -103,7 +127,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .melt,
toState: .liquid,
sideEffects: [.logMelted])))
expect(self.logger).to(log(Message.melted))
expect(self.logger).to(log(Message.exitedSolid.rawValue, Message.enteredLiquid.rawValue, Message.melted.rawValue))
}

func test_givenStateIsSolid_whenFrozen_shouldThrowInvalidTransitionError() throws {
Expand Down Expand Up @@ -136,7 +160,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .freeze,
toState: .solid,
sideEffects: [.logFrozen])))
expect(self.logger).to(log(Message.frozen))
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredSolid.rawValue, Message.frozen.rawValue))
}

func test_givenStateIsLiquid_whenVaporized_shouldTransitionToGasState() throws {
Expand All @@ -153,7 +177,7 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .vaporize,
toState: .gas,
sideEffects: [.logVaporized])))
expect(self.logger).to(log(Message.vaporized))
expect(self.logger).to(log(Message.exitedLiquid.rawValue, Message.enteredGas.rawValue, Message.vaporized.rawValue))
}

func test_givenStateIsGas_whenCondensed_shouldTransitionToLiquidState() throws {
Expand All @@ -170,6 +194,6 @@ final class StateMachine_Matter_Tests: XCTestCase, StateMachineBuilder {
event: .condense,
toState: .liquid,
sideEffects: [.logCondensed])))
expect(self.logger).to(log(Message.condensed))
expect(self.logger).to(log(Message.exitedGas.rawValue, Message.enteredLiquid.rawValue, Message.condensed.rawValue))
}
}
34 changes: 34 additions & 0 deletions Swift/Tests/StateMachineTests/StateMachine_Turnstile_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,25 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
typealias TurnstileStateMachine = StateMachine<State, Event, SideEffect>
typealias ValidTransition = TurnstileStateMachine.Transition.Valid

enum Message: String {
case enteredLocked
case exitedLocked
case enteredUnlocked
case exitedUnlocked
case enteredBroken
case exitedBroken
}

static func turnstileStateMachine(withInitialState _state: State, logger: Logger) -> TurnstileStateMachine {
TurnstileStateMachine {
initialState(_state)
state(.locked) {
onEnter { state in
logger.log("\(Message.enteredLocked.rawValue) \(try state.credit() as Int)")
}
onExit {
logger.log(Message.exitedLocked.rawValue)
}
on(.insertCoin) { locked, insertCoin in
let newCredit: Int = try locked.credit() + insertCoin.value()
if newCredit >= Constant.farePrice {
Expand All @@ -52,11 +67,23 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
}
}
state(.unlocked) {
onEnter {
logger.log(Message.enteredUnlocked.rawValue)
}
onExit {
logger.log(Message.exitedUnlocked.rawValue)
}
on(.admitPerson) {
transition(to: .locked(credit: 0), emit: .closeDoors)
}
}
state(.broken) {
onEnter {
logger.log(Message.enteredBroken.rawValue)
}
onExit {
logger.log(Message.exitedBroken.rawValue)
}
on(.machineRepairDidComplete) { broken in
transition(to: try broken.oldState())
}
Expand Down Expand Up @@ -96,6 +123,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .insertCoin(10),
toState: .locked(credit: 10),
sideEffects: [])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, "\(Message.enteredLocked.rawValue) 10"))
}

func test_givenStateIsLocked_whenInsertCoin_andCreditEqualsFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws {
Expand All @@ -112,6 +140,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .insertCoin(15),
toState: .unlocked,
sideEffects: [.openDoors])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue))
}

func test_givenStateIsLocked_whenInsertCoin_andCreditMoreThanFarePrice_shouldTransitionToUnlockedStateAndOpenDoors() throws {
Expand All @@ -128,6 +157,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .insertCoin(20),
toState: .unlocked,
sideEffects: [.openDoors])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredUnlocked.rawValue))
}

func test_givenStateIsLocked_whenAdmitPerson_shouldTransitionToLockedStateAndSoundAlarm() throws {
Expand All @@ -144,6 +174,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .admitPerson,
toState: .locked(credit: 35),
sideEffects: [.soundAlarm])))
expect(self.logger).to(noLog())
}

func test_givenStateIsLocked_whenMachineDidFail_shouldTransitionToBrokenStateAndOrderRepair() throws {
Expand All @@ -160,6 +191,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .machineDidFail,
toState: .broken(oldState: .locked(credit: 15)),
sideEffects: [.orderRepair])))
expect(self.logger).to(log(Message.exitedLocked.rawValue, Message.enteredBroken.rawValue))
}

func test_givenStateIsUnlocked_whenAdmitPerson_shouldTransitionToLockedStateAndCloseDoors() throws {
Expand All @@ -176,6 +208,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .admitPerson,
toState: .locked(credit: 0),
sideEffects: [.closeDoors])))
expect(self.logger).to(log(Message.exitedUnlocked.rawValue, "\(Message.enteredLocked.rawValue) 0"))
}

func test_givenStateIsBroken_whenMachineRepairDidComplete_shouldTransitionToLockedState() throws {
Expand All @@ -192,6 +225,7 @@ final class StateMachine_Turnstile_Tests: XCTestCase, StateMachineBuilder {
event: .machineRepairDidComplete,
toState: .locked(credit: 15),
sideEffects: [])))
expect(self.logger).to(log(Message.exitedBroken.rawValue, "\(Message.enteredLocked.rawValue) 15"))
}
}

Expand Down

0 comments on commit 578b1c9

Please sign in to comment.