diff --git a/smart_importer/entries.py b/smart_importer/entries.py index c336550..b05e1d2 100644 --- a/smart_importer/entries.py +++ b/smart_importer/entries.py @@ -1,22 +1,34 @@ """Helpers to work with Beancount entry objects.""" +from __future__ import annotations + from beancount.core.data import Posting, Transaction -def update_postings(transaction, accounts): - """Update the list of postings of a transaction to match the accounts.""" +def update_postings( + transaction: Transaction, accounts: list[str] +) -> Transaction: + """Update the list of postings of a transaction to match the accounts. + + Expects the transaction to be updated to have exactly one posting, + otherwise it is returned unchanged. Adds empty postings for all the + accounts - if the account of the single existing posting is found + in the list of accounts, it is placed there at the first occurence, + otherwise it is appended at the end. + """ if len(transaction.postings) != 1: return transaction + posting = transaction.postings[0] + new_postings = [ Posting(account, None, None, None, None, None) for account in accounts ] - for posting in transaction.postings: - if posting.account in accounts: - new_postings[accounts.index(posting.account)] = posting - else: - new_postings.append(posting) + if posting.account in accounts: + new_postings[accounts.index(posting.account)] = posting + else: + new_postings.append(posting) return transaction._replace(postings=new_postings) diff --git a/tests/entries_test.py b/tests/entries_test.py new file mode 100644 index 0000000..cc29de9 --- /dev/null +++ b/tests/entries_test.py @@ -0,0 +1,50 @@ +"""Tests for the entry helpers.""" + +from __future__ import annotations + +# pylint: disable=missing-docstring +from beancount.parser import parser + +from smart_importer.entries import update_postings + +TEST_DATA, _errors, _options = parser.parse_string( + """ +2016-01-06 * "Farmer Fresh" "Buying groceries" + Assets:US:BofA:Checking -10.00 USD + +2016-01-06 * "Farmer Fresh" "Buying groceries" + Assets:US:BofA:Checking -10.00 USD + Assets:US:BofA:Checking 10.00 USD +""" +) + + +def test_update_postings() -> None: + txn0 = TEST_DATA[0] + + def _update(accounts: list[str]) -> list[tuple[str, bool]]: + """Update, get accounts and whether this is the original posting.""" + updated = update_postings(txn0, accounts) + return [(p.account, p is txn0.postings[0]) for p in updated.postings] + + assert _update(["Assets:US:BofA:Checking", "Assets:Other"]) == [ + ("Assets:US:BofA:Checking", True), + ("Assets:Other", False), + ] + + assert _update( + ["Assets:US:BofA:Checking", "Assets:US:BofA:Checking", "Assets:Other"] + ) == [ + ("Assets:US:BofA:Checking", True), + ("Assets:US:BofA:Checking", False), + ("Assets:Other", False), + ] + + assert _update(["Assets:Other", "Assets:Other2"]) == [ + ("Assets:Other", False), + ("Assets:Other2", False), + ("Assets:US:BofA:Checking", True), + ] + + txn1 = TEST_DATA[1] + assert update_postings(txn1, ["Assets:Other"]) == txn1