Skip to content

Commit 3c805a9

Browse files
authored
Adding support for upserts of nested arrays (#1838)
This commit adds support for upserts of nested array fields, and limited support for upserts of maps. Closes #1190
1 parent 62b4b86 commit 3c805a9

File tree

9 files changed

+711
-30
lines changed

9 files changed

+711
-30
lines changed

spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.SQLContext
3636
import org.apache.spark.sql.SaveMode
3737
import org.apache.spark.sql.types.ArrayType
3838
import org.apache.spark.sql.types.IntegerType
39+
import org.apache.spark.sql.types.MapType
3940
import org.apache.spark.sql.types.StringType
4041
import org.apache.spark.sql.types.StructField
4142
import org.apache.spark.sql.types.StructType
@@ -2306,6 +2307,108 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
23062307
assertEquals(0, result(0).size)
23072308
}
23082309

2310+
@Test
2311+
def testNestedFieldsUpsert(): Unit = {
2312+
val update_params = "new_samples: samples"
2313+
val update_script = "ctx._source.samples = params.new_samples"
2314+
val es_conf = Map(
2315+
"es.mapping.id" -> "id",
2316+
"es.write.operation" -> "upsert",
2317+
"es.update.script.params" -> update_params,
2318+
"es.update.script.inline" -> update_script
2319+
)
2320+
// First do an upsert with two completely new rows:
2321+
var data = Seq(Row("2", List(Row("hello"), Row("world"))), Row("1", List()))
2322+
var rdd: RDD[Row] = sc.parallelize(data)
2323+
val schema = new StructType()
2324+
.add("id", StringType, nullable = false)
2325+
.add("samples", new ArrayType(new StructType()
2326+
.add("text", StringType), true))
2327+
var df = sqc.createDataFrame(rdd, schema)
2328+
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("nested_fields_upsert_test")
2329+
2330+
val reader = sqc.read.schema(schema).format("org.elasticsearch.spark.sql").option("es.read.field.as.array.include","samples")
2331+
var resultDf = reader.load("nested_fields_upsert_test")
2332+
assertEquals(2, resultDf.count())
2333+
var samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first().getAs[IndexedSeq[Row]](0)
2334+
assertEquals(2, samples.size)
2335+
assertEquals("hello", samples(0).get(0))
2336+
assertEquals("world", samples(1).get(0))
2337+
2338+
//Now, do an upsert on the one with the empty samples list:
2339+
data = Seq(Row("1", List(Row("goodbye"), Row("world"))))
2340+
rdd = sc.parallelize(data)
2341+
df = sqc.createDataFrame(rdd, schema)
2342+
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("nested_fields_upsert_test")
2343+
2344+
resultDf = reader.load("nested_fields_upsert_test")
2345+
samples = resultDf.where(resultDf("id").equalTo("1")).select("samples").first().getAs[IndexedSeq[Row]](0)
2346+
assertEquals(2, samples.size)
2347+
assertEquals("goodbye", samples(0).get(0))
2348+
assertEquals("world", samples(1).get(0))
2349+
2350+
// Finally, an upsert on the row that had samples values:
2351+
data = Seq(Row("2", List(Row("goodbye"), Row("again"))))
2352+
rdd = sc.parallelize(data)
2353+
df = sqc.createDataFrame(rdd, schema)
2354+
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("nested_fields_upsert_test")
2355+
2356+
resultDf = reader.load("nested_fields_upsert_test")
2357+
samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first().getAs[IndexedSeq[Row]](0)
2358+
assertEquals(2, samples.size)
2359+
assertEquals("goodbye", samples(0).get(0))
2360+
assertEquals("again", samples(1).get(0))
2361+
}
2362+
2363+
@Test
2364+
def testMapsUpsert(): Unit = {
2365+
val update_params = "new_samples: samples"
2366+
val update_script = "ctx._source.samples = params.new_samples"
2367+
val es_conf = Map(
2368+
"es.mapping.id" -> "id",
2369+
"es.write.operation" -> "upsert",
2370+
"es.update.script.params" -> update_params,
2371+
"es.update.script.inline" -> update_script
2372+
)
2373+
// First do an upsert with two completely new rows:
2374+
var data = Seq(Row("2", Map(("hello", "world"))), Row("1", Map()))
2375+
var rdd: RDD[Row] = sc.parallelize(data)
2376+
val schema = new StructType()
2377+
.add("id", StringType, nullable = false)
2378+
.add("samples", new MapType(StringType, StringType, true))
2379+
var df = sqc.createDataFrame(rdd, schema)
2380+
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("map_fields_upsert_test")
2381+
2382+
val reader = sqc.read.format("org.elasticsearch.spark.sql")
2383+
var resultDf = reader.load("map_fields_upsert_test")
2384+
assertEquals(2, resultDf.count())
2385+
var samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first()
2386+
assertEquals(1, samples.size)
2387+
assertEquals("world", samples.get(0).asInstanceOf[Row].get(0))
2388+
2389+
//Now, do an upsert on the one with the empty samples list:
2390+
data = Seq(Row("1", Map(("goodbye", "all"))))
2391+
rdd = sc.parallelize(data)
2392+
df = sqc.createDataFrame(rdd, schema)
2393+
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("map_fields_upsert_test")
2394+
2395+
resultDf = reader.load("map_fields_upsert_test")
2396+
samples = resultDf.where(resultDf("id").equalTo("1")).select("samples").first()
2397+
assertEquals(1, samples.size)
2398+
assertEquals("all", samples.get(0).asInstanceOf[Row].get(0))
2399+
2400+
// Finally, an upsert on the row that had samples values:
2401+
data = Seq(Row("2", Map(("goodbye", "again"))))
2402+
rdd = sc.parallelize(data)
2403+
df = sqc.createDataFrame(rdd, schema)
2404+
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("map_fields_upsert_test")
2405+
2406+
resultDf = reader.load("map_fields_upsert_test")
2407+
samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first()
2408+
assertEquals(1, samples.size)
2409+
assertEquals("again", samples.get(0).asInstanceOf[Row].get(0))
2410+
}
2411+
23092412
@Test
23102413
def testWildcard() {
23112414
val mapping = wrapMapping("data", s"""{

spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DataFrameValueWriter.scala

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ package org.elasticsearch.spark.sql
2121
import java.sql.Date
2222
import java.sql.Timestamp
2323
import java.util.{Map => JMap}
24-
2524
import scala.collection.JavaConverters.mapAsScalaMapConverter
2625
import scala.collection.{Map => SMap}
2726
import scala.collection.Seq
2827
import org.apache.spark.sql.Row
29-
import org.apache.spark.sql.types.ArrayType
30-
import org.apache.spark.sql.types.DataType
28+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType}
3129
import org.apache.spark.sql.types.DataTypes.BinaryType
3230
import org.apache.spark.sql.types.DataTypes.BooleanType
3331
import org.apache.spark.sql.types.DataTypes.ByteType
@@ -39,8 +37,6 @@ import org.apache.spark.sql.types.DataTypes.LongType
3937
import org.apache.spark.sql.types.DataTypes.ShortType
4038
import org.apache.spark.sql.types.DataTypes.StringType
4139
import org.apache.spark.sql.types.DataTypes.TimestampType
42-
import org.apache.spark.sql.types.MapType
43-
import org.apache.spark.sql.types.StructType
4440
import org.elasticsearch.hadoop.cfg.ConfigurationOptions.ES_SPARK_DATAFRAME_WRITE_NULL_VALUES_DEFAULT
4541
import org.elasticsearch.hadoop.cfg.Settings
4642
import org.elasticsearch.hadoop.serialization.EsHadoopSerializationException
@@ -51,7 +47,7 @@ import org.elasticsearch.hadoop.serialization.builder.ValueWriter.Result
5147
import org.elasticsearch.hadoop.util.unit.Booleans
5248

5349

54-
class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends FilteringValueWriter[(Row, StructType)] with SettingsAware {
50+
class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends FilteringValueWriter[Any] with SettingsAware {
5551

5652
def this() {
5753
this(false)
@@ -64,11 +60,28 @@ class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends Filtering
6460
writeNullValues = settings.getDataFrameWriteNullValues
6561
}
6662

67-
override def write(value: (Row, StructType), generator: Generator): Result = {
68-
val row = value._1
69-
val schema = value._2
63+
override def write(value: Any, generator: Generator): Result = {
64+
value match {
65+
case Tuple2(row, schema: StructType) =>
66+
writeStruct(schema, row, generator)
67+
case map: Map[_, _] =>
68+
writeMapWithInferredSchema(map, generator)
69+
case seq: Seq[Row] =>
70+
writeArray(seq, generator)
71+
}
72+
}
7073

71-
writeStruct(schema, row, generator)
74+
private[spark] def writeArray(value: Seq[Row], generator: Generator): Result = {
75+
if (value.nonEmpty) {
76+
val schema = value.head.schema
77+
val result = write(DataTypes.createArrayType(schema), value, generator)
78+
if (!result.isSuccesful) {
79+
return handleUnknown(value, generator)
80+
}
81+
} else {
82+
generator.writeBeginArray().writeEndArray()
83+
}
84+
Result.SUCCESFUL()
7285
}
7386

7487
private[spark] def writeStruct(schema: StructType, value: Any, generator: Generator): Result = {
@@ -157,6 +170,81 @@ class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends Filtering
157170
Result.SUCCESFUL()
158171
}
159172

173+
private def writeMapWithInferredSchema(value: Any, generator: Generator): Result = {
174+
value match {
175+
case sm: SMap[_, _] => doWriteMapWithInferredSchema(sm, generator)
176+
case jm: JMap[_, _] => doWriteMapWithInferredSchema(jm.asScala, generator)
177+
// unknown map type
178+
case _ => handleUnknown(value, generator)
179+
}
180+
}
181+
182+
private def doWriteMapWithInferredSchema(map: SMap[_, _], generator: Generator): Result = {
183+
if (map != null && map.valuesIterator.hasNext) {
184+
val sampleValueOption = getFirstNotNullElement(map.valuesIterator)
185+
val schema = inferMapSchema(sampleValueOption)
186+
doWriteMap(schema, map, generator)
187+
} else {
188+
writeEmptyMap(generator)
189+
}
190+
}
191+
192+
private def writeEmptyMap(generator: Generator): Result = {
193+
generator.writeBeginObject().writeEndObject()
194+
Result.SUCCESFUL()
195+
}
196+
197+
private def inferMapSchema(valueOption: Option[Any]): MapType = {
198+
if(valueOption.isDefined) {
199+
val valueType = inferType(valueOption.get)
200+
MapType(StringType, valueType) //The key type is never read
201+
} else {
202+
MapType(StringType, StringType) //Does not matter if the map is empty or has no values
203+
}
204+
}
205+
206+
def inferArraySchema(array: Array[_]): DataType = {
207+
val EMPTY_ARRAY_TYPE = StringType //Makes no difference for an empty array
208+
if (array.isEmpty) {
209+
EMPTY_ARRAY_TYPE
210+
} else {
211+
val sampleValueOption = getFirstNotNullElement(array.iterator)
212+
if (sampleValueOption.isDefined) {
213+
inferType(sampleValueOption.get)
214+
}
215+
else {
216+
EMPTY_ARRAY_TYPE
217+
}
218+
}
219+
}
220+
221+
def getFirstNotNullElement(iterator: Iterator[_]): Option[Any] = {
222+
iterator.find(value => Option(value).isDefined)
223+
}
224+
225+
private def inferType(value: Any): DataType = {
226+
value match {
227+
case _: String => StringType
228+
case _: Int => IntegerType
229+
case _: Integer => IntegerType
230+
case _: Boolean => BooleanType
231+
case _: java.lang.Boolean => BooleanType
232+
case _: Short => ShortType
233+
case _: java.lang.Short => ShortType
234+
case _: Long => LongType
235+
case _: java.lang.Long => LongType
236+
case _: Double => DoubleType
237+
case _: java.lang.Double => DoubleType
238+
case _: Float => FloatType
239+
case _: java.lang.Float => FloatType
240+
case _: Timestamp => TimestampType
241+
case _: Date => DateType
242+
case _: Array[Byte] => BinaryType
243+
case array: Array[_] => ArrayType(inferArraySchema(array))
244+
case map: Map[_, _] => inferMapSchema(getFirstNotNullElement(map.valuesIterator))
245+
}
246+
}
247+
160248
private[spark] def writePrimitive(schema: DataType, value: Any, generator: Generator): Result = {
161249
if (value == null) {
162250
generator.writeNull()

spark/sql-13/src/test/scala/org/elasticsearch/spark/sql/DataFrameValueWriterTest.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,39 @@ class DataFrameValueWriterTest {
120120
assertEquals("""{"skey":{"jkey":"value"}}""", serialized)
121121
}
122122

123+
@Test
124+
def testWriteMaps(): Unit = {
125+
val settings = new TestSettings()
126+
val writer = new DataFrameValueWriter()
127+
if (settings != null) {
128+
writer.setSettings(settings)
129+
}
130+
{
131+
val out = new ByteArrayOutputStream()
132+
val generator = new JacksonJsonGenerator(out)
133+
val inputMap = Map(("key1", Map(("key2", Array(1, 2, 3)))))
134+
val result = writer.write(inputMap, generator)
135+
assertTrue(result.isSuccesful)
136+
generator.flush()
137+
assertEquals("{\"key1\":{\"key2\":[1,2,3]}}", new String(out.toByteArray))
138+
}
139+
{
140+
val out = new ByteArrayOutputStream()
141+
val generator = new JacksonJsonGenerator(out)
142+
val inputMap = Map(("key1", "value1"), ("key2", "value2"))
143+
val result = writer.write(inputMap, generator)
144+
assertTrue(result.isSuccesful)
145+
generator.flush()
146+
assertEquals("{\"key1\":\"value1\",\"key2\":\"value2\"}", new String(out.toByteArray))
147+
}
148+
{
149+
val out = new ByteArrayOutputStream()
150+
val generator = new JacksonJsonGenerator(out)
151+
val inputMap = Map()
152+
val result = writer.write(inputMap, generator)
153+
assertTrue(result.isSuccesful)
154+
generator.flush()
155+
assertEquals("{}", new String(out.toByteArray))
156+
}
157+
}
123158
}

0 commit comments

Comments
 (0)