Skip to content

Adding support for upserts of nested arrays #1838

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -2306,6 +2307,108 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
assertEquals(0, result(0).size)
}

@Test
def testNestedFieldsUpsert(): Unit = {
val update_params = "new_samples: samples"
val update_script = "ctx._source.samples = params.new_samples"
val es_conf = Map(
"es.mapping.id" -> "id",
"es.write.operation" -> "upsert",
"es.update.script.params" -> update_params,
"es.update.script.inline" -> update_script
)
// First do an upsert with two completely new rows:
var data = Seq(Row("2", List(Row("hello"), Row("world"))), Row("1", List()))
var rdd: RDD[Row] = sc.parallelize(data)
val schema = new StructType()
.add("id", StringType, nullable = false)
.add("samples", new ArrayType(new StructType()
.add("text", StringType), true))
var df = sqc.createDataFrame(rdd, schema)
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("nested_fields_upsert_test")

val reader = sqc.read.schema(schema).format("org.elasticsearch.spark.sql").option("es.read.field.as.array.include","samples")
var resultDf = reader.load("nested_fields_upsert_test")
assertEquals(2, resultDf.count())
var samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first().getAs[IndexedSeq[Row]](0)
assertEquals(2, samples.size)
assertEquals("hello", samples(0).get(0))
assertEquals("world", samples(1).get(0))

//Now, do an upsert on the one with the empty samples list:
data = Seq(Row("1", List(Row("goodbye"), Row("world"))))
rdd = sc.parallelize(data)
df = sqc.createDataFrame(rdd, schema)
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("nested_fields_upsert_test")

resultDf = reader.load("nested_fields_upsert_test")
samples = resultDf.where(resultDf("id").equalTo("1")).select("samples").first().getAs[IndexedSeq[Row]](0)
assertEquals(2, samples.size)
assertEquals("goodbye", samples(0).get(0))
assertEquals("world", samples(1).get(0))

// Finally, an upsert on the row that had samples values:
data = Seq(Row("2", List(Row("goodbye"), Row("again"))))
rdd = sc.parallelize(data)
df = sqc.createDataFrame(rdd, schema)
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("nested_fields_upsert_test")

resultDf = reader.load("nested_fields_upsert_test")
samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first().getAs[IndexedSeq[Row]](0)
assertEquals(2, samples.size)
assertEquals("goodbye", samples(0).get(0))
assertEquals("again", samples(1).get(0))
}

@Test
def testMapsUpsert(): Unit = {
val update_params = "new_samples: samples"
val update_script = "ctx._source.samples = params.new_samples"
val es_conf = Map(
"es.mapping.id" -> "id",
"es.write.operation" -> "upsert",
"es.update.script.params" -> update_params,
"es.update.script.inline" -> update_script
)
// First do an upsert with two completely new rows:
var data = Seq(Row("2", Map(("hello", "world"))), Row("1", Map()))
var rdd: RDD[Row] = sc.parallelize(data)
val schema = new StructType()
.add("id", StringType, nullable = false)
.add("samples", new MapType(StringType, StringType, true))
var df = sqc.createDataFrame(rdd, schema)
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("map_fields_upsert_test")

val reader = sqc.read.format("org.elasticsearch.spark.sql")
var resultDf = reader.load("map_fields_upsert_test")
assertEquals(2, resultDf.count())
var samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first()
assertEquals(1, samples.size)
assertEquals("world", samples.get(0).asInstanceOf[Row].get(0))

//Now, do an upsert on the one with the empty samples list:
data = Seq(Row("1", Map(("goodbye", "all"))))
rdd = sc.parallelize(data)
df = sqc.createDataFrame(rdd, schema)
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("map_fields_upsert_test")

resultDf = reader.load("map_fields_upsert_test")
samples = resultDf.where(resultDf("id").equalTo("1")).select("samples").first()
assertEquals(1, samples.size)
assertEquals("all", samples.get(0).asInstanceOf[Row].get(0))

// Finally, an upsert on the row that had samples values:
data = Seq(Row("2", Map(("goodbye", "again"))))
rdd = sc.parallelize(data)
df = sqc.createDataFrame(rdd, schema)
df.write.format("org.elasticsearch.spark.sql").options(es_conf).mode(SaveMode.Append).save("map_fields_upsert_test")

resultDf = reader.load("map_fields_upsert_test")
samples = resultDf.where(resultDf("id").equalTo("2")).select("samples").first()
assertEquals(1, samples.size)
assertEquals("again", samples.get(0).asInstanceOf[Row].get(0))
}

