Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JDBetteridge/parallel assert #6

Merged
merged 5 commits into from
Dec 18, 2024
Merged

Conversation

JDBetteridge
Copy link
Member

Add a parallel assertion feature

Copy link
Collaborator

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are quite a few small things that need fixing and comments that should be addressed. In general this seems good though.

@JHopeCollins
Copy link
Member

JHopeCollins commented Oct 15, 2024

@JDBetteridge:

@connorjward and I were chatting yesterday and I think we're both pretty happy with this, it's something we really need.
We've made various suggestions across different threads, so to make it a bit clearer to follow I've written out below what the function would look like with all of our suggestions. Not that you should automatically accept all of them, but just to collect the discussion together.

Changes are:

  1. Remove the subset argument and just pass None as the assertion argument for ranks that are not participating. This way each rank doesn't need to work out/know what other ranks are participating.
  2. Make the assertion argument just a plain bool for participating ranks, it isn't clear why the evaluation of the condition should be delayed until inside the parallel_assert function.
  3. Use if not all instead of if not min for checking if all ranks have passed, because mpi4py won't implicitly convert the bool array to ints.
  4. I've also added typehinting, and a couple of things to the docstring - a "Raises" section, and an example with only a subset of ranks participating. I've not made these into proper review suggestions yet because they are combined with the other suggestions.
def parallel_assert(assertion: Union[bool, None], msg: str=""):
    """Make an assertion across MPI.COMM_WORLD

    Parameters:
    -----------
    assertion:
        Boolean that will be tested for truthyness. This should be `None`
        on any rank that is not participating in the assertion.
    msg:
        A informative message to be printed if the assertion fails on any rank.

    Raises:
    ---------
    AssertionError
        Raised on all ranks if the :assertion: argument is :False: on any rank.

    Example:
    --------
    Where in serial code one would have previously written:
    ```python
    x = f()
    assert x < 5
    ```

    Now write, if all ranks are participating:
    ```python
    x = f()
    parallel_assert(x < 5)
    ```

    Or if only the first 2 ranks are participating, and with a helpful message:
    ```python
    x = f()
    rank = MPI.COMM_WORLD.rank
    parallel_assert(x < 5 if (rank < 3) else None,
                    msg="x is not less than 5")
    ```
    """
    if assertion is None:
        assertion = True
    all_assertions = MPI.COMM_WORLD.allgather(assertion)
    if not all(all_assertions):
        raise AssertionError(
            "Parallel assertion failed on ranks:"
            f"{[ii for ii, b in enumerate(all_assertions) if not b]}\n" + msg                 
            )

JDBetteridge and others added 2 commits October 21, 2024 16:33
@connorjward connorjward marked this pull request as ready for review December 18, 2024 14:49
@connorjward connorjward merged commit f5668e4 into main Dec 18, 2024
@connorjward connorjward deleted the JDBetteridge/parallel_assert branch December 18, 2024 14:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants