|
20 | 20 | import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.SELECT;
|
21 | 21 | import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.UPDATE;
|
22 | 22 | import static com.google.cloud.spanner.connection.AbstractConnectionImplTest.expectSpannerException;
|
| 23 | +import static com.google.cloud.spanner.connection.ConnectionImpl.checkResultTypeAllowed; |
23 | 24 | import static org.hamcrest.CoreMatchers.equalTo;
|
24 | 25 | import static org.hamcrest.CoreMatchers.is;
|
25 | 26 | import static org.hamcrest.CoreMatchers.notNullValue;
|
|
28 | 29 | import static org.junit.Assert.assertFalse;
|
29 | 30 | import static org.junit.Assert.assertNotNull;
|
30 | 31 | import static org.junit.Assert.assertNull;
|
| 32 | +import static org.junit.Assert.assertThrows; |
31 | 33 | import static org.junit.Assert.assertTrue;
|
32 | 34 | import static org.junit.Assert.fail;
|
33 | 35 | import static org.mockito.Mockito.any;
|
|
73 | 75 | import com.google.cloud.spanner.connection.StatementResult.ResultType;
|
74 | 76 | import com.google.cloud.spanner.connection.UnitOfWork.CallType;
|
75 | 77 | import com.google.cloud.spanner.connection.UnitOfWork.UnitOfWorkState;
|
| 78 | +import com.google.common.collect.ImmutableSet; |
76 | 79 | import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata;
|
77 | 80 | import com.google.spanner.v1.ExecuteSqlRequest.QueryOptions;
|
78 | 81 | import com.google.spanner.v1.ResultSetStats;
|
@@ -1624,4 +1627,115 @@ UnitOfWork createNewUnitOfWork(boolean isInternalMetadataQuery) {
|
1624 | 1627 | assertNull(connection.getTransactionTag());
|
1625 | 1628 | }
|
1626 | 1629 | }
|
| 1630 | + |
| 1631 | + @Test |
| 1632 | + public void testCheckResultTypeAllowed() { |
| 1633 | + AbstractStatementParser parser = |
| 1634 | + AbstractStatementParser.getInstance(Dialect.GOOGLE_STANDARD_SQL); |
| 1635 | + String query = "select * from foo"; |
| 1636 | + String dml = "update foo set bar=1 where true"; |
| 1637 | + String dmlReturning = "insert into foo (id, value) values (1, 'One') then return id"; |
| 1638 | + String ddl = "create table foo"; |
| 1639 | + String set = "set readonly=true"; |
| 1640 | + String show = "show variable readonly"; |
| 1641 | + String start = "start batch dml"; |
| 1642 | + |
| 1643 | + // null means all statements should be allowed. |
| 1644 | + ImmutableSet<ResultType> allowedResultTypes = null; |
| 1645 | + checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes); |
| 1646 | + checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes); |
| 1647 | + checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes); |
| 1648 | + checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes); |
| 1649 | + checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes); |
| 1650 | + checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes); |
| 1651 | + checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes); |
| 1652 | + |
| 1653 | + allowedResultTypes = ImmutableSet.of(); |
| 1654 | + assertThrowResultNotAllowed(parser, query, allowedResultTypes); |
| 1655 | + assertThrowResultNotAllowed(parser, dml, allowedResultTypes); |
| 1656 | + assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes); |
| 1657 | + assertThrowResultNotAllowed(parser, ddl, allowedResultTypes); |
| 1658 | + assertThrowResultNotAllowed(parser, set, allowedResultTypes); |
| 1659 | + assertThrowResultNotAllowed(parser, show, allowedResultTypes); |
| 1660 | + assertThrowResultNotAllowed(parser, start, allowedResultTypes); |
| 1661 | + |
| 1662 | + allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET); |
| 1663 | + checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes); |
| 1664 | + assertThrowResultNotAllowed(parser, dml, allowedResultTypes); |
| 1665 | + checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes); |
| 1666 | + assertThrowResultNotAllowed(parser, ddl, allowedResultTypes); |
| 1667 | + assertThrowResultNotAllowed(parser, set, allowedResultTypes); |
| 1668 | + checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes); |
| 1669 | + assertThrowResultNotAllowed(parser, start, allowedResultTypes); |
| 1670 | + |
| 1671 | + allowedResultTypes = ImmutableSet.of(ResultType.UPDATE_COUNT); |
| 1672 | + assertThrowResultNotAllowed(parser, query, allowedResultTypes); |
| 1673 | + checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes); |
| 1674 | + assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes); |
| 1675 | + assertThrowResultNotAllowed(parser, ddl, allowedResultTypes); |
| 1676 | + assertThrowResultNotAllowed(parser, set, allowedResultTypes); |
| 1677 | + assertThrowResultNotAllowed(parser, show, allowedResultTypes); |
| 1678 | + assertThrowResultNotAllowed(parser, start, allowedResultTypes); |
| 1679 | + |
| 1680 | + allowedResultTypes = ImmutableSet.of(ResultType.NO_RESULT); |
| 1681 | + assertThrowResultNotAllowed(parser, query, allowedResultTypes); |
| 1682 | + assertThrowResultNotAllowed(parser, dml, allowedResultTypes); |
| 1683 | + assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes); |
| 1684 | + checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes); |
| 1685 | + checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes); |
| 1686 | + assertThrowResultNotAllowed(parser, show, allowedResultTypes); |
| 1687 | + checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes); |
| 1688 | + |
| 1689 | + allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET, ResultType.UPDATE_COUNT); |
| 1690 | + checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes); |
| 1691 | + checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes); |
| 1692 | + checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes); |
| 1693 | + assertThrowResultNotAllowed(parser, ddl, allowedResultTypes); |
| 1694 | + assertThrowResultNotAllowed(parser, set, allowedResultTypes); |
| 1695 | + checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes); |
| 1696 | + assertThrowResultNotAllowed(parser, start, allowedResultTypes); |
| 1697 | + |
| 1698 | + allowedResultTypes = ImmutableSet.of(ResultType.RESULT_SET, ResultType.NO_RESULT); |
| 1699 | + checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes); |
| 1700 | + assertThrowResultNotAllowed(parser, dml, allowedResultTypes); |
| 1701 | + checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes); |
| 1702 | + checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes); |
| 1703 | + checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes); |
| 1704 | + checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes); |
| 1705 | + checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes); |
| 1706 | + |
| 1707 | + allowedResultTypes = ImmutableSet.of(ResultType.UPDATE_COUNT, ResultType.NO_RESULT); |
| 1708 | + assertThrowResultNotAllowed(parser, query, allowedResultTypes); |
| 1709 | + checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes); |
| 1710 | + assertThrowResultNotAllowed(parser, dmlReturning, allowedResultTypes); |
| 1711 | + checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes); |
| 1712 | + checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes); |
| 1713 | + assertThrowResultNotAllowed(parser, show, allowedResultTypes); |
| 1714 | + checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes); |
| 1715 | + |
| 1716 | + allowedResultTypes = |
| 1717 | + ImmutableSet.of(ResultType.RESULT_SET, ResultType.UPDATE_COUNT, ResultType.NO_RESULT); |
| 1718 | + checkResultTypeAllowed(parser.parse(Statement.of(query)), allowedResultTypes); |
| 1719 | + checkResultTypeAllowed(parser.parse(Statement.of(dml)), allowedResultTypes); |
| 1720 | + checkResultTypeAllowed(parser.parse(Statement.of(dmlReturning)), allowedResultTypes); |
| 1721 | + checkResultTypeAllowed(parser.parse(Statement.of(ddl)), allowedResultTypes); |
| 1722 | + checkResultTypeAllowed(parser.parse(Statement.of(set)), allowedResultTypes); |
| 1723 | + checkResultTypeAllowed(parser.parse(Statement.of(show)), allowedResultTypes); |
| 1724 | + checkResultTypeAllowed(parser.parse(Statement.of(start)), allowedResultTypes); |
| 1725 | + } |
| 1726 | + |
| 1727 | + private void assertThrowResultNotAllowed( |
| 1728 | + AbstractStatementParser parser, String sql, ImmutableSet<ResultType> allowedResultTypes) { |
| 1729 | + SpannerException exception = |
| 1730 | + assertThrows( |
| 1731 | + SpannerException.class, |
| 1732 | + () -> checkResultTypeAllowed(parser.parse(Statement.of(sql)), allowedResultTypes)); |
| 1733 | + assertEquals(ErrorCode.INVALID_ARGUMENT, exception.getErrorCode()); |
| 1734 | + assertTrue( |
| 1735 | + exception.getMessage(), |
| 1736 | + exception |
| 1737 | + .getMessage() |
| 1738 | + .contains( |
| 1739 | + "Only statements that return a result of one of the following types are allowed")); |
| 1740 | + } |
1627 | 1741 | }
|
0 commit comments