Skip to content

Adding support for update output mode to structured streaming #1839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public abstract class InitializationUtils {
public static void checkIdForOperation(Settings settings) {
String operation = settings.getOperation();

if (ConfigurationOptions.ES_OPERATION_UPDATE.equals(operation)) {
if (ConfigurationOptions.ES_OPERATION_UPDATE.equals(operation) || ConfigurationOptions.ES_OPERATION_UPSERT.equals(operation)) {
Assert.isTrue(StringUtils.hasText(settings.getMappingId()),
String.format("Operation [%s] requires an id but none (%s) was specified", operation, ConfigurationOptions.ES_MAPPING_ID));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

package org.elasticsearch.spark.integration

import com.fasterxml.jackson.databind.ObjectMapper

import java.io.File
import java.nio.file.Files
import java.sql.Timestamp
import java.util.concurrent.TimeUnit
import java.{lang => jl}
import java.{util => ju}

import javax.xml.bind.DatatypeConverter
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
Expand Down Expand Up @@ -56,8 +57,7 @@ import org.hamcrest.Matchers.containsString
import org.hamcrest.Matchers.is
import org.hamcrest.Matchers.not
import org.junit.{AfterClass, Assert, Assume, BeforeClass, ClassRule, FixMethodOrder, Rule, Test}
import org.junit.Assert.assertThat
import org.junit.Assert.assertTrue
import org.junit.Assert.{assertEquals, assertThat, assertTrue}
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
Expand Down Expand Up @@ -585,4 +585,63 @@ class AbstractScalaEsSparkStructuredStreaming(prefix: String, something: Boolean
.start(target)
}
}

@Test
def testUpdate(): Unit = {
val target = wrapIndex(resource("test-update", "data", version))
val docPath = wrapIndex(docEndpoint("test-update", "data", version))
val test = new StreamingQueryTestHarness[Record](spark)

test.withInput(Record(1, "Spark"))
.withInput(Record(2, "Hadoop"))
.withInput(Record(3, "YARN"))
.startTest {
test.stream
.writeStream
.outputMode("update")
.option("checkpointLocation", checkpoint(target))
.option(ES_MAPPING_ID, "id")
.format("es")
.start(target)
}
test.waitForPartialCompletion()

assertTrue(RestUtils.exists(target))
assertTrue(RestUtils.exists(docPath + "/1"))
assertTrue(RestUtils.exists(docPath + "/2"))
assertTrue(RestUtils.exists(docPath + "/3"))
var searchResult = RestUtils.get(target + "/_search?")
assertThat(searchResult, containsString("Spark"))
assertThat(searchResult, containsString("Hadoop"))
assertThat(searchResult, containsString("YARN"))

test.withInput(Record(1, "Spark"))
.withInput(Record(2, "Hadoop2"))
.withInput(Record(3, "YARN"))
test.waitForCompletion()
searchResult = RestUtils.get(target + "/_search?version=true")
val result: java.util.Map[String, Object] = new ObjectMapper().readValue(searchResult, classOf[java.util.Map[String, Object]])
val hits = result.get("hits").asInstanceOf[java.util.Map[String, Object]].get("hits").asInstanceOf[java.util.List[java.util.Map[String,
Object]]]
hits.forEach(hit => {
hit.get("_id").asInstanceOf[String] match {
case "1" => {
assertEquals(1, hit.get("_version"))
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
assertEquals("Spark", value)
}
case "2" => {
assertEquals(2, hit.get("_version")) // The only one that should have been updated
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
assertEquals("Hadoop2", value)
}
case "3" => {
assertEquals(1, hit.get("_version"))
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
assertEquals("YARN", value)
}
case _ => throw new AssertionError("Unexpected result")
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
private var foundExpectedException: Boolean = false
private var encounteredException: Option[String] = None

private val latch = new CountDownLatch(1)
private var latch = new CountDownLatch(1) // expects just a single batch

def incrementExpected(): Unit = inputsRequired = inputsRequired + 1

Expand Down Expand Up @@ -153,6 +153,10 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe

def waitOnComplete(timeValue: TimeValue): Boolean = latch.await(timeValue.millis, TimeUnit.MILLISECONDS)

def expectAnotherBatch(): Unit = {
latch = new CountDownLatch(1)
}

def assertExpectedExceptions(message: Option[String]): Unit = {
expectingToThrow match {
case Some(exceptionClass) =>
Expand Down Expand Up @@ -211,7 +215,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
* Add input to test server. Updates listener's bookkeeping to know when it's safe to shut down the stream
*/
def withInput(data: S): StreamingQueryTestHarness[S] = {
ensureState(Init) {
ensureState(Init, Running) {
testingServer.sendData(TestingSerde.serialize(data))
listener.incrementExpected()
}
Expand Down Expand Up @@ -320,6 +324,30 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
}
}

/**
* Waits until all inputs are processed on the streaming query, but leaves the query open with the listener still in place, expecting
* another batch of inputs.
*/
def waitForPartialCompletion(): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

ensureState(Running) {
currentState match {
case Running =>
try {
// Wait for query to complete consuming records
if (!listener.waitOnComplete(testTimeout)) {
throw new TimeoutException("Timed out on waiting for stream to complete.")
}
listener.expectAnotherBatch()
} catch {
case e: Throwable =>
// Best effort to shutdown queries before throwing
scrubState()
throw e
}
}
}
}

// tears down literally everything indiscriminately, mostly for cleanup after a failure
private[this] def scrubState(): Unit = {
sparkSession.streams.removeListener(listener)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.util.Calendar
import java.util.Date
import java.util.Locale
import java.util.UUID

import javax.xml.bind.DatatypeConverter
import org.apache.commons.logging.LogFactory
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -63,6 +62,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException
import org.elasticsearch.hadoop.EsHadoopIllegalStateException
import org.elasticsearch.hadoop.cfg.ConfigurationOptions
import org.elasticsearch.hadoop.cfg.ConfigurationOptions.ES_WRITE_OPERATION
import org.elasticsearch.hadoop.cfg.InternalConfigurationOptions
import org.elasticsearch.hadoop.cfg.InternalConfigurationOptions.INTERNAL_TRANSPORT_POOLING_KEY
import org.elasticsearch.hadoop.cfg.Settings
Expand Down Expand Up @@ -122,12 +122,6 @@ private[sql] class DefaultSource extends RelationProvider with SchemaRelationPro
// Verify compatiblity versions for alpha:
StructuredStreamingVersionLock.checkCompatibility(sparkSession)

// For now we only support Append style output mode
if (outputMode != OutputMode.Append()) {
throw new EsHadoopIllegalArgumentException("Append is only supported OutputMode for Elasticsearch. " +
s"Cannot continue with [$outputMode].")
}

// Should not support partitioning. We already allow people to split data into different
// indices with the index pattern functionality. Potentially could add this later if a need
// arises by appending patterns to the provided index, but that's probably feature overload.
Expand All @@ -142,6 +136,19 @@ private[sql] class DefaultSource extends RelationProvider with SchemaRelationPro
.load(sqlContext.sparkContext.getConf)
.merge(streamParams(mapConfig.toMap, sparkSession).asJava)

// For now we only support Update and Append style output modes
if (outputMode == OutputMode.Update()) {
val writeOperation = jobSettings.getProperty(ES_WRITE_OPERATION);
if (writeOperation == null) {
jobSettings.setProperty(ES_WRITE_OPERATION, ConfigurationOptions.ES_OPERATION_UPSERT)
} else if (writeOperation != ConfigurationOptions.ES_OPERATION_UPSERT) {
throw new EsHadoopIllegalArgumentException("Output mode update is only supported if es.write.operation is unset or set to upsert")
}
} else if (outputMode != OutputMode.Append()) {
throw new EsHadoopIllegalArgumentException("Append and update are the only supported OutputModes for Elasticsearch. " +
s"Cannot continue with [$outputMode].")
}

InitializationUtils.discoverClusterInfo(jobSettings, LogFactory.getLog(classOf[DefaultSource]))
InitializationUtils.checkIdForOperation(jobSettings)
InitializationUtils.checkIndexExistence(jobSettings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

package org.elasticsearch.spark.integration

import com.fasterxml.jackson.databind.ObjectMapper

import java.io.File
import java.nio.file.Files
import java.sql.Timestamp
import java.util.concurrent.TimeUnit
import java.{lang => jl}
import java.{util => ju}

import javax.xml.bind.DatatypeConverter
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf
Expand Down Expand Up @@ -56,8 +57,7 @@ import org.hamcrest.Matchers.containsString
import org.hamcrest.Matchers.is
import org.hamcrest.Matchers.not
import org.junit.{AfterClass, Assert, Assume, BeforeClass, ClassRule, FixMethodOrder, Rule, Test}
import org.junit.Assert.assertThat
import org.junit.Assert.assertTrue
import org.junit.Assert.{assertEquals, assertThat, assertTrue}
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
Expand Down Expand Up @@ -585,4 +585,63 @@ class AbstractScalaEsSparkStructuredStreaming(prefix: String, something: Boolean
.start(target)
}
}

@Test
def testUpdate(): Unit = {
val target = wrapIndex(resource("test-update", "data", version))
val docPath = wrapIndex(docEndpoint("test-update", "data", version))
val test = new StreamingQueryTestHarness[Record](spark)

test.withInput(Record(1, "Spark"))
.withInput(Record(2, "Hadoop"))
.withInput(Record(3, "YARN"))
.startTest {
test.stream
.writeStream
.outputMode("update")
.option("checkpointLocation", checkpoint(target))
.option(ES_MAPPING_ID, "id")
.format("es")
.start(target)
}
test.waitForPartialCompletion()

assertTrue(RestUtils.exists(target))
assertTrue(RestUtils.exists(docPath + "/1"))
assertTrue(RestUtils.exists(docPath + "/2"))
assertTrue(RestUtils.exists(docPath + "/3"))
var searchResult = RestUtils.get(target + "/_search?")
assertThat(searchResult, containsString("Spark"))
assertThat(searchResult, containsString("Hadoop"))
assertThat(searchResult, containsString("YARN"))

test.withInput(Record(1, "Spark"))
.withInput(Record(2, "Hadoop2"))
.withInput(Record(3, "YARN"))
test.waitForCompletion()
searchResult = RestUtils.get(target + "/_search?version=true")
val result: java.util.Map[String, Object] = new ObjectMapper().readValue(searchResult, classOf[java.util.Map[String, Object]])
val hits = result.get("hits").asInstanceOf[java.util.Map[String, Object]].get("hits").asInstanceOf[java.util.List[java.util.Map[String,
Object]]]
hits.forEach(hit => {
hit.get("_id").asInstanceOf[String] match {
case "1" => {
assertEquals(1, hit.get("_version"))
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
assertEquals("Spark", value)
}
case "2" => {
assertEquals(2, hit.get("_version")) // The only one that should have been updated
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
assertEquals("Hadoop2", value)
}
case "3" => {
assertEquals(1, hit.get("_version"))
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
assertEquals("YARN", value)
}
case _ => throw new AssertionError("Unexpected result")
}
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
private var foundExpectedException: Boolean = false
private var encounteredException: Option[String] = None

private val latch = new CountDownLatch(1)
private var latch = new CountDownLatch(1) // expects just a single batch

def incrementExpected(): Unit = inputsRequired = inputsRequired + 1

Expand Down Expand Up @@ -153,6 +153,10 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe

def waitOnComplete(timeValue: TimeValue): Boolean = latch.await(timeValue.millis, TimeUnit.MILLISECONDS)

def expectAnotherBatch(): Unit = {
latch = new CountDownLatch(1)
}

def assertExpectedExceptions(message: Option[String]): Unit = {
expectingToThrow match {
case Some(exceptionClass) =>
Expand Down Expand Up @@ -211,7 +215,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
* Add input to test server. Updates listener's bookkeeping to know when it's safe to shut down the stream
*/
def withInput(data: S): StreamingQueryTestHarness[S] = {
ensureState(Init) {
ensureState(Init, Running) {
testingServer.sendData(TestingSerde.serialize(data))
listener.incrementExpected()
}
Expand Down Expand Up @@ -320,6 +324,30 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
}
}

/**
* Waits until all inputs are processed on the streaming query, but leaves the query open with the listener still in place, expecting
* another batch of inputs.
*/
def waitForPartialCompletion(): Unit = {
ensureState(Running) {
currentState match {
case Running =>
try {
// Wait for query to complete consuming records
if (!listener.waitOnComplete(testTimeout)) {
throw new TimeoutException("Timed out on waiting for stream to complete.")
}
listener.expectAnotherBatch()
} catch {
case e: Throwable =>
// Best effort to shutdown queries before throwing
scrubState()
throw e
}
}
}
}

// tears down literally everything indiscriminately, mostly for cleanup after a failure
private[this] def scrubState(): Unit = {
sparkSession.streams.removeListener(listener)
Expand Down
Loading