diff --git a/can/io/mf4.py b/can/io/mf4.py index 042bf8765..4f5336b42 100644 --- a/can/io/mf4.py +++ b/can/io/mf4.py @@ -5,16 +5,18 @@ the ASAM MDF standard (see https://www.asam.net/standards/detail/mdf/) """ +import abc +import heapq import logging from datetime import datetime from hashlib import md5 from io import BufferedIOBase, BytesIO from pathlib import Path -from typing import Any, BinaryIO, Generator, Optional, Union, cast +from typing import Any, BinaryIO, Dict, Generator, Iterator, List, Optional, Union, cast from ..message import Message from ..typechecking import StringPathLike -from ..util import channel2int, dlc2len, len2dlc +from ..util import channel2int, len2dlc from .generic import BinaryIOMessageReader, BinaryIOMessageWriter logger = logging.getLogger("can.io.mf4") @@ -22,10 +24,10 @@ try: import asammdf import numpy as np - from asammdf import Signal + from asammdf import Signal, Source from asammdf.blocks.mdf_v4 import MDF4 - from asammdf.blocks.v4_blocks import SourceInformation - from asammdf.blocks.v4_constants import BUS_TYPE_CAN, SOURCE_BUS + from asammdf.blocks.v4_blocks import ChannelGroup, SourceInformation + from asammdf.blocks.v4_constants import BUS_TYPE_CAN, FLAG_CG_BUS_EVENT, SOURCE_BUS from asammdf.mdf import MDF STD_DTYPE = np.dtype( @@ -70,6 +72,8 @@ ) except ImportError: asammdf = None + MDF4 = None + Signal = None CAN_MSG_EXT = 0x80000000 @@ -266,13 +270,179 @@ def on_message_received(self, msg: Message) -> None: self._rtr_buffer = np.zeros(1, dtype=RTR_DTYPE) +class FrameIterator(metaclass=abc.ABCMeta): + """ + Iterator helper class for common handling among CAN DataFrames, ErrorFrames and RemoteFrames. + """ + + # Number of records to request for each asammdf call + _chunk_size = 1000 + + def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float, name: str): + self._mdf = mdf + self._group_index = group_index + self._start_timestamp = start_timestamp + self._name = name + + # Extract names + channel_group: ChannelGroup = self._mdf.groups[self._group_index] + + self._channel_names = [] + + for channel in channel_group.channels: + if str(channel.name).startswith(f"{self._name}."): + self._channel_names.append(channel.name) + + def _get_data(self, current_offset: int) -> Signal: + # NOTE: asammdf suggests using select instead of get. Select seem to miss converting some + # channels which get does convert as expected. + data_raw = self._mdf.get( + self._name, + self._group_index, + record_offset=current_offset, + record_count=self._chunk_size, + raw=False, + ) + + return data_raw + + @abc.abstractmethod + def __iter__(self) -> Generator[Message, None, None]: + pass + + class MF4Reader(BinaryIOMessageReader): """ Iterator of CAN messages from a MF4 logging file. - The MF4Reader only supports MF4 files that were recorded with python-can. + The MF4Reader only supports MF4 files with CAN bus logging. """ + # NOTE: Readout based on the bus logging code from asammdf GUI + + class _CANDataFrameIterator(FrameIterator): + + def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float): + super().__init__(mdf, group_index, start_timestamp, "CAN_DataFrame") + + def __iter__(self) -> Generator[Message, None, None]: + for current_offset in range( + 0, + self._mdf.groups[self._group_index].channel_group.cycles_nr, + self._chunk_size, + ): + data = self._get_data(current_offset) + names = data.samples[0].dtype.names + + for i in range(len(data)): + data_length = int(data["CAN_DataFrame.DataLength"][i]) + + kv: Dict[str, Any] = { + "timestamp": float(data.timestamps[i]) + self._start_timestamp, + "arbitration_id": int(data["CAN_DataFrame.ID"][i]) & 0x1FFFFFFF, + "data": data["CAN_DataFrame.DataBytes"][i][ + :data_length + ].tobytes(), + } + + if "CAN_DataFrame.BusChannel" in names: + kv["channel"] = int(data["CAN_DataFrame.BusChannel"][i]) + if "CAN_DataFrame.Dir" in names: + kv["is_rx"] = int(data["CAN_DataFrame.Dir"][i]) == 0 + if "CAN_DataFrame.IDE" in names: + kv["is_extended_id"] = bool(data["CAN_DataFrame.IDE"][i]) + if "CAN_DataFrame.EDL" in names: + kv["is_fd"] = bool(data["CAN_DataFrame.EDL"][i]) + if "CAN_DataFrame.BRS" in names: + kv["bitrate_switch"] = bool(data["CAN_DataFrame.BRS"][i]) + if "CAN_DataFrame.ESI" in names: + kv["error_state_indicator"] = bool(data["CAN_DataFrame.ESI"][i]) + + yield Message(**kv) + + class _CANErrorFrameIterator(FrameIterator): + + def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float): + super().__init__(mdf, group_index, start_timestamp, "CAN_ErrorFrame") + + def __iter__(self) -> Generator[Message, None, None]: + for current_offset in range( + 0, + self._mdf.groups[self._group_index].channel_group.cycles_nr, + self._chunk_size, + ): + data = self._get_data(current_offset) + names = data.samples[0].dtype.names + + for i in range(len(data)): + kv: Dict[str, Any] = { + "timestamp": float(data.timestamps[i]) + self._start_timestamp, + "is_error_frame": True, + } + + if "CAN_ErrorFrame.BusChannel" in names: + kv["channel"] = int(data["CAN_ErrorFrame.BusChannel"][i]) + if "CAN_ErrorFrame.Dir" in names: + kv["is_rx"] = int(data["CAN_ErrorFrame.Dir"][i]) == 0 + if "CAN_ErrorFrame.ID" in names: + kv["arbitration_id"] = ( + int(data["CAN_ErrorFrame.ID"][i]) & 0x1FFFFFFF + ) + if "CAN_ErrorFrame.IDE" in names: + kv["is_extended_id"] = bool(data["CAN_ErrorFrame.IDE"][i]) + if "CAN_ErrorFrame.EDL" in names: + kv["is_fd"] = bool(data["CAN_ErrorFrame.EDL"][i]) + if "CAN_ErrorFrame.BRS" in names: + kv["bitrate_switch"] = bool(data["CAN_ErrorFrame.BRS"][i]) + if "CAN_ErrorFrame.ESI" in names: + kv["error_state_indicator"] = bool( + data["CAN_ErrorFrame.ESI"][i] + ) + if "CAN_ErrorFrame.RTR" in names: + kv["is_remote_frame"] = bool(data["CAN_ErrorFrame.RTR"][i]) + if ( + "CAN_ErrorFrame.DataLength" in names + and "CAN_ErrorFrame.DataBytes" in names + ): + data_length = int(data["CAN_ErrorFrame.DataLength"][i]) + kv["data"] = data["CAN_ErrorFrame.DataBytes"][i][ + :data_length + ].tobytes() + + yield Message(**kv) + + class _CANRemoteFrameIterator(FrameIterator): + + def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float): + super().__init__(mdf, group_index, start_timestamp, "CAN_RemoteFrame") + + def __iter__(self) -> Generator[Message, None, None]: + for current_offset in range( + 0, + self._mdf.groups[self._group_index].channel_group.cycles_nr, + self._chunk_size, + ): + data = self._get_data(current_offset) + names = data.samples[0].dtype.names + + for i in range(len(data)): + kv: Dict[str, Any] = { + "timestamp": float(data.timestamps[i]) + self._start_timestamp, + "arbitration_id": int(data["CAN_RemoteFrame.ID"][i]) + & 0x1FFFFFFF, + "dlc": int(data["CAN_RemoteFrame.DLC"][i]), + "is_remote_frame": True, + } + + if "CAN_RemoteFrame.BusChannel" in names: + kv["channel"] = int(data["CAN_RemoteFrame.BusChannel"][i]) + if "CAN_RemoteFrame.Dir" in names: + kv["is_rx"] = int(data["CAN_RemoteFrame.Dir"][i]) == 0 + if "CAN_RemoteFrame.IDE" in names: + kv["is_extended_id"] = bool(data["CAN_RemoteFrame.IDE"][i]) + + yield Message(**kv) + def __init__( self, file: Union[StringPathLike, BinaryIO], @@ -293,193 +463,65 @@ def __init__( self._mdf: MDF4 if isinstance(file, BufferedIOBase): - self._mdf = MDF(BytesIO(file.read())) + self._mdf = cast(MDF4, MDF(BytesIO(file.read()))) else: - self._mdf = MDF(file) - - self.start_timestamp = self._mdf.header.start_time.timestamp() - - masters = [self._mdf.get_master(i) for i in range(3)] - - masters = [ - np.core.records.fromarrays((master, np.ones(len(master)) * i)) - for i, master in enumerate(masters) - ] - - self.masters = np.sort(np.concatenate(masters)) - - def __iter__(self) -> Generator[Message, None, None]: - standard_counter = 0 - error_counter = 0 - rtr_counter = 0 - - for timestamp, group_index in self.masters: - # standard frames - if group_index == 0: - sample = self._mdf.get( - "CAN_DataFrame", - group=group_index, - raw=True, - record_offset=standard_counter, - record_count=1, - ) - - try: - channel = int(sample["CAN_DataFrame.BusChannel"][0]) - except ValueError: - channel = None - - if sample["CAN_DataFrame.EDL"] == 0: - is_extended_id = bool(sample["CAN_DataFrame.IDE"][0]) - arbitration_id = int(sample["CAN_DataFrame.ID"][0]) - is_rx = int(sample["CAN_DataFrame.Dir"][0]) == 0 - size = int(sample["CAN_DataFrame.DataLength"][0]) - dlc = int(sample["CAN_DataFrame.DLC"][0]) - data = sample["CAN_DataFrame.DataBytes"][0, :size].tobytes() - - msg = Message( - timestamp=timestamp + self.start_timestamp, - is_error_frame=False, - is_remote_frame=False, - is_fd=False, - is_extended_id=is_extended_id, - channel=channel, - is_rx=is_rx, - arbitration_id=arbitration_id, - data=data, - dlc=dlc, - ) - - else: - is_extended_id = bool(sample["CAN_DataFrame.IDE"][0]) - arbitration_id = int(sample["CAN_DataFrame.ID"][0]) - is_rx = int(sample["CAN_DataFrame.Dir"][0]) == 0 - size = int(sample["CAN_DataFrame.DataLength"][0]) - dlc = dlc2len(sample["CAN_DataFrame.DLC"][0]) - data = sample["CAN_DataFrame.DataBytes"][0, :size].tobytes() - error_state_indicator = bool(sample["CAN_DataFrame.ESI"][0]) - bitrate_switch = bool(sample["CAN_DataFrame.BRS"][0]) - - msg = Message( - timestamp=timestamp + self.start_timestamp, - is_error_frame=False, - is_remote_frame=False, - is_fd=True, - is_extended_id=is_extended_id, - channel=channel, - arbitration_id=arbitration_id, - is_rx=is_rx, - data=data, - dlc=dlc, - bitrate_switch=bitrate_switch, - error_state_indicator=error_state_indicator, + self._mdf = cast(MDF4, MDF(file)) + + self._start_timestamp = self._mdf.header.start_time.timestamp() + + def __iter__(self) -> Iterator[Message]: + # To handle messages split over multiple channel groups, create a single iterator per + # channel group and merge these iterators into a single iterator using heapq. + iterators: List[FrameIterator] = [] + for group_index, group in enumerate(self._mdf.groups): + channel_group: ChannelGroup = group.channel_group + + if not channel_group.flags & FLAG_CG_BUS_EVENT: + # Not a bus event, skip + continue + + if channel_group.cycles_nr == 0: + # No data, skip + continue + + acquisition_source: Optional[Source] = channel_group.acq_source + + if acquisition_source is None: + # No source information, skip + continue + if not acquisition_source.source_type & Source.SOURCE_BUS: + # Not a bus type (likely already covered by the channel group flag), skip + continue + + channel_names = [channel.name for channel in group.channels] + + if acquisition_source.bus_type == Source.BUS_TYPE_CAN: + if "CAN_DataFrame" in channel_names: + iterators.append( + self._CANDataFrameIterator( + self._mdf, group_index, self._start_timestamp + ) ) - - yield msg - standard_counter += 1 - - # error frames - elif group_index == 1: - sample = self._mdf.get( - "CAN_ErrorFrame", - group=group_index, - raw=True, - record_offset=error_counter, - record_count=1, - ) - - try: - channel = int(sample["CAN_ErrorFrame.BusChannel"][0]) - except ValueError: - channel = None - - if sample["CAN_ErrorFrame.EDL"] == 0: - is_extended_id = bool(sample["CAN_ErrorFrame.IDE"][0]) - arbitration_id = int(sample["CAN_ErrorFrame.ID"][0]) - is_rx = int(sample["CAN_ErrorFrame.Dir"][0]) == 0 - size = int(sample["CAN_ErrorFrame.DataLength"][0]) - dlc = int(sample["CAN_ErrorFrame.DLC"][0]) - data = sample["CAN_ErrorFrame.DataBytes"][0, :size].tobytes() - - msg = Message( - timestamp=timestamp + self.start_timestamp, - is_error_frame=True, - is_remote_frame=False, - is_fd=False, - is_extended_id=is_extended_id, - channel=channel, - arbitration_id=arbitration_id, - is_rx=is_rx, - data=data, - dlc=dlc, + elif "CAN_ErrorFrame" in channel_names: + iterators.append( + self._CANErrorFrameIterator( + self._mdf, group_index, self._start_timestamp + ) ) - - else: - is_extended_id = bool(sample["CAN_ErrorFrame.IDE"][0]) - arbitration_id = int(sample["CAN_ErrorFrame.ID"][0]) - is_rx = int(sample["CAN_ErrorFrame.Dir"][0]) == 0 - size = int(sample["CAN_ErrorFrame.DataLength"][0]) - dlc = dlc2len(sample["CAN_ErrorFrame.DLC"][0]) - data = sample["CAN_ErrorFrame.DataBytes"][0, :size].tobytes() - error_state_indicator = bool(sample["CAN_ErrorFrame.ESI"][0]) - bitrate_switch = bool(sample["CAN_ErrorFrame.BRS"][0]) - - msg = Message( - timestamp=timestamp + self.start_timestamp, - is_error_frame=True, - is_remote_frame=False, - is_fd=True, - is_extended_id=is_extended_id, - channel=channel, - arbitration_id=arbitration_id, - is_rx=is_rx, - data=data, - dlc=dlc, - bitrate_switch=bitrate_switch, - error_state_indicator=error_state_indicator, + elif "CAN_RemoteFrame" in channel_names: + iterators.append( + self._CANRemoteFrameIterator( + self._mdf, group_index, self._start_timestamp + ) ) - - yield msg - error_counter += 1 - - # remote frames else: - sample = self._mdf.get( - "CAN_RemoteFrame", - group=group_index, - raw=True, - record_offset=rtr_counter, - record_count=1, - ) - - try: - channel = int(sample["CAN_RemoteFrame.BusChannel"][0]) - except ValueError: - channel = None - - is_extended_id = bool(sample["CAN_RemoteFrame.IDE"][0]) - arbitration_id = int(sample["CAN_RemoteFrame.ID"][0]) - is_rx = int(sample["CAN_RemoteFrame.Dir"][0]) == 0 - dlc = int(sample["CAN_RemoteFrame.DLC"][0]) - - msg = Message( - timestamp=timestamp + self.start_timestamp, - is_error_frame=False, - is_remote_frame=True, - is_fd=False, - is_extended_id=is_extended_id, - channel=channel, - arbitration_id=arbitration_id, - is_rx=is_rx, - dlc=dlc, - ) - - yield msg - - rtr_counter += 1 - - self.stop() + # Unknown bus type, skip + continue + + # Create merged iterator over all the groups, using the timestamps as comparison key + return iter(heapq.merge(*iterators, key=lambda x: x.timestamp)) def stop(self) -> None: self._mdf.close() + self._mdf = None super().stop()