diff --git a/src/extra/components/pulses.py b/src/extra/components/pulses.py index ab0c5cce..aa80442c 100644 --- a/src/extra/components/pulses.py +++ b/src/extra/components/pulses.py @@ -1153,6 +1153,7 @@ def _get_pulse_ids(self): def _get_pulse_mask(self, reduced=False): # Actually returns flags instead of a mask. + train_ids = self._get_train_ids() pulse_ids = self.pulse_ids(copy=False) pids_by_train = pulse_ids.groupby(level=0) @@ -1163,11 +1164,28 @@ def _get_pulse_mask(self, reduced=False): pid_offset = 0 table_len = self._bunch_pattern_table_len - flags = np.zeros((pids_by_train.ngroups, table_len), dtype=np.int8) + num_trains = len(train_ids) + flags = np.zeros((num_trains, table_len), dtype=np.int8) + train_idx = 0 - for i, (_, train_pids) in enumerate(pids_by_train): - flags[i, train_pids.loc[:, :, True, :] - pid_offset] |= 1 - flags[i, train_pids.loc[:, :, :, True] - pid_offset] |= 2 + for train_id, train_pids in pulse_ids.groupby(level=0): + # See PulsePattern._get_pulse_mask. + for i in range(train_idx, num_trains): + if train_ids[i] == train_id: + train_idx = i + break + + try: + flags[train_idx, train_pids.loc[:, :, True, :] - pid_offset] |= 1 + except KeyError: + # No FEL pulses in this train. + pass + + try: + flags[train_idx, train_pids.loc[:, :, :, True] - pid_offset] |= 2 + except KeyError: + # No PPL pulses in this train. + pass return flags diff --git a/tests/conftest.py b/tests/conftest.py index c44e1424..3d2b4fde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,7 +65,7 @@ def mock_spb_aux_directory(): """Mock run directory with SPB auxiliary sources. Pulse pattern per train: - - 0:10, no pulses + - 0:5, no pulses - SA1 - 10:50, 50 pulses at 1000:1300:6 - 50:100, 25 pulses at 1000:1300:12 @@ -74,7 +74,7 @@ def mock_spb_aux_directory(): - SA3 - 10:100, 1 pulse at 200 - LP_SPB - - 10:100, 50 pulses at 0:300:6 + - 5:100, 50 pulses at 0:300:6 """ sources = [ diff --git a/tests/mockdata/timeserver.py b/tests/mockdata/timeserver.py index b88a6b95..2c513eb2 100644 --- a/tests/mockdata/timeserver.py +++ b/tests/mockdata/timeserver.py @@ -18,7 +18,7 @@ def _fill_bunch_pattern_table(table, num_rows, offset=10): table[offset:, 200] |= (DESTINATION_T4D | PHOTON_LINE_DEFLECTION) # LP_SPB - table[offset:, 0:300:6] |= PPL_BITS.LP_SPB + table[offset//2:, 0:300:6] |= PPL_BITS.LP_SPB class Timeserver(DeviceBase): diff --git a/tests/test_components_pulses.py b/tests/test_components_pulses.py index 391cfd00..a8cddb78 100644 --- a/tests/test_components_pulses.py +++ b/tests/test_components_pulses.py @@ -8,7 +8,8 @@ from euxfel_bunch_pattern import PPL_BITS from extra.data import RunDirectory, SourceData, KeyData, by_id -from extra.components import XrayPulses, OpticalLaserPulses, PumpProbePulses, DldPulses +from extra.components import XrayPulses, OpticalLaserPulses, PumpProbePulses, \ + DldPulses pattern_sources = dict( @@ -555,6 +556,22 @@ def test_pump_probe_basic(mock_spb_aux_run, source): # Pulse mask assert pulses.pulse_mask(labelled=False)[0, 1000:1306:6].all() + # Obtain a pulse mask for the entire run, including trains without + # any pulses and trains without FEL pulses. + # Requires use of bunch_table_position to avoid extrapolating FEL + # pulses at the beginning of the run. + mask = PumpProbePulses( + run, source=source, bunch_table_position=1001 + ).pulse_mask(labelled=False) + + assert not mask[:5].any() # No pulses at all. + + # No FEL pulses but PPL pulses. + assert not mask[5:10, 1000:1300:6].any() and mask[5:10, 1001:1301:6].all() + + # FEL and PPL pulses. + assert mask[10, 1000:1300:6].all() and mask[10, 1001:1301:6].all() + # Is constant pattern? assert not pulses.is_constant_pattern() assert pulses.select_trains(np.s_[:1]).is_constant_pattern()