Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PumpProbePulses.pulse_mask() for empty or partially empty trains #126

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/extra/components/pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,7 @@ def _get_pulse_ids(self):
def _get_pulse_mask(self, reduced=False):
# Actually returns flags instead of a mask.

train_ids = self._get_train_ids()
pulse_ids = self.pulse_ids(copy=False)
pids_by_train = pulse_ids.groupby(level=0)

Expand All @@ -1163,11 +1164,28 @@ def _get_pulse_mask(self, reduced=False):
pid_offset = 0
table_len = self._bunch_pattern_table_len

flags = np.zeros((pids_by_train.ngroups, table_len), dtype=np.int8)
num_trains = len(train_ids)
flags = np.zeros((num_trains, table_len), dtype=np.int8)
train_idx = 0

for i, (_, train_pids) in enumerate(pids_by_train):
flags[i, train_pids.loc[:, :, True, :] - pid_offset] |= 1
flags[i, train_pids.loc[:, :, :, True] - pid_offset] |= 2
for train_id, train_pids in pulse_ids.groupby(level=0):
# See PulsePattern._get_pulse_mask.
for i in range(train_idx, num_trains):
if train_ids[i] == train_id:
train_idx = i
break

try:
flags[train_idx, train_pids.loc[:, :, True, :] - pid_offset] |= 1
except KeyError:
# No FEL pulses in this train.
pass

try:
flags[train_idx, train_pids.loc[:, :, :, True] - pid_offset] |= 2
except KeyError:
# No PPL pulses in this train.
pass

return flags

Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def mock_spb_aux_directory():
"""Mock run directory with SPB auxiliary sources.

Pulse pattern per train:
- 0:10, no pulses
- 0:5, no pulses
- SA1
- 10:50, 50 pulses at 1000:1300:6
- 50:100, 25 pulses at 1000:1300:12
Expand All @@ -74,7 +74,7 @@ def mock_spb_aux_directory():
- SA3
- 10:100, 1 pulse at 200
- LP_SPB
- 10:100, 50 pulses at 0:300:6
- 5:100, 50 pulses at 0:300:6
"""

sources = [
Expand Down
2 changes: 1 addition & 1 deletion tests/mockdata/timeserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _fill_bunch_pattern_table(table, num_rows, offset=10):
table[offset:, 200] |= (DESTINATION_T4D | PHOTON_LINE_DEFLECTION)

# LP_SPB
table[offset:, 0:300:6] |= PPL_BITS.LP_SPB
table[offset//2:, 0:300:6] |= PPL_BITS.LP_SPB


class Timeserver(DeviceBase):
Expand Down
19 changes: 18 additions & 1 deletion tests/test_components_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from euxfel_bunch_pattern import PPL_BITS
from extra.data import RunDirectory, SourceData, KeyData, by_id
from extra.components import XrayPulses, OpticalLaserPulses, PumpProbePulses, DldPulses
from extra.components import XrayPulses, OpticalLaserPulses, PumpProbePulses, \
DldPulses


pattern_sources = dict(
Expand Down Expand Up @@ -555,6 +556,22 @@ def test_pump_probe_basic(mock_spb_aux_run, source):
# Pulse mask
assert pulses.pulse_mask(labelled=False)[0, 1000:1306:6].all()

# Obtain a pulse mask for the entire run, including trains without
# any pulses and trains without FEL pulses.
# Requires use of bunch_table_position to avoid extrapolating FEL
# pulses at the beginning of the run.
mask = PumpProbePulses(
run, source=source, bunch_table_position=1001
).pulse_mask(labelled=False)

assert not mask[:5].any() # No pulses at all.

# No FEL pulses but PPL pulses.
assert not mask[5:10, 1000:1300:6].any() and mask[5:10, 1001:1301:6].all()

# FEL and PPL pulses.
assert mask[10, 1000:1300:6].all() and mask[10, 1001:1301:6].all()

# Is constant pattern?
assert not pulses.is_constant_pattern()
assert pulses.select_trains(np.s_[:1]).is_constant_pattern()
Expand Down
Loading