Skip to content

Commit 2851ce8

Browse files
chore: add execute method with allowed return types (#2592)
* chore: add execute method with allowed return types Adds an execute method to the Connection API that allows the caller to specify the allowed result types. This can be used by driver implementations, such as JDBC, to use a single execute method in the Connection API, while still making sure that the method only executes statements that it should. The current executeUpdate method in the Connection API does not overlap completely with semantics of executeUpdate in JDBC, as JDBC allows any statement type that does not return a ResultSet to be executed with that method. The Connection API only allows statements that return an update count. Instead of modifying the executeUpdate method in the Connection API to match the semantics of JDBC (which would be a breaking change), this method can be used generically for all execute*** methods in the JDBC driver, which again can be used to clean up some code there. * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * fix: bad merge for clirr --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 1f850e9 commit 2851ce8

File tree

5 files changed

+344
-1
lines changed

5 files changed

+344
-1
lines changed

google-cloud-spanner/clirr-ignored-differences.xml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,18 @@
416416
<className>com/google/cloud/spanner/connection/Connection</className>
417417
<method>void setMaxPartitions(int)</method>
418418
</difference>
419+
<!-- Add an execute method that allows the driver to state what types should be allowed or not.
420+
This fixes the gap between what JDBC allows, and what is currently allowed in the Connection
421+
API:
422+
1. JDBC allows executeUpdate to be used for everything that does not return a ResultSet.
423+
2. Connection API requires executeUpdate to be used with something that returns an update
424+
count (i.e. no DDL and no client-side statements. -->
425+
<difference>
426+
<differenceType>7012</differenceType>
427+
<className>com/google/cloud/spanner/connection/Connection</className>
428+
<method>com.google.cloud.spanner.connection.StatementResult execute(com.google.cloud.spanner.Statement, java.util.Set)</method>
429+
</difference>
430+
419431
<!-- (Internal change, use stream timeout) -->
420432
<difference>
421433
<differenceType>7012</differenceType>

google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/Connection.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import com.google.spanner.v1.ExecuteBatchDmlRequest;
4242
import com.google.spanner.v1.ResultSetStats;
4343
import java.util.Iterator;
44+
import java.util.Set;
4445
import java.util.concurrent.ExecutionException;
4546
import java.util.concurrent.TimeUnit;
4647

@@ -949,6 +950,35 @@ default boolean isDelayTransactionStartUntilFirstWrite() {
949950
*/
950951
StatementResult execute(Statement statement);
951952

953+
/**
954+
* Executes the given statement if allowed in the current {@link TransactionMode} and connection
955+
* state, and if the result that would be returned is in the set of allowed result types. The
956+
* statement will not be sent to Cloud Spanner if the result type would not be allowed. This
957+
* method can be used by drivers that must limit the type of statements that are allowed for a
958+
* given method, e.g. for the {@link java.sql.Statement#executeQuery(String)} and {@link
959+
* java.sql.Statement#executeUpdate(String)} methods.
960+
*
961+
* <p>The returned value depends on the type of statement:
962+
*
963+
* <ul>
964+
* <li>Queries and DML statements with returning clause will return a {@link ResultSet}.
965+
* <li>Simple DML statements will return an update count
966+
* <li>DDL statements will return a {@link ResultType#NO_RESULT}
967+
* <li>Connection and transaction statements (SET AUTOCOMMIT=TRUE|FALSE, SHOW AUTOCOMMIT, SET
968+
* TRANSACTION READ ONLY, etc) will return either a {@link ResultSet} or {@link
969+
* ResultType#NO_RESULT}, depending on the type of statement (SHOW or SET)
970+
* </ul>
971+
*
972+
* @param statement The statement to execute
973+
* @param allowedResultTypes The result types that this method may return. The statement will not
974+
* be sent to Cloud Spanner if the statement would return a result that is not one of the
975+
* types in this set.
976+
* @return the result of the statement
977+
*/
978+
default StatementResult execute(Statement statement, Set<ResultType> allowedResultTypes) {
979+
throw new UnsupportedOperationException("Not implemented");
980+
}
981+
952982
/**
953983
* Executes the given statement if allowed in the current {@link TransactionMode} and connection
954984
* state asynchronously. The returned value depends on the type of statement:

google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import com.google.cloud.spanner.connection.AbstractStatementParser.ParsedStatement;
4848
import com.google.cloud.spanner.connection.AbstractStatementParser.StatementType;
4949
import com.google.cloud.spanner.connection.StatementExecutor.StatementTimeout;
50+
import com.google.cloud.spanner.connection.StatementResult.ResultType;
5051
import com.google.cloud.spanner.connection.UnitOfWork.CallType;
5152
import com.google.cloud.spanner.connection.UnitOfWork.UnitOfWorkState;
5253
import com.google.common.annotations.VisibleForTesting;
@@ -60,12 +61,15 @@
6061
import java.util.Iterator;
6162
import java.util.LinkedList;
6263
import java.util.List;
64+
import java.util.Set;
6365
import java.util.Stack;
6466
import java.util.concurrent.ExecutionException;
6567
import java.util.concurrent.RejectedExecutionException;
6668
import java.util.concurrent.ThreadFactory;
6769
import java.util.concurrent.TimeUnit;
6870
import java.util.concurrent.TimeoutException;
71+
import java.util.stream.Collectors;
72+
import javax.annotation.Nullable;
6973
import org.threeten.bp.Instant;
7074

7175
/** Implementation for {@link Connection}, the generic Spanner connection API (not JDBC). */
@@ -940,9 +944,20 @@ public void rollbackToSavepoint(String name) {
940944

941945
@Override
942946
public StatementResult execute(Statement statement) {
943-
Preconditions.checkNotNull(statement);
947+
return internalExecute(Preconditions.checkNotNull(statement), null);
948+
}
949+
950+
@Override
951+
public StatementResult execute(Statement statement, Set<ResultType> allowedResultTypes) {
952+
return internalExecute(
953+
Preconditions.checkNotNull(statement), Preconditions.checkNotNull(allowedResultTypes));
954+
}
955+
956+
private StatementResult internalExecute(
957+
Statement statement, @Nullable Set<ResultType> allowedResultTypes) {
944958
ConnectionPreconditions.checkState(!isClosed(), CLOSED_ERROR_MSG);
945959
ParsedStatement parsedStatement = getStatementParser().parse(statement, this.queryOptions);
960+
checkResultTypeAllowed(parsedStatement, allowedResultTypes);
946961
switch (parsedStatement.getType()) {
947962
case CLIENT_SIDE:
948963
return parsedStatement
@@ -969,6 +984,53 @@ public StatementResult execute(Statement statement) {
969984
"Unknown statement: " + parsedStatement.getSqlWithoutComments());
970985
}
971986

987+
@VisibleForTesting
988+
static void checkResultTypeAllowed(
989+
ParsedStatement parsedStatement, @Nullable Set<ResultType> allowedResultTypes) {
990+
if (allowedResultTypes == null) {
991+
return;
992+
}
993+
ResultType resultType = getResultType(parsedStatement);
994+
if (!allowedResultTypes.contains(resultType)) {
995+
throw SpannerExceptionFactory.newSpannerException(
996+
ErrorCode.INVALID_ARGUMENT,
997+
"This statement returns a result of type "
998+
+ resultType
999+
+ ". Only statements that return a result of one of the following types are allowed: "
1000+
+ allowedResultTypes.stream()
1001+
.map(ResultType::toString)
1002+
.collect(Collectors.joining(", ")));
1003+
}
1004+
}
1005+
1006+
private static ResultType getResultType(ParsedStatement parsedStatement) {
1007+
switch (parsedStatement.getType()) {
1008+
case CLIENT_SIDE:
1009+
if (parsedStatement.getClientSideStatement().isQuery()) {
1010+
return ResultType.RESULT_SET;
1011+
} else if (parsedStatement.getClientSideStatement().isUpdate()) {
1012+
return ResultType.UPDATE_COUNT;
1013+
} else {
1014+
return ResultType.NO_RESULT;
1015+
}
1016+
case QUERY:
1017+
return ResultType.RESULT_SET;
1018+
case UPDATE:
1019+
if (parsedStatement.hasReturningClause()) {
1020+
return ResultType.RESULT_SET;
1021+
} else {
1022+
return ResultType.UPDATE_COUNT;
1023+
}
1024+
case DDL:
1025+
return ResultType.NO_RESULT;
1026+
case UNKNOWN:
1027+
default:
1028+
throw SpannerExceptionFactory.newSpannerException(
1029+
ErrorCode.INVALID_ARGUMENT,
1030+
"Unknown statement: " + parsedStatement.getSqlWithoutComments());
1031+
}
1032+
}
1033+
9721034
@Override
9731035
public AsyncStatementResult executeAsync(Statement statement) {
9741036
Preconditions.checkNotNull(statement);

google-cloud-spanner/src/test/java/com/google/cloud/spanner/connection/ConnectionImplTest.java

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.SELECT;
2121
import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.UPDATE;
2222
import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.expectSpannerException;
23+
import static com.google.cloud.spanner.connection.ConnectionImpl.checkResultTypeAllowed;
2324
import static org.hamcrest.CoreMatchers.equalTo;
2425
import static org.hamcrest.CoreMatchers.is;
2526
import static org.hamcrest.CoreMatchers.notNullValue;
@@ -28,6 +29,7 @@
2829
import static org.junit.Assert.assertFalse;
2930
import static org.junit.Assert.assertNotNull;
3031
import static org.junit.Assert.assertNull;
32+
import static org.junit.Assert.assertThrows;
3133
import static org.junit.Assert.assertTrue;
3234
import static org.junit.Assert.fail;
3335
import static org.mockito.Mockito.any;
@@ -73,6 +75,7 @@
7375
import com.google.cloud.spanner.connection.StatementResult.ResultType;
7476
import com.google.cloud.spanner.connection.UnitOfWork.CallType;
7577
import com.google.cloud.spanner.connection.UnitOfWork.UnitOfWorkState;
78+
import com.google.common.collect.ImmutableSet;
7679
import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata;
7780
import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions;
7881
import com.google.spanner.v1.ResultSetStats;
@@ -1624,4 +1627,115 @@ UnitOfWork createNewUnitOfWork(boolean isInternalMetadataQuery) {
16241627
assertNull(connection.getTransactionTag());
16251628
}
16261629
}
1630+
1631+
@Test
1632+
public void testCheckResultTypeAllowed() {
1633+
AbstractStatementParser parser =
1634+
AbstractStatementParser.getInstance(Dialect.GOOGLE_STANDARD_SQL);
1635+
String query = "select * from foo";
1636+
String dml = "update foo set bar=1 where true";
1637+
String dmlReturning = "insert into foo (id, value) values (1, 'One') then return id";
1638+
String ddl = "create table foo";
1639+
String set = "set readonly=true";
1640+
String show = "show variable readonly";
1641+
String start = "start batch dml";
1642+
1643+
// null means all statements should be allowed.
1644+
ImmutableSet<ResultType> allowedResultTypes = null;
1645+
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
1646+
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
1647+
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
1648+
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
1649+
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
1650+
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
1651+
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);
1652+
1653+
allowedResultTypes = ImmutableSet.of();
1654+
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
1655+
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
1656+
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
1657+
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
1658+
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
1659+
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
1660+
assertThrowResultNotAllowed(parser, start, allowedResultTypes);
1661+
1662+
allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET);
1663+
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
1664+
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
1665+
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
1666+
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
1667+
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
1668+
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
1669+
assertThrowResultNotAllowed(parser, start, allowedResultTypes);
1670+
1671+
allowedResultTypes = ImmutableSet.of(ResultType.UPDATE_COUNT);
1672+
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
1673+
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
1674+
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
1675+
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
1676+
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
1677+
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
1678+
assertThrowResultNotAllowed(parser, start, allowedResultTypes);
1679+
1680+
allowedResultTypes = ImmutableSet.of(ResultType.NO_RESULT);
1681+
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
1682+
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
1683+
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
1684+
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
1685+
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
1686+
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
1687+
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);
1688+
1689+
allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET, ResultType.UPDATE_COUNT);
1690+
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
1691+
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
1692+
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
1693+
assertThrowResultNotAllowed(parser, ddl, allowedResultTypes);
1694+
assertThrowResultNotAllowed(parser, set, allowedResultTypes);
1695+
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
1696+
assertThrowResultNotAllowed(parser, start, allowedResultTypes);
1697+
1698+
allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET, ResultType.NO_RESULT);
1699+
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
1700+
assertThrowResultNotAllowed(parser, dml, allowedResultTypes);
1701+
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
1702+
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
1703+
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
1704+
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
1705+
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);
1706+
1707+
allowedResultTypes = ImmutableSet.of(ResultType.UPDATE_COUNT, ResultType.NO_RESULT);
1708+
assertThrowResultNotAllowed(parser, query, allowedResultTypes);
1709+
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
1710+
assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes);
1711+
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
1712+
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
1713+
assertThrowResultNotAllowed(parser, show, allowedResultTypes);
1714+
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);
1715+
1716+
allowedResultTypes =
1717+
ImmutableSet.of(ResultType.RESULT_SET, ResultType.UPDATE_COUNT, ResultType.NO_RESULT);
1718+
checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes);
1719+
checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes);
1720+
checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes);
1721+
checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes);
1722+
checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes);
1723+
checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes);
1724+
checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes);
1725+
}
1726+
1727+
private void assertThrowResultNotAllowed(
1728+
AbstractStatementParser parser, String sql, ImmutableSet<ResultType> allowedResultTypes) {
1729+
SpannerException exception =
1730+
assertThrows(
1731+
SpannerException.class,
1732+
() -> checkResultTypeAllowed(parser.parse(Statement.of(sql)), allowedResultTypes));
1733+
assertEquals(ErrorCode.INVALID_ARGUMENT, exception.getErrorCode());
1734+
assertTrue(
1735+
exception.getMessage(),
1736+
exception
1737+
.getMessage()
1738+
.contains(
1739+
"Only statements that return a result of one of the following types are allowed"));
1740+
}
16271741
}

0 commit comments

Comments
 (0)