diff --git a/addressfield/fields.py b/addressfield/fields.py index 1420c252..0f601527 100644 --- a/addressfield/fields.py +++ b/addressfield/fields.py @@ -29,17 +29,18 @@ def flatten_data(data): class AddressField(forms.CharField): widget = AddressWidget + data = VALIDATION_DATA def clean(self, value, **kwargs): country = value['country'] try: - country_data = VALIDATION_DATA[country] + country_data = self.data[country] except KeyError: raise ValidationError('Invalid country selected') fields = flatten_data(country_data['fields']) - missing_fields = set(country_data['required']) - set(value.keys()) + missing_fields = set(country_data['required']) - set(field for field, value in value.items() if value) if missing_fields: missing_field_name = [fields[field]['label'] for field in missing_fields] raise ValidationError('Please provide data for: {}'.format(', '.join(missing_field_name))) diff --git a/addressfield/tests.py b/addressfield/tests.py new file mode 100644 index 00000000..adbf1dad --- /dev/null +++ b/addressfield/tests.py @@ -0,0 +1,40 @@ +from django.core.exceptions import ValidationError +from django.test import TestCase + +from .fields import AddressField + + +class TestRequiredFields(TestCase): + def build_validation_data(self, fields=list(), required=list()): + fields = set(fields + required) + return {'COUNTRY': { + 'fields': [{field: {'label': field}} for field in fields], + 'required': required, + }} + + def test_non_required(self): + field = AddressField() + field.data = self.build_validation_data(fields=['postalcode']) + field.clean({'country': 'COUNTRY'}) + + def test_non_required_blank_data(self): + field = AddressField() + field.data = self.build_validation_data(fields=['postalcode']) + field.clean({'country': 'COUNTRY', 'postalcode': ''}) + + def test_one_field_required(self): + field = AddressField() + field.data = self.build_validation_data(required=['postalcode']) + with self.assertRaises(ValidationError): + field.clean({'country': 'COUNTRY'}) + + def test_one_field_required_blank_data(self): + field = AddressField() + field.data = self.build_validation_data(required=['postalcode']) + with self.assertRaises(ValidationError): + field.clean({'country': 'COUNTRY', 'postalcode': ''}) + + def test_one_field_required_supplied_data(self): + field = AddressField() + field.data = self.build_validation_data(required=['postalcode']) + field.clean({'country': 'COUNTRY', 'postalcode': 'BS1 2AB'})