From 0f912c10dd37a07c5ccbcbea88c2e5a65b9b4fca Mon Sep 17 00:00:00 2001
From: James Moxon <james.moxon@ripjar.com>
Date: Fri, 21 Jun 2024 11:47:26 +0100
Subject: [PATCH] Support a deep path when sampling geo_shape fields

---
 .../hadoop/rest/RestRepository.java           |  2 +-
 .../dto/mapping/MappingUtils.java             | 19 +++++++
 .../integration/AbstractScalaEsSparkSQL.scala | 50 ++++++++++++++++++
 .../integration/AbstractScalaEsSparkSQL.scala | 52 ++++++++++++++++++-
 .../integration/AbstractScalaEsSparkSQL.scala | 50 ++++++++++++++++++
 5 files changed, 171 insertions(+), 2 deletions(-)

diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java
index a7424f432..9fee68188 100644
--- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java
+++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java
@@ -300,7 +300,7 @@ public Map<String, GeoField> sampleGeoFields(Mapping mapping) {
         Map<String, GeoField> geoInfo = new LinkedHashMap<String, GeoField>();
         for (Entry<String, GeoType> geoEntry : fields.entrySet()) {
             String fieldName = geoEntry.getKey();
-            geoInfo.put(fieldName, MappingUtils.parseGeoInfo(geoEntry.getValue(), geoMapping.get(fieldName)));
+            geoInfo.put(fieldName, MappingUtils.parseGeoInfo(geoEntry.getValue(), MappingUtils.getGeoMapping(geoMapping, fieldName)));
         }
 
         return geoInfo;
diff --git a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingUtils.java b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingUtils.java
index b310438f3..b11b27dbf 100644
--- a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingUtils.java
+++ b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingUtils.java
@@ -47,6 +47,25 @@ public abstract class MappingUtils {
                 "_parent", "_routing", "_index", "_size", "_timestamp", "_ttl", "_field_names", "_meta"));
     }
 
