Skip to content

Commit c2c90fd

Browse files
13sindavidhewitt
andauthored
Fix TzInfo equality check based on offset (#1197)
Co-authored-by: David Hewitt <[email protected]>
1 parent 8c6e2bd commit c2c90fd

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/input/datetime.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,19 @@ impl TzInfo {
563563
hasher.finish()
564564
}
565565

566-
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
567-
op.matches(self.seconds.cmp(&other.seconds))
566+
fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult<Py<PyAny>> {
567+
let py = other.py();
568+
if other.is_instance_of::<PyTzInfo>() {
569+
let offset_delta = other.call_method1(intern!(py, "utcoffset"), (py.None(),))?;
570+
if offset_delta.is_none() {
571+
return Ok(py.NotImplemented());
572+
}
573+
let offset_seconds: f64 = offset_delta.call_method0(intern!(py, "total_seconds"))?.extract()?;
574+
let offset = offset_seconds.round() as i32;
575+
Ok(op.matches(self.seconds.cmp(&offset)).into_py(py))
576+
} else {
577+
Ok(py.NotImplemented())
578+
}
568579
}
569580

570581
fn __deepcopy__(&self, py: Python, _memo: &PyDict) -> PyResult<Py<Self>> {

tests/test_tzinfo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import copy
22
import functools
33
import pickle
4+
import sys
45
import unittest
56
from datetime import datetime, timedelta, timezone, tzinfo
67

78
from pydantic_core import SchemaValidator, TzInfo, core_schema
89

10+
if sys.version_info >= (3, 9):
11+
from zoneinfo import ZoneInfo
12+
913

1014
class _ALWAYS_EQ:
1115
"""
@@ -80,6 +84,7 @@ class TestTzInfo(unittest.TestCase):
8084
def setUp(self):
8185
self.ACDT = TzInfo(timedelta(hours=9.5).total_seconds())
8286
self.EST = TzInfo(-timedelta(hours=5).total_seconds())
87+
self.UTC = TzInfo(timedelta(0).total_seconds())
8388
self.DT = datetime(2010, 1, 1)
8489

8590
def test_str(self):
@@ -163,6 +168,17 @@ def test_comparison(self):
163168
self.assertFalse(tz <= SMALLEST)
164169
self.assertTrue(tz >= SMALLEST)
165170

171+
# offset based comparion tests for tzinfo derived classes like datetime.timezone.
172+
utcdatetime = self.DT.replace(tzinfo=timezone.utc)
173+
self.assertTrue(tz == utcdatetime.tzinfo)
174+
estdatetime = self.DT.replace(tzinfo=timezone(-timedelta(hours=5)))
175+
self.assertTrue(self.EST == estdatetime.tzinfo)
176+
self.assertTrue(tz > estdatetime.tzinfo)
177+
if sys.version_info >= (3, 9) and sys.platform == 'linux':
178+
self.assertFalse(tz == ZoneInfo('Europe/London'))
179+
with self.assertRaises(TypeError):
180+
tz > ZoneInfo('Europe/London')
181+
166182
def test_copy(self):
167183
for tz in self.ACDT, self.EST:
168184
tz_copy = copy.copy(tz)

0 commit comments

Comments
 (0)