diff --git a/server/migrations/versions/2024-11-26-1444_add_order_billing_address.py b/server/migrations/versions/2024-11-26-1444_add_order_billing_address.py new file mode 100644 index 0000000000..3d3ed6c458 --- /dev/null +++ b/server/migrations/versions/2024-11-26-1444_add_order_billing_address.py @@ -0,0 +1,159 @@ +"""Add Order.billing_address + +Revision ID: 1769a6e618a4 +Revises: 6cbeabf73caf +Create Date: 2024-11-26 14:44:03.569035 + +""" + +import concurrent.futures +import random +import time +from typing import Any, TypedDict, cast + +import sqlalchemy as sa +import stripe as stripe_lib +from alembic import op +from pydantic import ValidationError + +from polar import payment_method +from polar.config import settings + +# Polar Custom Imports +from polar.integrations.stripe.utils import get_expandable_id +from polar.kit.address import Address, AddressType + +# revision identifiers, used by Alembic. +revision = "1769a6e618a4" +down_revision = "6cbeabf73caf" +branch_labels: tuple[str] | None = None +depends_on: tuple[str] | None = None + + +stripe_client = stripe_lib.StripeClient( + settings.STRIPE_SECRET_KEY, + http_client=stripe_lib.HTTPXClient(allow_sync_methods=True), +) + + +class MigratedOrder(TypedDict): + order_id: str + amount: int + billing_address: dict[str, Any] | None + + +def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool: + return customer_address is None or customer_address["country"] is None + + +def migrate_order( + order: tuple[str, int, str | None, str | None], retry: int = 1 +) -> MigratedOrder: + order_id, amount, stripe_invoice_id, stripe_charge_id = order + + if stripe_invoice_id is None and stripe_charge_id is None: + raise ValueError(f"No invoice or charge: {order_id}") + + customer_address: Any | None = None + try: + # Get from invoice + if stripe_invoice_id is not None: + invoice = stripe_client.invoices.retrieve(stripe_invoice_id) + customer_address = invoice.customer_address + # No address on invoice, try to get from charge + if ( + _is_empty_customer_address(customer_address) + and invoice.charge is not None + ): + return migrate_order( + (order_id, amount, None, get_expandable_id(invoice.charge)) + ) + # Get from charge + elif stripe_charge_id is not None: + charge = stripe_client.charges.retrieve( + stripe_charge_id, + params={ + "expand": ["payment_method_details", "payment_method_details.card"] + }, + ) + customer_address = charge.billing_details.address + # No address on charge, try to get from payment method + if _is_empty_customer_address(customer_address): + if payment_method_details := charge.payment_method_details: + if card := getattr(payment_method_details, "card", None): + customer_address = {"country": card.country} + except stripe_lib.RateLimitError: + time.sleep(retry + random.random()) + return migrate_order(order, retry=retry + 1) + + billing_address: dict[str, Any] | None = None + if not _is_empty_customer_address(customer_address): + try: + billing_address = ( + Address.model_validate(customer_address).model_dump() + if customer_address + else None + ) + except ValidationError as e: + raise ValueError(f"Invalid address for order {order_id}: {e}") + + return {"order_id": order_id, "amount": amount, "billing_address": billing_address} + + +def migrate_orders( + results: sa.CursorResult[tuple[str, int, str | None, str | None]], +) -> list[MigratedOrder]: + migrated_orders: list[MigratedOrder] = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + futures = [executor.submit(migrate_order, order._tuple()) for order in results] + for future in concurrent.futures.as_completed(futures): + migrated_orders.append(future.result()) + return migrated_orders + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "orders", + sa.Column( + "billing_address", + AddressType(astext_type=sa.Text()), + nullable=True, + ), + ) + + connection = op.get_bind() + orders = connection.execute( + sa.text(""" + SELECT orders.id, orders.amount, orders.stripe_invoice_id, orders.user_metadata->>'charge_id' AS stripe_charge_id + FROM orders + """) + ) + migrated_orders = migrate_orders(orders) + for migrated_order in migrated_orders: + if migrated_order["billing_address"] is None: + if migrated_order["amount"] != 0: + print("No billing address for paid order", migrated_order["order_id"]) # noqa: T201 + continue + op.execute( + sa.text( + """ + UPDATE orders + SET billing_address = :billing_address + WHERE id = :order_id + """ + ).bindparams( + sa.bindparam( + "billing_address", migrated_order["billing_address"], type_=sa.JSON + ), + order_id=migrated_order["order_id"], + ) + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("orders", "billing_address") + # ### end Alembic commands ### diff --git a/server/polar/models/order.py b/server/polar/models/order.py index eb78f17231..244399ff2c 100644 --- a/server/polar/models/order.py +++ b/server/polar/models/order.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from polar.custom_field.data import CustomFieldDataMixin +from polar.kit.address import Address, AddressType from polar.kit.db.models import RecordModel from polar.kit.metadata import MetadataMixin @@ -38,6 +39,7 @@ class Order(CustomFieldDataMixin, MetadataMixin, RecordModel): billing_reason: Mapped[OrderBillingReason] = mapped_column( String, nullable=False, index=True ) + billing_address: Mapped[Address | None] = mapped_column(AddressType, nullable=True) stripe_invoice_id: Mapped[str | None] = mapped_column( String, nullable=True, unique=True ) diff --git a/server/polar/order/schemas.py b/server/polar/order/schemas.py index 38770aa113..28e41c44ea 100644 --- a/server/polar/order/schemas.py +++ b/server/polar/order/schemas.py @@ -7,6 +7,7 @@ from polar.discount.schemas import ( DiscountMinimal, ) +from polar.kit.address import Address from polar.kit.metadata import MetadataOutputMixin from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema from polar.models.order import OrderBillingReason @@ -21,6 +22,7 @@ class OrderBase( tax_amount: int currency: str billing_reason: OrderBillingReason + billing_address: Address | None user_id: UUID4 product_id: UUID4 diff --git a/server/polar/order/service.py b/server/polar/order/service.py index 22af0c2172..0de90a0940 100644 --- a/server/polar/order/service.py +++ b/server/polar/order/service.py @@ -20,6 +20,7 @@ from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import stripe as stripe_service from polar.integrations.stripe.utils import get_expandable_id +from polar.kit.address import Address from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams, paginate from polar.kit.services import ResourceServiceReader @@ -149,6 +150,10 @@ def __init__(self, order: Order) -> None: super().__init__(message, 404) +def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool: + return customer_address is None or customer_address["country"] is None + + class OrderService(ResourceServiceReader[Order]): async def list( self, @@ -303,6 +308,16 @@ async def create_order_from_stripe( product = product_price.product + billing_address: Address | None = None + if not _is_empty_customer_address(invoice.customer_address): + billing_address = Address.model_validate(invoice.customer_address) + # Try to retrieve the country from the payment method + elif invoice.charge is not None: + charge = await stripe_service.get_charge(get_expandable_id(invoice.charge)) + if payment_method_details := charge.payment_method_details: + if card := getattr(payment_method_details, "card", None): + billing_address = Address.model_validate({"country": card.country}) + # Get Discount if available discount: Discount | None = None if invoice.discount is not None: @@ -329,7 +344,6 @@ async def create_order_from_stripe( user: User | None = None billing_reason: OrderBillingReason = OrderBillingReason.purchase - tax = invoice.tax or 0 amount = invoice.total - tax @@ -368,6 +382,7 @@ async def create_order_from_stripe( tax_amount=tax, currency=invoice.currency, billing_reason=billing_reason, + billing_address=billing_address, stripe_invoice_id=invoice.id, product=product, product_price=product_price, diff --git a/server/tests/order/test_service.py b/server/tests/order/test_service.py index 32f05e26bc..23cbcce85c 100644 --- a/server/tests/order/test_service.py +++ b/server/tests/order/test_service.py @@ -9,6 +9,7 @@ from polar.held_balance.service import held_balance as held_balance_service from polar.integrations.stripe.schemas import ProductType from polar.integrations.stripe.service import StripeService +from polar.kit.address import Address from polar.kit.db.postgres import AsyncSession from polar.kit.pagination import PaginationParams from polar.models import ( @@ -55,6 +56,7 @@ def construct_stripe_invoice( lines: list[tuple[str, bool, dict[str, str] | None]] = [("PRICE_ID", False, None)], metadata: dict[str, str] = {}, billing_reason: str = "subscription_create", + customer_address: dict[str, Any] | None = {"country": "FR"}, paid_out_of_band: bool = False, discount: Discount | None = None, ) -> stripe_lib.Invoice: @@ -69,6 +71,7 @@ def construct_stripe_invoice( "subscription": subscription_id, "subscription_details": subscription_details, "customer": customer_id, + "customer_address": customer_address, "lines": { "data": [ { @@ -604,6 +607,7 @@ async def test_one_time_product( assert order.product_price == product_one_time.prices[0] assert order.subscription is None assert order.billing_reason == OrderBillingReason.purchase + assert order.billing_address == Address(country="FR") # pyright: ignore enqueue_job_mock.assert_any_call( "order.discord_notification", @@ -862,6 +866,131 @@ async def test_charge_from_metadata( order_id=order.id, ) + @pytest.mark.parametrize( + "customer_address", + [ + None, + {"country": None}, + ], + ) + async def test_no_billing_address( + self, + customer_address: dict[str, Any] | None, + save_fixture: SaveFixture, + mocker: MockerFixture, + session: AsyncSession, + product: Product, + user: User, + organization_account: Account, + ) -> None: + mock = MagicMock(spec=StripeService) + mocker.patch("polar.order.service.stripe_service", new=mock) + mock.get_charge.return_value = stripe_lib.Charge.construct_from( + {"id": "CHARGE_ID", "payment_method_details": None}, + key=None, + ) + invoice = construct_stripe_invoice( + lines=[(product.prices[0].stripe_price_id, False, None)], + customer_address=customer_address, + subscription_id=None, + billing_reason="manual", + ) + invoice_total = invoice.total - (invoice.tax or 0) + + user.stripe_customer_id = "CUSTOMER_ID" + await save_fixture(user) + + payment_transaction = await create_transaction( + save_fixture, type=TransactionType.payment + ) + payment_transaction.charge_id = "CHARGE_ID" + await save_fixture(payment_transaction) + + transaction_service_mock = mocker.patch( + "polar.order.service.balance_transaction_service", + spec=BalanceTransactionService, + ) + transaction_service_mock.get_by.return_value = payment_transaction + transaction_service_mock.create_balance_from_charge.return_value = ( + Transaction(type=TransactionType.balance, amount=-invoice_total), + Transaction( + type=TransactionType.balance, + amount=invoice_total, + account_id=organization_account.id, + ), + ) + mocker.patch( + "polar.order.service.platform_fee_transaction_service", + spec=PlatformFeeTransactionService, + ) + + order = await order_service.create_order_from_stripe(session, invoice=invoice) + assert order.billing_address is None + + async def test_billing_address_from_payment_method( + self, + mocker: MockerFixture, + save_fixture: SaveFixture, + session: AsyncSession, + product_one_time: Product, + user: User, + organization_account: Account, + ) -> None: + mock = MagicMock(spec=StripeService) + mocker.patch("polar.order.service.stripe_service", new=mock) + mock.get_charge.return_value = stripe_lib.Charge.construct_from( + { + "id": "CHARGE_ID", + "payment_method_details": { + "card": { + "country": "US", + } + }, + }, + key=None, + ) + invoice = construct_stripe_invoice( + charge_id="CHARGE_ID", + lines=[(product_one_time.prices[0].stripe_price_id, False, None)], + customer_address=None, + subscription_id=None, + billing_reason="manual", + ) + invoice_total = invoice.total - (invoice.tax or 0) + + user.stripe_customer_id = "CUSTOMER_ID" + await save_fixture(user) + + payment_transaction = await create_transaction( + save_fixture, type=TransactionType.payment + ) + payment_transaction.charge_id = "CHARGE_ID" + await save_fixture(payment_transaction) + + transaction_service_mock = mocker.patch( + "polar.order.service.balance_transaction_service", + spec=BalanceTransactionService, + ) + transaction_service_mock.get_by.return_value = payment_transaction + transaction_service_mock.create_balance_from_charge.return_value = ( + Transaction(type=TransactionType.balance, amount=-invoice_total), + Transaction( + type=TransactionType.balance, + amount=invoice_total, + account_id=organization_account.id, + ), + ) + mocker.patch( + "polar.order.service.platform_fee_transaction_service", + spec=PlatformFeeTransactionService, + ) + + user.stripe_customer_id = "CUSTOMER_ID" + await save_fixture(user) + + order = await order_service.create_order_from_stripe(session, invoice=invoice) + assert order.billing_address == Address(country="US") # type: ignore + @pytest.mark.asyncio @pytest.mark.skip_db_asserts