+    public static Object getGeoMapping(Map<String, Object> map, String path) {
+        String[] keys = path.split("\\.");
+        Object currentValue = map;
+
+        for (String key : keys) {
+            if (currentValue instanceof ArrayList) {
+                currentValue = ((ArrayList)currentValue).get(0);
+            }
+
+            if (currentValue instanceof Map) {
+                currentValue = ((Map<String, Object>) currentValue).get(key);
+            } else {
+                return null;
+            }
+        }
+
+        return currentValue;
+    }
+
     public static void validateMapping(String fields, Mapping mapping, FieldPresenceValidation validation, Log log) {
         if (StringUtils.hasText(fields)) {
             validateMapping(StringUtils.tokenize(fields), mapping, validation, log);
diff --git a/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
index 9b05a3db3..c4073b35a 100644
--- a/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
+++ b/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
@@ -1819,6 +1819,56 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
     assertThat(array(1), is(0.0d))
   }
 
+  @Test
+  def testGeoShapePointDeep() {
+    val mapping = wrapMapping("data", s"""{
+    |      "properties": {
+    |        "name": {
+    |          "type": "$keyword"
+    |        },
+    |        "location": {
+    |          "properties": {
+    |            "deep": {
+    |              "type": "geo_shape"
+    |          }
+    |        }
+    |      }
+    |    }
+    |  }
+    """.stripMargin)
+
+    val index = wrapIndex("sparksql-test-geoshape-point-geoshape-deep")
+    val typed = "data"
+    val (target, _) = makeTargets(index, typed)
+    RestUtils.touch(index)
+    RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))
+
+    val point = """{"name":"point", "location": { "deep":{ "type" : "point", "coordinates": [100.0, 0.0]  } }}""".stripMargin
+
+    sc.makeRDD(Seq(point)).saveJsonToEs(target)
+    val df = sqc.read.format("es").load(index)
+
+    println(df.schema.treeString)
+
+    val dataType = df.schema("location").dataType.asInstanceOf[StructType]("deep").dataType
+    assertEquals("struct", dataType.typeName)
+
+    val struct = dataType.asInstanceOf[StructType]
+    assertTrue(struct.fieldNames.contains("type"))
+    var coords = struct("coordinates").dataType
+    assertEquals("array", coords.typeName)
+    coords = coords.asInstanceOf[ArrayType].elementType
+    assertEquals("double", coords.typeName)
+
+    val head = df.select("location.*").head()
+
+    val obj = head.getStruct(0)
+    assertThat(obj.getString(0), is("point"))
+    val array = obj.getSeq[Double](1)
+    assertThat(array(0), is(100.0d))
+    assertThat(array(1), is(0.0d))
+  }
+
   @Test
   def testGeoShapeLine() {
     val mapping = wrapMapping("data", s"""{
diff --git a/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
index 9be766b9c..ae91aaa19 100644
--- a/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
+++ b/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
@@ -1881,6 +1881,56 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
     assertThat(array(1), is(0.0d))
   }
 
+  @Test
+  def testGeoShapePointDeep() {
+    val mapping = wrapMapping("data", s"""{
+    |      "properties": {
+    |        "name": {
+    |          "type": "$keyword"
+    |        },
+    |        "location": {
+    |          "properties": {
+    |            "deep": {
+    |              "type": "geo_shape"
+    |          }
+    |        }
+    |      }
+    |    }
+    |  }
+    """.stripMargin)
+
+    val index = wrapIndex("sparksql-test-geoshape-point-geoshape-deep")
+    val typed = "data"
+    val (target, _) = makeTargets(index, typed)
+    RestUtils.touch(index)
+    RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))
+
+    val point = """{"name":"point", "location": { "deep":{ "type" : "point", "coordinates": [100.0, 0.0]  } }}""".stripMargin
+
+    sc.makeRDD(Seq(point)).saveJsonToEs(target)
+    val df = sqc.read.format("es").load(index)
+
+    println(df.schema.treeString)
+
+    val dataType = df.schema("location").dataType.asInstanceOf[StructType]("deep").dataType
+    assertEquals("struct", dataType.typeName)
+
+    val struct = dataType.asInstanceOf[StructType]
+    assertTrue(struct.fieldNames.contains("type"))
+    var coords = struct("coordinates").dataType
+    assertEquals("array", coords.typeName)
+    coords = coords.asInstanceOf[ArrayType].elementType
+    assertEquals("double", coords.typeName)
+
+    val head = df.select("location.*").head()
+
+    val obj = head.getStruct(0)
+    assertThat(obj.getString(0), is("point"))
+    val array = obj.getSeq[Double](1)
+    assertThat(array(0), is(100.0d))
+    assertThat(array(1), is(0.0d))
+  }
+
   @Test
   def testGeoShapeLine() {
     val mapping = wrapMapping("data", s"""{
@@ -1905,7 +1955,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
       
     sc.makeRDD(Seq(line)).saveJsonToEs(target)
     val df = sqc.read.format("es").load(index)
- 
+
     val dataType = df.schema("location").dataType
     assertEquals("struct", dataType.typeName)
 
diff --git a/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
index 86acac5cb..d32a309f0 100644
--- a/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
+++ b/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala
@@ -1881,6 +1881,56 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
     assertThat(array(1), is(0.0d))
   }
 
+  @Test
+  def testGeoShapePointDeep() {
+    val mapping = wrapMapping("data", s"""{
+    |      "properties": {
+    |        "name": {
+    |          "type": "$keyword"
+    |        },
+    |        "location": {
+    |          "properties": {
+    |            "deep": {
+    |              "type": "geo_shape"
+    |          }
+    |        }
+    |      }
+    |    }
+    |  }
+    """.stripMargin)
+
+    val index = wrapIndex("sparksql-test-geoshape-point-geoshape-deep")
+    val typed = "data"
+    val (target, _) = makeTargets(index, typed)
+    RestUtils.touch(index)
+    RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))
+
+    val point = """{"name":"point", "location": { "deep":{ "type" : "point", "coordinates": [100.0, 0.0]  } }}""".stripMargin
+
+    sc.makeRDD(Seq(point)).saveJsonToEs(target)
+    val df = sqc.read.format("es").load(index)
+
+    println(df.schema.treeString)
+
+    val dataType = df.schema("location").dataType.asInstanceOf[StructType]("deep").dataType
+    assertEquals("struct", dataType.typeName)
+
+    val struct = dataType.asInstanceOf[StructType]
+    assertTrue(struct.fieldNames.contains("type"))
+    var coords = struct("coordinates").dataType
+    assertEquals("array", coords.typeName)
+    coords = coords.asInstanceOf[ArrayType].elementType
+    assertEquals("double", coords.typeName)
+
+    val head = df.select("location.*").head()
+
+    val obj = head.getStruct(0)
+    assertThat(obj.getString(0), is("point"))
+    val array = obj.getSeq[Double](1)
+    assertThat(array(0), is(100.0d))
+    assertThat(array(1), is(0.0d))
+  }
+
   @Test
   def testGeoShapeLine() {
     val mapping = wrapMapping("data", s"""{