Skip to content
This repository has been archived by the owner on Sep 2, 2020. It is now read-only.

Commit

Permalink
Merge pull request #45 from intelligenia/feature/geospatial
Browse files Browse the repository at this point in the history
Feature/geospatial (issue #36)
  • Loading branch information
myarik authored Oct 29, 2019
2 parents 57965ae + 1b4809e commit 13bed74
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 4 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Blog(models.Model):
body = models.TextField(_('Body'))
tags = ArrayField(models.CharField(max_length=200), blank=True, null=True)
is_published = models.BooleanField(_('Is published'), default=False)


def __str__(self):
return self.title
Expand All @@ -33,7 +34,7 @@ class Blog(models.Model):
Create a `DocType` to represent our Blog model
```python

from elasticsearch_dsl import Document, Date, Integer, Keyword, Text
from elasticsearch_dsl import Document, Date, Integer, Keyword, Text, GeoPoint

class BlogIndex(Document):
pk = Integer()
Expand All @@ -42,6 +43,7 @@ class BlogIndex(Document):
body = Text()
tags = Keyword(multi=True)
is_published = Boolean()
location = GeoPoint()

class Index:
name = 'blog'
Expand All @@ -64,6 +66,7 @@ class BlogView(es_views.ListElasticAPIView):
es_filters.ElasticFieldsRangeFilter,
es_filters.ElasticSearchFilter,
es_filters.ElasticOrderingFilter,
es_filters.ElasticGeoBoundingBoxFilter
)
es_ordering = 'created_at'
es_filter_fields = (
Expand All @@ -76,6 +79,10 @@ class BlogView(es_views.ListElasticAPIView):
'tags',
'title',
)

es_geo_location_field = es_filters.ESFieldFilter('location')
es_geo_location_field_name = 'location'

```

This will allow the client to filter the items in the list by making queries such as:
Expand All @@ -84,6 +91,13 @@ http://example.com/blogs/api/list?search=elasticsearch
http://example.com/blogs/api/list?tag=opensource
http://example.com/blogs/api/list?tag=opensource,aws
http://example.com/blogs/api/list?to_created_at=2020-10-01&from_created_at=2017-09-01
# ElasticGeoBoundingBoxFilter expects format {top left lat, lon}|{bottom right lat, lon}
http://example.com/blogs/api/list?location=25.55235365216549,120.245361328125|21.861498734372567,122.728271484375
# ElasticGeoDistanceFilter expects format {distance}{unit}|{lat}|{lon}
http://example.com/blogs/api/list?location=100000km|12.04|-63.93
```


Expand Down
180 changes: 180 additions & 0 deletions rest_framework_elasticsearch/es_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,183 @@ def get_schema_fields(self, view):
)
)
]


GEO_BOUNDING_BOX = 'geo_bounding_box'

class ElasticGeoBoundingBoxFilter(BaseEsFilterBackend):
geo_bounding_box_param = ""
geo_bounding_box_title = _('Geo Bounding Box')
geo_bounding_box_description = _("""A Geo Bounding Box filter.
Expects format {top left lat, lon}|{bottom right lat, lon}
ex. Filter documents that are located in the given bounding box.
44.87,40.07|43.87,41.11""")

def get_geo_bounding_box_params(self, request, view):
"""
Geo bounding box
?location__geo_bounding_box={top left lat, lon}|{bottom right lat, lon}
ex. Filter documents that are located in the given bounding box.
?location__geo_bounding_box=44.87,40.07|43.87,41.11
"""

location_field = view.get_es_geo_location_field()
location_field_name = view.get_es_geo_location_field_name()

if not location_field:
return {}

self.geo_bounding_box_param = location_field_name
values = request.query_params.get(location_field_name, '').split('|')

if len(values) < 2:
return {}

top_left_points = {}
bottom_right_points = {}
options = {}

# Top left
lat_lon = values[0].split(
','
)
if len(lat_lon) >= 2:
top_left_points.update({
'lat': float(lat_lon[0]),
'lon': float(lat_lon[1]),
})

# Bottom right
lat_lon = values[1].split(
','
)
if len(lat_lon) >= 2:
bottom_right_points.update({
'lat': float(lat_lon[0]),
'lon': float(lat_lon[1]),
})

# Options
for value in values[2:]:
if ':' in value:
opt_name_val = value.split(
':'
)
if len(opt_name_val) >= 2:
if opt_name_val[0] in ('_name', 'validation_method', 'type'):
options.update(
{
opt_name_val[0]: opt_name_val[1]
}
)

if not top_left_points or not bottom_right_points:
return {}

params = {
self.geo_bounding_box_param: {
'top_left': top_left_points,
'bottom_right': bottom_right_points,
}
}

params.update(options)
return params

def filter_search(self, request, search, view):
geo_params = self.get_geo_bounding_box_params(request, view)

if not geo_params:
return search

q = Q(GEO_BOUNDING_BOX, **geo_params)
search = search.filter(q)
return search

def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.geo_bounding_box_param,
required=False,
location='query',
schema=coreschema.String(
title=force_text(self.geo_bounding_box_title),
description=force_text(self.geo_bounding_box_description)
)
)
]

GEO_DISTANCE = 'geo_distance'

class ElasticGeoDistanceFilter(BaseEsFilterBackend):
geo_distance_param = ''
geo_distance_title = _('Geo Distance')
geo_distance_description = _("""A Geo Distance filter.
Expects format {distance}{unit}|{lat}|{lon}.
ex. Filter documents by radius of 100000km from the given location.
100000km|12.04|-63.93""")

def get_geo_distance_params(self, request, view):
"""
Geo distance
?location__geo_distance={distance}{unit}|{lat}|{lon}
ex. Filter documents by radius of 100000km from the given location.
?location__geo_distance=100000km|12.04|-63.93
"""
location_field = view.get_es_geo_location_field()
location_field_name = view.get_es_geo_location_field_name()

if not location_field:
return {}

self.geo_distance_param = location_field_name
values = request.query_params.get(self.geo_distance_param, '').split('|', 2)
len_values = len(values)

if len_values < 2:
return {}

lat_lon = values[1].split(',')

params = {
'distance': values[0],
}
if len(lat_lon) >= 2:
params.update({
self.geo_distance_param: {
'lat': float(lat_lon[0]),
'lon': float(lat_lon[1]),
}
})

if len_values == 3:
params['distance_type'] = values[2]

return params

def filter_search(self, request, search, view):
geo_params = self.get_geo_distance_params(request, view)

if not geo_params:
return search

q = Q(GEO_DISTANCE, **geo_params)
search = search.query(q)
return search

def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.geo_distance_param,
required=False,
location='query',
schema=coreschema.String(
title=force_text(self.geo_distance_title),
description=force_text(self.geo_distance_description)
)
)
]
12 changes: 12 additions & 0 deletions rest_framework_elasticsearch/es_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def get_es_excludes_fields(self):
"""
return getattr(self, 'es_excludes_fields', None)

def get_es_geo_location_field(self):
"""
Return field or fields used for search.
The return value must be an iterable.
"""
return getattr(self, 'es_geo_location_field', None)

def get_es_geo_location_field_name(self):
"""
"""
return getattr(self, 'es_geo_location_field_name', None)

def get_es_client(self):
"""
You may want to override this if you need to provide different
Expand Down
Loading

0 comments on commit 13bed74

Please sign in to comment.