Skip to content

Commit 008791d

Browse files
committed
Supply expected dtypes to to_dataframe()
1 parent 78ffa8e commit 008791d

File tree

5 files changed

+36
-51
lines changed

5 files changed

+36
-51
lines changed

benchmark/read_gbq_large_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
# result sets.
66
df = pandas_gbq.read_gbq(
77
"SELECT * FROM `bigquery-public-data.usa_names.usa_1910_2013`",
8-
dialect="standard")
8+
dialect="standard",
9+
)

benchmark/read_gbq_small_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
# Select a few KB worth of data, to time downloading small result sets.
55
df = pandas_gbq.read_gbq(
66
"SELECT * FROM `bigquery-public-data.utility_us.country_code_iso`",
7-
dialect="standard")
7+
dialect="standard",
8+
)

pandas_gbq/gbq.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,10 @@ def run_query(self, query, **kwargs):
480480
rows_iter = query_reply.result()
481481
except self.http_error as ex:
482482
self.process_http_error(ex)
483-
df = rows_iter.to_dataframe()
483+
484+
schema_fields = [field.to_api_repr() for field in rows_iter.schema]
485+
dtypes = _bqschema_to_dtypes(schema_fields)
486+
df = rows_iter.to_dataframe(dtypes=dtypes)
484487
logger.debug("Got {} rows.\n".format(rows_iter.total_rows))
485488
return df
486489

@@ -630,27 +633,32 @@ def delete_and_recreate_table(self, dataset_id, table_id, table_schema):
630633
table.create(table_id, table_schema)
631634

632635

633-
def _parse_schema(schema_fields):
636+
def _bqschema_to_dtypes(schema_fields):
637+
# Only specify dtype when the dtype allows nulls. Otherwise, use pandas's
638+
# default dtype choice.
639+
#
634640
# see:
635641
# http://pandas.pydata.org/pandas-docs/dev/missing_data.html
636642
# #missing-data-casting-rules-and-indexing
637643
dtype_map = {
638644
"FLOAT": np.dtype(float),
639-
"TIMESTAMP": "datetime64[ns]",
645+
"TIMESTAMP": "datetime64[ns, UTC]",
640646
"TIME": "datetime64[ns]",
641647
"DATE": "datetime64[ns]",
642648
"DATETIME": "datetime64[ns]",
643-
"BOOLEAN": bool,
644-
"INTEGER": np.int64,
645649
}
646650

651+
dtypes = {}
647652
for field in schema_fields:
648653
name = str(field["name"])
649654
if field["mode"].upper() == "REPEATED":
650-
yield name, object
651-
else:
652-
dtype = dtype_map.get(field["type"].upper())
653-
yield name, dtype
655+
continue
656+
657+
dtype = dtype_map.get(field["type"].upper())
658+
if dtype:
659+
dtypes[name] = dtype
660+
661+
return dtypes
654662

655663

656664
def read_gbq(

tests/system/test_gbq.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,6 @@ def test_should_be_able_to_get_a_bigquery_client(self, gbq_connector):
138138
bigquery_client = gbq_connector.get_client()
139139
assert bigquery_client is not None
140140

141-
def test_should_be_able_to_get_schema_from_query(self, gbq_connector):
142-
schema, pages = gbq_connector.run_query("SELECT 1")
143-
assert schema is not None
144-
145-
def test_should_be_able_to_get_results_from_query(self, gbq_connector):
146-
schema, pages = gbq_connector.run_query("SELECT 1")
147-
assert pages is not None
148-
149141

150142
def test_should_read(project, credentials):
151143
query = 'SELECT "PI" AS valid_string'

tests/unit/test_gbq.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pandas.util.testing as tm
44
import pytest
5+
import numpy
56
from pandas import DataFrame
67
from pandas.compat.numpy import np_datetime64_compat
78

@@ -65,26 +66,23 @@ def no_auth(monkeypatch):
6566

6667

6768
@pytest.mark.parametrize(
68-
("input", "type_", "expected"),
69+
("type_", "expected"),
6970
[
70-
(1, "INTEGER", int(1)),
71-
(1, "FLOAT", float(1)),
72-
pytest.param("false", "BOOLEAN", False, marks=pytest.mark.xfail),
73-
pytest.param(
74-
"0e9",
75-
"TIMESTAMP",
76-
np_datetime64_compat("1970-01-01T00:00:00Z"),
77-
marks=pytest.mark.xfail,
78-
),
79-
("STRING", "STRING", "STRING"),
71+
("INTEGER", None), # Can't handle NULL
72+
("BOOLEAN", None), # Can't handle NULL
73+
("FLOAT", numpy.dtype(float)),
74+
("TIMESTAMP", "datetime64[ns, UTC]"),
75+
("DATETIME", "datetime64[ns]"),
8076
],
8177
)
82-
def test_should_return_bigquery_correctly_typed(input, type_, expected):
83-
result = gbq._parse_data(
84-
dict(fields=[dict(name="x", type=type_, mode="NULLABLE")]),
85-
rows=[[input]],
86-
).iloc[0, 0]
87-
assert result == expected
78+
def test_should_return_bigquery_correctly_typed(type_, expected):
79+
result = gbq._bqschema_to_dtypes(
80+
[dict(name="x", type=type_, mode="NULLABLE")]
81+
)
82+
if not expected:
83+
assert result == {}
84+
else:
85+
assert result == {"x": expected}
8886

8987

9088
def test_to_gbq_should_fail_if_invalid_table_name_passed():
@@ -264,21 +262,6 @@ def test_read_gbq_with_inferred_project_id(monkeypatch):
264262
assert df is not None
265263

266264

267-
def test_that_parse_data_works_properly():
268-
from google.cloud.bigquery.table import Row
269-
270-
test_schema = {
271-
"fields": [{"mode": "NULLABLE", "name": "column_x", "type": "STRING"}]
272-
}
273-
field_to_index = {"column_x": 0}
274-
values = ("row_value",)
275-
test_page = [Row(values, field_to_index)]
276-
277-
test_output = gbq._parse_data(test_schema, test_page)
278-
correct_output = DataFrame({"column_x": ["row_value"]})
279-
tm.assert_frame_equal(test_output, correct_output)
280-
281-
282265
def test_read_gbq_with_invalid_private_key_json_should_fail():
283266
with pytest.raises(pandas_gbq.exceptions.InvalidPrivateKeyFormat):
284267
gbq.read_gbq(

0 commit comments

Comments
 (0)