Skip to content

Commit

Permalink
Update code to explict handle HTTP request methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pamella committed May 27, 2024
1 parent 1d0ad37 commit 58e95c3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
4 changes: 2 additions & 2 deletions drf_rw_serializers/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_serializer_class(self):
(Eg. admins get full serialization, others get basic serialization)
"""
if hasattr(self, "request"):
if self.request.method in ["GET"]:
if self.request.method in ["GET", "HEAD", "OPTIONS", "TRACE"]:
assert (
getattr(self, "read_serializer_class", None) is not None
or self.serializer_class is not None
Expand All @@ -53,7 +53,7 @@ def get_serializer_class(self):
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_read_serializer_class()
else:
elif self.request.method in ["POST", "PUT", "PATCH", "DELETE"]:
assert (
getattr(self, "write_serializer_class", None) is not None
or self.serializer_class is not None
Expand Down
35 changes: 26 additions & 9 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,47 @@ def test_no_request_provided(self):
self.FullSerializerView().get_serializer_class(), OrderedMealDetailsSerializer
)

def test_get_request_provided(self):
def test_read_request_method_provided(self):
read_methods = ["GET", "HEAD", "OPTIONS", "TRACE"]

# Return read_serializer_class
self.RWSerializerView.request = mock.Mock(method="GET")
self.assertEqual(self.RWSerializerView().get_serializer_class(), OrderListSerializer)
for method in read_methods:
self.RWSerializerView.request = mock.Mock(method=method)
self.assertEqual(self.RWSerializerView().get_serializer_class(), OrderListSerializer)

# Return read_serializer_class even if serializer_class is provided
self.FullSerializerView.request = mock.Mock(method="GET")
self.assertEqual(self.FullSerializerView().get_serializer_class(), OrderListSerializer)
for method in read_methods:
self.FullSerializerView.request = mock.Mock(method=method)
self.assertEqual(self.FullSerializerView().get_serializer_class(), OrderListSerializer)

def test_non_get_request_provided(self):
non_get_methods = ["POST", "PUT", "PATCH", "DELETE"]
def test_write_request_method_provided(self):
write_methods = ["POST", "PUT", "PATCH", "DELETE"]

# Return write_serializer_class
for method in non_get_methods:
for method in write_methods:
self.RWSerializerView.request = mock.Mock(method=method)
self.assertEqual(self.RWSerializerView().get_serializer_class(), OrderCreateSerializer)

# Return write_serializer_class even if serializer_class is provided
for method in non_get_methods:
for method in write_methods:
self.FullSerializerView.request = mock.Mock(method=method)
self.assertEqual(
self.FullSerializerView().get_serializer_class(), OrderCreateSerializer
)

def test_non_read_write_request_method_provided(self):
non_read_write_method = "CONNECT"

# Return default serializer_class
self.RWSerializerView.request = mock.Mock(method=non_read_write_method)
self.assertIsNone(self.RWSerializerView().get_serializer_class())

# Return default serializer_class even if read/write serializer classes are provided
self.FullSerializerView.request = mock.Mock(method=non_read_write_method)
self.assertEqual(
self.FullSerializerView().get_serializer_class(), OrderedMealDetailsSerializer
)


class GenericAPIViewGetReadSerializerClassTests(BaseTestCase):
def test_read_serializer_class_not_provided(self):
Expand Down

0 comments on commit 58e95c3

Please sign in to comment.