Skip to content

Commit

Permalink
ivp test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin CO committed Mar 3, 2025
1 parent 198d945 commit 4abc61e
Showing 1 changed file with 96 additions and 114 deletions.
210 changes: 96 additions & 114 deletions tests/shard1/test_ivp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def test_values(result, tested_values):
for i in range(len(tested_index)):
for i in range(len(tested_values)):
np.testing.assert_almost_equal(
result[tested_index[i]],
tested_values[i],
Expand Down Expand Up @@ -308,134 +308,116 @@ def test_hmed2018_ivp(model):
test_values(result["F"][0], tested_values)


ding2003_with_fatigue_model = ModelMaker.create_model(
"ding2003_with_fatigue", stim_time=[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
)


@pytest.mark.parametrize("pulse_mode", ["single", "doublet", "triplet"])
def test_pulse_mode_ivp(pulse_mode):
fes_parameters = {
"model": ding2003_with_fatigue_model,
"pulse_mode": pulse_mode,
}
ivp_parameters = {"final_time": 0.3, "use_sx": True, "ode_solver": OdeSolver.RK4(n_integration_steps=10)}
ivp_parameters = {"final_time": 1, "use_sx": True, "ode_solver": OdeSolver.RK4(n_integration_steps=10)}

ivp = IvpFes(fes_parameters, ivp_parameters)

# Integrating the solution
result = ivp.integrate(return_time=False)

if pulse_mode == "single":
np.testing.assert_almost_equal(
result["F"][0],
np.array(
[
0.0,
15.85435155,
36.35204199,
54.97169718,
70.81575896,
83.41136615,
92.40789376,
97.57896902,
98.90445688,
96.65469613,
91.4120125,
99.09361882,
111.76148442,
123.22551071,
132.4700185,
138.93153923,
142.1847179,
141.94409218,
138.15150421,
131.06678227,
121.2908411,
125.7329513,
135.80039992,
144.91927299,
152.01179863,
156.47959033,
157.87252439,
155.88794042,
150.45844244,
141.84300151,
130.64878849,
]
),
)
tested_values = [
0.0,
91.41201249507682,
121.29084110441761,
130.64878849132887,
133.6325306398771,
134.63501242329733,
135.01838738959748,
135.2065215051987,
135.33181981213994,
135.43580597313056,
135.5315562363452,
]

elif pulse_mode == "doublet":
np.testing.assert_almost_equal(
result["F"][0][150:180],
np.array(
[
124.87618293,
125.18689644,
125.48798444,
125.77941427,
126.06115422,
126.33317358,
126.59544267,
126.8479329,
127.09061681,
127.32346812,
127.54646176,
127.75957395,
127.96278219,
128.15606537,
128.33940376,
128.51277911,
128.67617463,
128.8295751,
128.97296688,
129.10633793,
129.22967793,
129.34297824,
129.44623198,
129.53943411,
129.62258137,
129.69567244,
129.75870789,
129.81169025,
129.85462405,
129.88751587,
]
),
)
tested_values = [
0.0,
6.811659796099999,
17.836098653932126,
29.55736919641,
41.04200577764048,
52.06499584872172,
62.532496833034685,
72.39128730499773,
81.6025536882239,
90.13179437212492,
97.94429264124467,
105.00306253311439,
111.26822359129736,
116.6974534531757,
121.24742211246753,
124.87618292671563,
127.54646176449623,
129.22967792906385,
129.91037433984613,
129.59056707624796,
128.2933969321963,
130.2820796006739,
135.62676745353022,
141.89351264990967,
148.18010795280892,
154.24004557311832,
159.96083428619735,
165.27342115086742,
170.12491198575734,
174.4678252728422,
178.2551988671325,
181.43833498379723,
183.96608043102177,
185.78526600963679,
]

elif pulse_mode == "triplet":
np.testing.assert_almost_equal(
result["F"][0][350:380],
np.array(
[
220.086985,
220.22482724,
220.35632941,
220.48143445,
220.60008492,
220.71222302,
220.81779059,
220.91672912,
221.00897984,
221.09448364,
221.17318118,
221.24501288,
221.30991893,
221.36783933,
221.41871394,
221.46248248,
221.49908455,
221.5284597,
221.55054742,
221.56528722,
221.5726186,
221.57248114,
221.56481451,
221.54955853,
221.52665317,
221.49603862,
221.45765532,
221.41144399,
221.3573457,
221.29530191,
]
),
)
tested_values = [
0.0,
6.811659796099999,
18.419341569107342,
31.018933207521105,
43.51388159780265,
55.5996068206944,
67.18132459057564,
78.23205802958206,
88.74309318549606,
98.71107226716651,
108.13310254845092,
117.00434140759826,
125.31651184893666,
133.0568245159095,
140.20712082300022,
146.74319450865164,
152.63432797007994,
157.84313653511586,
162.32586010475507,
166.03327591175878,
168.91241867862357,
172.07412408967804,
176.7299944839003,
182.24856137475152,
187.95780010310577,
193.56606260351114,
198.96016141251192,
204.0961127202081,
208.95003601617074,
213.50462110419983,
217.74384599976568,
221.6503283192785,
225.20368978421823,
228.37937196951108,
]

test_values(result["F"][0], tested_values)


def test_ivp_methods():
Expand Down

0 comments on commit 4abc61e

Please sign in to comment.