Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server/order: store billing address #4544

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Add Order.billing_address

Revision ID: 1769a6e618a4
Revises: 6cbeabf73caf
Create Date: 2024-11-26 14:44:03.569035

"""

import concurrent.futures
import random
import time
from typing import Any, TypedDict, cast

import sqlalchemy as sa
import stripe as stripe_lib
from alembic import op
from pydantic import ValidationError

from polar import payment_method
from polar.config import settings

# Polar Custom Imports
from polar.integrations.stripe.utils import get_expandable_id
from polar.kit.address import Address, AddressType

# revision identifiers, used by Alembic.
revision = "1769a6e618a4"
down_revision = "6cbeabf73caf"
branch_labels: tuple[str] | None = None
depends_on: tuple[str] | None = None


stripe_client = stripe_lib.StripeClient(
settings.STRIPE_SECRET_KEY,
http_client=stripe_lib.HTTPXClient(allow_sync_methods=True),
)


class MigratedOrder(TypedDict):
order_id: str
amount: int
billing_address: dict[str, Any] | None


def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool:
return customer_address is None or customer_address["country"] is None


def migrate_order(
order: tuple[str, int, str | None, str | None], retry: int = 1
) -> MigratedOrder:
order_id, amount, stripe_invoice_id, stripe_charge_id = order

if stripe_invoice_id is None and stripe_charge_id is None:
raise ValueError(f"No invoice or charge: {order_id}")

customer_address: Any | None = None
try:
# Get from invoice
if stripe_invoice_id is not None:
invoice = stripe_client.invoices.retrieve(stripe_invoice_id)
customer_address = invoice.customer_address
# No address on invoice, try to get from charge
if (
_is_empty_customer_address(customer_address)
and invoice.charge is not None
):
return migrate_order(
(order_id, amount, None, get_expandable_id(invoice.charge))
)
# Get from charge
elif stripe_charge_id is not None:
charge = stripe_client.charges.retrieve(
stripe_charge_id,
params={
"expand": ["payment_method_details", "payment_method_details.card"]
},
)
customer_address = charge.billing_details.address
# No address on charge, try to get from payment method
if _is_empty_customer_address(customer_address):
if payment_method_details := charge.payment_method_details:
if card := getattr(payment_method_details, "card", None):
customer_address = {"country": card.country}
except stripe_lib.RateLimitError:
time.sleep(retry + random.random())
return migrate_order(order, retry=retry + 1)

billing_address: dict[str, Any] | None = None
if not _is_empty_customer_address(customer_address):
try:
billing_address = (
Address.model_validate(customer_address).model_dump()
if customer_address
else None
)
except ValidationError as e:
raise ValueError(f"Invalid address for order {order_id}: {e}")

return {"order_id": order_id, "amount": amount, "billing_address": billing_address}


def migrate_orders(
results: sa.CursorResult[tuple[str, int, str | None, str | None]],
) -> list[MigratedOrder]:
migrated_orders: list[MigratedOrder] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
futures = [executor.submit(migrate_order, order._tuple()) for order in results]
for future in concurrent.futures.as_completed(futures):
migrated_orders.append(future.result())
return migrated_orders


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"orders",
sa.Column(
"billing_address",
AddressType(astext_type=sa.Text()),
nullable=True,
),
)

connection = op.get_bind()
orders = connection.execute(
sa.text("""
SELECT orders.id, orders.amount, orders.stripe_invoice_id, orders.user_metadata->>'charge_id' AS stripe_charge_id
FROM orders
""")
)
migrated_orders = migrate_orders(orders)
for migrated_order in migrated_orders:
if migrated_order["billing_address"] is None:
if migrated_order["amount"] != 0:
print("No billing address for paid order", migrated_order["order_id"]) # noqa: T201
continue
op.execute(
sa.text(
"""
UPDATE orders
SET billing_address = :billing_address
WHERE id = :order_id
"""
).bindparams(
sa.bindparam(
"billing_address", migrated_order["billing_address"], type_=sa.JSON
),
order_id=migrated_order["order_id"],
)
)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("orders", "billing_address")
# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions server/polar/models/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from polar.custom_field.data import CustomFieldDataMixin
from polar.kit.address import Address, AddressType
from polar.kit.db.models import RecordModel
from polar.kit.metadata import MetadataMixin

Expand Down Expand Up @@ -38,6 +39,7 @@ class Order(CustomFieldDataMixin, MetadataMixin, RecordModel):
billing_reason: Mapped[OrderBillingReason] = mapped_column(
String, nullable=False, index=True
)
billing_address: Mapped[Address | None] = mapped_column(AddressType, nullable=True)
stripe_invoice_id: Mapped[str | None] = mapped_column(
String, nullable=True, unique=True
)
Expand Down
2 changes: 2 additions & 0 deletions server/polar/order/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from polar.discount.schemas import (
DiscountMinimal,
)
from polar.kit.address import Address
from polar.kit.metadata import MetadataOutputMixin
from polar.kit.schemas import IDSchema, MergeJSONSchema, Schema, TimestampedSchema
from polar.models.order import OrderBillingReason
Expand All @@ -21,6 +22,7 @@ class OrderBase(
tax_amount: int
currency: str
billing_reason: OrderBillingReason
billing_address: Address | None

user_id: UUID4
product_id: UUID4
Expand Down
17 changes: 16 additions & 1 deletion server/polar/order/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from polar.integrations.stripe.schemas import ProductType
from polar.integrations.stripe.service import stripe as stripe_service
from polar.integrations.stripe.utils import get_expandable_id
from polar.kit.address import Address
from polar.kit.db.postgres import AsyncSession
from polar.kit.pagination import PaginationParams, paginate
from polar.kit.services import ResourceServiceReader
Expand Down Expand Up @@ -149,6 +150,10 @@ def __init__(self, order: Order) -> None:
super().__init__(message, 404)


def _is_empty_customer_address(customer_address: dict[str, Any] | None) -> bool:
return customer_address is None or customer_address["country"] is None


class OrderService(ResourceServiceReader[Order]):
async def list(
self,
Expand Down Expand Up @@ -303,6 +308,16 @@ async def create_order_from_stripe(

product = product_price.product

billing_address: Address | None = None
if not _is_empty_customer_address(invoice.customer_address):
billing_address = Address.model_validate(invoice.customer_address)
# Try to retrieve the country from the payment method
elif invoice.charge is not None:
charge = await stripe_service.get_charge(get_expandable_id(invoice.charge))
if payment_method_details := charge.payment_method_details:
if card := getattr(payment_method_details, "card", None):
billing_address = Address.model_validate({"country": card.country})

# Get Discount if available
discount: Discount | None = None
if invoice.discount is not None:
Expand All @@ -329,7 +344,6 @@ async def create_order_from_stripe(
user: User | None = None

billing_reason: OrderBillingReason = OrderBillingReason.purchase

tax = invoice.tax or 0
amount = invoice.total - tax

Expand Down Expand Up @@ -368,6 +382,7 @@ async def create_order_from_stripe(
tax_amount=tax,
currency=invoice.currency,
billing_reason=billing_reason,
billing_address=billing_address,
stripe_invoice_id=invoice.id,
product=product,
product_price=product_price,
Expand Down
Loading