Skip to content

Commit 3203f3f

Browse files
nnegreyanguillanneuf
authored andcommitted
samples: automl: create model tests (#1933)
* automl: add create model tests * Use a fake model to speed up batch predict test * Update method names and clean up Translate model test * Rename translate model * Run code formatter * Import order * copy paste typo * License year * Fix sample error, fix typos in test * lint: line length
1 parent 1e4dbba commit 3203f3f

9 files changed

+561
-187
lines changed

automl/snippets/src/main/java/com/example/automl/VisionClassificationCreateModel.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ static void createModel(String projectId, String datasetId, String displayName)
4949
LocationName projectLocation = LocationName.of(projectId, "us-central1");
5050
// Set model metadata.
5151
ImageClassificationModelMetadata metadata =
52-
ImageClassificationModelMetadata.newBuilder()
53-
.setTrainBudgetMilliNodeHours(
54-
8) // The train budget of creating this model, expressed in hours.
55-
.build();
52+
ImageClassificationModelMetadata.newBuilder().setTrainBudgetMilliNodeHours(24000).build();
5653
Model model =
5754
Model.newBuilder()
5855
.setDisplayName(displayName)

automl/snippets/src/test/java/com/example/automl/BatchPredictTest.java

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
public class BatchPredictTest {
4545
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
4646
private static final String BUCKET_ID = PROJECT_ID + "-lcm";
47-
private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
47+
private static final String MODEL_ID = "TEN0000000000000000000";
4848
private ByteArrayOutputStream bout;
4949
private PrintStream out;
5050

@@ -62,59 +62,31 @@ public static void checkRequirements() {
6262
}
6363

6464
@Before
65-
public void setUp() throws IOException, ExecutionException, InterruptedException {
66-
// Verify that the model is deployed for prediction
67-
try (AutoMlClient client = AutoMlClient.create()) {
68-
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
69-
Model model = client.getModel(modelFullId);
70-
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
71-
// Deploy the model if not deployed
72-
DeployModelRequest request =
73-
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
74-
client.deployModelAsync(request).get();
75-
}
76-
}
77-
65+
public void setUp() {
7866
bout = new ByteArrayOutputStream();
7967
out = new PrintStream(bout);
8068
System.setOut(out);
8169
}
8270

8371
@After
8472
public void tearDown() {
85-
// Delete the created files from GCS
86-
Storage storage = StorageOptions.getDefaultInstance().getService();
87-
Page<Blob> blobs =
88-
storage.list(
89-
BUCKET_ID,
90-
Storage.BlobListOption.currentDirectory(),
91-
Storage.BlobListOption.prefix("TEST_BATCH_PREDICT/"));
92-
93-
for (Blob blob : blobs.iterateAll()) {
94-
Page<Blob> fileBlobs =
95-
storage.list(
96-
BUCKET_ID,
97-
Storage.BlobListOption.currentDirectory(),
98-
Storage.BlobListOption.prefix(blob.getName()));
99-
for (Blob fileBlob : fileBlobs.iterateAll()) {
100-
if (!fileBlob.isDirectory()) {
101-
fileBlob.delete();
102-
}
103-
}
104-
}
105-
10673
System.setOut(null);
10774
}
10875

10976
@Test
110-
public void testBatchPredict() throws IOException, ExecutionException, InterruptedException {
111-
String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID);
112-
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
113-
// Act
114-
BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri);
115-
116-
// Assert
117-
String got = bout.toString();
118-
assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket");
77+
public void testBatchPredict() {
78+
// As batch prediction can take a long time. Try to batch predict on a model and confirm that
79+
// the model was not found, but other elements of the request were valid.
80+
try {
81+
String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID);
82+
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
83+
BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri);
84+
String got = bout.toString();
85+
assertThat(got)
86+
.contains("The model is either not found or not supported for prediction yet.");
87+
} catch (IOException | ExecutionException | InterruptedException e) {
88+
assertThat(e.getMessage())
89+
.contains("The model is either not found or not supported for prediction yet.");
90+
}
11991
}
12092
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.automl;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import java.io.ByteArrayOutputStream;
23+
import java.io.IOException;
24+
import java.io.PrintStream;
25+
import java.util.UUID;
26+
import java.util.concurrent.ExecutionException;
27+
28+
import org.junit.After;
29+
import org.junit.Before;
30+
import org.junit.BeforeClass;
31+
import org.junit.Test;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
35+
@RunWith(JUnit4.class)
36+
@SuppressWarnings("checkstyle:abbreviationaswordinname")
37+
public class LanguageEntityExtractionCreateModelTest {
38+
39+
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
40+
private static final String DATASET_ID = "TEN0000000000000000000";
41+
private ByteArrayOutputStream bout;
42+
private PrintStream out;
43+
44+
private static void requireEnvVar(String varName) {
45+
assertNotNull(
46+
System.getenv(varName),
47+
"Environment variable '%s' is required to perform these tests.".format(varName));
48+
}
49+
50+
@BeforeClass
51+
public static void checkRequirements() {
52+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
53+
requireEnvVar("AUTOML_PROJECT_ID");
54+
}
55+
56+
@Before
57+
public void setUp() {
58+
bout = new ByteArrayOutputStream();
59+
out = new PrintStream(bout);
60+
System.setOut(out);
61+
}
62+
63+
@After
64+
public void tearDown() {
65+
System.setOut(null);
66+
}
67+
68+
@Test
69+
public void testLanguageEntityExtractionCreateModel() {
70+
// As entity extraction does not let you cancel model creation, instead try to create a model
71+
// from a nonexistent dataset, but other elements of the request were valid.
72+
try {
73+
// Create a random dataset name with a length of 32 characters (max allowed by AutoML)
74+
// To prevent name collisions when running tests in multiple java versions at once.
75+
// AutoML doesn't allow "-", but accepts "_"
76+
String modelName =
77+
String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26));
78+
LanguageEntityExtractionCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName);
79+
String got = bout.toString();
80+
assertThat(got).contains("Dataset does not exist");
81+
} catch (IOException | ExecutionException | InterruptedException e) {
82+
assertThat(e.getMessage()).contains("Dataset does not exist");
83+
}
84+
}
85+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.automl;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.automl.v1.AutoMlClient;
23+
24+
import java.io.ByteArrayOutputStream;
25+
import java.io.IOException;
26+
import java.io.PrintStream;
27+
import java.util.UUID;
28+
import java.util.concurrent.ExecutionException;
29+
30+
import org.junit.After;
31+
import org.junit.Before;
32+
import org.junit.BeforeClass;
33+
import org.junit.Test;
34+
import org.junit.runner.RunWith;
35+
import org.junit.runners.JUnit4;
36+
37+
@RunWith(JUnit4.class)
38+
@SuppressWarnings("checkstyle:abbreviationaswordinname")
39+
public class LanguageSentimentAnalysisCreateModelTest {
40+
41+
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
42+
private static final String DATASET_ID = System.getenv("SENTIMENT_ANALYSIS_DATASET_ID");
43+
private ByteArrayOutputStream bout;
44+
private PrintStream out;
45+
private String operationId;
46+
47+
private static void requireEnvVar(String varName) {
48+
assertNotNull(
49+
System.getenv(varName),
50+
"Environment variable '%s' is required to perform these tests.".format(varName));
51+
}
52+
53+
@BeforeClass
54+
public static void checkRequirements() {
55+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
56+
requireEnvVar("AUTOML_PROJECT_ID");
57+
requireEnvVar("SENTIMENT_ANALYSIS_DATASET_ID");
58+
}
59+
60+
@Before
61+
public void setUp() {
62+
bout = new ByteArrayOutputStream();
63+
out = new PrintStream(bout);
64+
System.setOut(out);
65+
}
66+
67+
@After
68+
public void tearDown() throws IOException {
69+
// Cancel the operation
70+
try (AutoMlClient client = AutoMlClient.create()) {
71+
client.getOperationsClient().cancelOperation(operationId);
72+
}
73+
74+
System.setOut(null);
75+
}
76+
77+
@Test
78+
public void testLanguageSentimentAnalysisCreateModel()
79+
throws IOException, ExecutionException, InterruptedException {
80+
// Create a random dataset name with a length of 32 characters (max allowed by AutoML)
81+
// To prevent name collisions when running tests in multiple java versions at once.
82+
// AutoML doesn't allow "-", but accepts "_"
83+
String modelName =
84+
String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26));
85+
LanguageSentimentAnalysisCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName);
86+
87+
String got = bout.toString();
88+
assertThat(got).contains("Training started");
89+
90+
operationId = got.split("Training operation name: ")[1].split("\n")[0];
91+
}
92+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.automl;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static junit.framework.TestCase.assertNotNull;
21+
22+
import com.google.cloud.automl.v1.AutoMlClient;
23+
24+
import java.io.ByteArrayOutputStream;
25+
import java.io.IOException;
26+
import java.io.PrintStream;
27+
import java.util.UUID;
28+
import java.util.concurrent.ExecutionException;
29+
30+
import org.junit.After;
31+
import org.junit.Before;
32+
import org.junit.BeforeClass;
33+
import org.junit.Test;
34+
import org.junit.runner.RunWith;
35+
import org.junit.runners.JUnit4;
36+
37+
@RunWith(JUnit4.class)
38+
@SuppressWarnings("checkstyle:abbreviationaswordinname")
39+
public class LanguageTextClassificationCreateModelTest {
40+
41+
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
42+
private static final String DATASET_ID = System.getenv("TEXT_CLASSIFICATION_DATASET_ID");
43+
private ByteArrayOutputStream bout;
44+
private PrintStream out;
45+
private String operationId;
46+
47+
private static void requireEnvVar(String varName) {
48+
assertNotNull(
49+
System.getenv(varName),
50+
"Environment variable '%s' is required to perform these tests.".format(varName));
51+
}
52+
53+
@BeforeClass
54+
public static void checkRequirements() {
55+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
56+
requireEnvVar("AUTOML_PROJECT_ID");
57+
requireEnvVar("TEXT_CLASSIFICATION_DATASET_ID");
58+
}
59+
60+
@Before
61+
public void setUp() {
62+
bout = new ByteArrayOutputStream();
63+
out = new PrintStream(bout);
64+
System.setOut(out);
65+
}
66+
67+
@After
68+
public void tearDown() throws IOException {
69+
// Cancel the operation
70+
try (AutoMlClient client = AutoMlClient.create()) {
71+
client.getOperationsClient().cancelOperation(operationId);
72+
}
73+
74+
System.setOut(null);
75+
}
76+
77+
@Test
78+
public void testLanguageTextClassificationCreateModel()
79+
throws IOException, ExecutionException, InterruptedException {
80+
// Create a random dataset name with a length of 32 characters (max allowed by AutoML)
81+
// To prevent name collisions when running tests in multiple java versions at once.
82+
// AutoML doesn't allow "-", but accepts "_"
83+
String modelName =
84+
String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26));
85+
LanguageTextClassificationCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName);
86+
87+
String got = bout.toString();
88+
assertThat(got).contains("Training started");
89+
90+
operationId = got.split("Training operation name: ")[1].split("\n")[0];
91+
}
92+
}

0 commit comments

Comments
 (0)