Skip to content

Commit

Permalink
Refactor tax calculation logic in AbstractOrder.recalculate
Browse files Browse the repository at this point in the history
  • Loading branch information
milano-slesarik authored and PetrDlouhy committed Mar 7, 2025
1 parent 1601c5c commit 1d6878b
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 24 deletions.
33 changes: 9 additions & 24 deletions plans/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from django.contrib.auth import get_user_model
from django.db import models, transaction

from plans import utils

try:
from django.contrib.sites.models import Site
except RuntimeError:
Expand All @@ -30,7 +32,6 @@

from plans.contrib import get_user_language, send_template_email
from plans.enumeration import Enumeration
from plans.importer import import_name
from plans.signals import (
account_activated,
account_change_plan,
Expand Down Expand Up @@ -993,43 +994,27 @@ def recalculate(self, amount, billing_info, request=None, use_default=True):
"""
self.amount = amount
self.currency = get_currency()

country = getattr(billing_info, "country", None)
if country is None:
country = get_country_code(request)
else:
country = country.code

if hasattr(billing_info, "tax_number") and billing_info.tax_number:
from plans.base.models import AbstractBillingInfo

tax_number = AbstractBillingInfo.get_full_tax_number(
billing_info.tax_number, country
)
else:
tax_number = None
# Calculating tax can be complex task (e.g. VIES webservice call)
# To ensure that tax calculated on order preview will be the same on final order
# tax rate is cached for a given billing data (as this value only depends on it)
tax_session_key = "tax_%s_%s" % (tax_number, country)
request_successful = True
if request:
tax = request.session.get(tax_session_key)
else:
tax = None
if tax is None:
taxation_policy = getattr(settings, "PLANS_TAXATION_POLICY", None)
if not taxation_policy:
raise ImproperlyConfigured("PLANS_TAXATION_POLICY is not set")
taxation_policy = import_name(taxation_policy)
tax, request_successful = taxation_policy.get_tax_rate(
tax_number, country, request
)
tax = str(tax)
# Because taxation policy could return None which clutters with saving this value
# into cache, we use str() representation of this value
if request and request_successful:
request.session[tax_session_key] = tax

tax_rate, request_successful = utils.get_tax_rate(country, tax_number, request)
if (
use_default or request_successful
): # Don't change the tax, if the request was not successful
self.tax = Decimal(tax) if tax != "None" else None
self.tax = tax_rate

class Meta:
ordering = ("-created",)
Expand Down
59 changes: 59 additions & 0 deletions plans/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from decimal import Decimal

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured

from plans.importer import import_name


def get_client_ip(request):
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
Expand Down Expand Up @@ -43,3 +47,58 @@ def country_code_transform(country_code):
"GR": "EL",
}
return transform_dict.get(country_code, country_code)


def calculate_tax_rate(tax_number, country_code, request=None):
taxation_policy = getattr(settings, "PLANS_TAXATION_POLICY", None)
if not taxation_policy:
raise ImproperlyConfigured("PLANS_TAXATION_POLICY is not set")
taxation_policy = import_name(taxation_policy)
tax, request_successful = taxation_policy.get_tax_rate(
tax_number, country_code, request
)
if request_successful and request:
TaxCacheService.cache_tax_rate(request, tax, tax_number, country_code)
return Decimal(tax) if tax is not None else None, request_successful


def get_tax_rate(country_code, tax_number, request=None):
"""Get tax rate for given country and tax number
1. Try to get tax rate from cache
2. If not in cache, calculate it (and possibly cache it)
Returns tax rate and if the request was successful (False means default tax rate was used)
"""
if request:
try:
tax_from_cache = TaxCacheService.get_tax_rate(
request, tax_number, country_code
)
return tax_from_cache, True
except KeyError:
pass

tax, request_successful = calculate_tax_rate(tax_number, country_code, request)
return tax, request_successful


class TaxCacheService:
@classmethod
def get_cache_key(cls, tax_number, country):
return "tax_%s_%s" % (tax_number, country)

@classmethod
def cache_tax_rate(cls, request, tax, tax_number, country):
request.session[cls.get_cache_key(tax_number, country)] = str(tax)

@classmethod
def get_tax_rate(cls, request, tax_number, country):
key = cls.get_cache_key(tax_number, country)
if key not in request.session:
raise KeyError(
f"Tax rate for {tax_number} and {country} not found in cache"
)
raw = request.session[key]
if raw == "None":
return None
return Decimal(raw)

0 comments on commit 1d6878b

Please sign in to comment.