@Test
def testWildcard() {
val mapping = wrapMapping("data", s"""{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ package org.elasticsearch.spark.sql
import java.sql.Date
import java.sql.Timestamp
import java.util.{Map => JMap}

import scala.collection.JavaConverters.mapAsScalaMapConverter
import scala.collection.{Map => SMap}
import scala.collection.Seq
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType}
import org.apache.spark.sql.types.DataTypes.BinaryType
import org.apache.spark.sql.types.DataTypes.BooleanType
import org.apache.spark.sql.types.DataTypes.ByteType
Expand All @@ -39,8 +37,6 @@ import org.apache.spark.sql.types.DataTypes.LongType
import org.apache.spark.sql.types.DataTypes.ShortType
import org.apache.spark.sql.types.DataTypes.StringType
import org.apache.spark.sql.types.DataTypes.TimestampType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StructType
import org.elasticsearch.hadoop.cfg.ConfigurationOptions.ES_SPARK_DATAFRAME_WRITE_NULL_VALUES_DEFAULT
import org.elasticsearch.hadoop.cfg.Settings
import org.elasticsearch.hadoop.serialization.EsHadoopSerializationException
Expand All @@ -51,7 +47,7 @@ import org.elasticsearch.hadoop.serialization.builder.ValueWriter.Result
import org.elasticsearch.hadoop.util.unit.Booleans


class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends FilteringValueWriter[(Row, StructType)] with SettingsAware {
class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends FilteringValueWriter[Any] with SettingsAware {

def this() {
this(false)
Expand All @@ -64,11 +60,28 @@ class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends Filtering
writeNullValues = settings.getDataFrameWriteNullValues
}

override def write(value: (Row, StructType), generator: Generator): Result = {
val row = value._1
val schema = value._2
override def write(value: Any, generator: Generator): Result = {
value match {
case Tuple2(row, schema: StructType) =>
writeStruct(schema, row, generator)
case map: Map[_, _] =>
writeMapWithInferredSchema(map, generator)
case seq: Seq[Row] =>
writeArray(seq, generator)
}
}

writeStruct(schema, row, generator)
private[spark] def writeArray(value: Seq[Row], generator: Generator): Result = {
if (value.nonEmpty) {
val schema = value.head.schema
val result = write(DataTypes.createArrayType(schema), value, generator)
if (!result.isSuccesful) {
return handleUnknown(value, generator)
}
} else {
generator.writeBeginArray().writeEndArray()
}
Result.SUCCESFUL()
}

private[spark] def writeStruct(schema: StructType, value: Any, generator: Generator): Result = {
Expand Down Expand Up @@ -157,6 +170,81 @@ class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends Filtering
Result.SUCCESFUL()
}

private def writeMapWithInferredSchema(value: Any, generator: Generator): Result = {
value match {
case sm: SMap[_, _] => doWriteMapWithInferredSchema(sm, generator)
case jm: JMap[_, _] => doWriteMapWithInferredSchema(jm.asScala, generator)
// unknown map type
case _ => handleUnknown(value, generator)
}
}

private def doWriteMapWithInferredSchema(map: SMap[_, _], generator: Generator): Result = {
if (map != null && map.valuesIterator.hasNext) {
val sampleValueOption = getFirstNotNullElement(map.valuesIterator)
val schema = inferMapSchema(sampleValueOption)
doWriteMap(schema, map, generator)
} else {
writeEmptyMap(generator)
}
}

private def writeEmptyMap(generator: Generator): Result = {
generator.writeBeginObject().writeEndObject()
Result.SUCCESFUL()
}

private def inferMapSchema(valueOption: Option[Any]): MapType = {
if(valueOption.isDefined) {
val valueType = inferType(valueOption.get)
MapType(StringType, valueType) //The key type is never read
} else {
MapType(StringType, StringType) //Does not matter if the map is empty or has no values
}
}

def inferArraySchema(array: Array[_]): DataType = {
val EMPTY_ARRAY_TYPE = StringType //Makes no difference for an empty array
if (array.isEmpty) {
EMPTY_ARRAY_TYPE
} else {
val sampleValueOption = getFirstNotNullElement(array.iterator)
if (sampleValueOption.isDefined) {
inferType(sampleValueOption.get)
}
else {
EMPTY_ARRAY_TYPE
}
}
}

def getFirstNotNullElement(iterator: Iterator[_]): Option[Any] = {
iterator.find(value => Option(value).isDefined)
}

private def inferType(value: Any): DataType = {
value match {
case _: String => StringType
case _: Int => IntegerType
case _: Integer => IntegerType
case _: Boolean => BooleanType
case _: java.lang.Boolean => BooleanType
case _: Short => ShortType
case _: java.lang.Short => ShortType
case _: Long => LongType
case _: java.lang.Long => LongType
case _: Double => DoubleType
case _: java.lang.Double => DoubleType
case _: Float => FloatType
case _: java.lang.Float => FloatType
case _: Timestamp => TimestampType
case _: Date => DateType
case _: Array[Byte] => BinaryType
case array: Array[_] => ArrayType(inferArraySchema(array))
case map: Map[_, _] => inferMapSchema(getFirstNotNullElement(map.valuesIterator))
}
}

private[spark] def writePrimitive(schema: DataType, value: Any, generator: Generator): Result = {
if (value == null) {
generator.writeNull()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,39 @@ class DataFrameValueWriterTest {
assertEquals("""{"skey":{"jkey":"value"}}""", serialized)
}

@Test
def testWriteMaps(): Unit = {
val settings = new TestSettings()
val writer = new DataFrameValueWriter()
if (settings != null) {
writer.setSettings(settings)
}
{
val out = new ByteArrayOutputStream()
val generator = new JacksonJsonGenerator(out)
val inputMap = Map(("key1", Map(("key2", Array(1, 2, 3)))))
val result = writer.write(inputMap, generator)
assertTrue(result.isSuccesful)
generator.flush()
assertEquals("{\"key1\":{\"key2\":[1,2,3]}}", new String(out.toByteArray))
}
{
val out = new ByteArrayOutputStream()
val generator = new JacksonJsonGenerator(out)
val inputMap = Map(("key1", "value1"), ("key2", "value2"))
val result = writer.write(inputMap, generator)
assertTrue(result.isSuccesful)
generator.flush()
assertEquals("{\"key1\":\"value1\",\"key2\":\"value2\"}", new String(out.toByteArray))
}
{
val out = new ByteArrayOutputStream()
val generator = new JacksonJsonGenerator(out)
val inputMap = Map()
val result = writer.write(inputMap, generator)
assertTrue(result.isSuccesful)
generator.flush()
assertEquals("{}", new String(out.toByteArray))
}
}
}
Loading