@@ -94,64 +94,92 @@ def test_frame_in_list(self):
94
94
with pytest .raises (ValueError , match = msg ):
95
95
df in [None ]
96
96
97
- def test_comparison_invalid (self ):
98
- def check (df , df2 ):
99
-
100
- for (x , y ) in [(df , df2 ), (df2 , df )]:
101
- # we expect the result to match Series comparisons for
102
- # == and !=, inequalities should raise
103
- result = x == y
104
- expected = DataFrame (
105
- {col : x [col ] == y [col ] for col in x .columns },
106
- index = x .index ,
107
- columns = x .columns ,
108
- )
109
- tm .assert_frame_equal (result , expected )
110
-
111
- result = x != y
112
- expected = DataFrame (
113
- {col : x [col ] != y [col ] for col in x .columns },
114
- index = x .index ,
115
- columns = x .columns ,
116
- )
117
- tm .assert_frame_equal (result , expected )
118
-
119
- msgs = [
120
- r"Invalid comparison between dtype=datetime64\[ns\] and ndarray" ,
121
- "invalid type promotion" ,
122
- (
123
- # npdev 1.20.0
124
- r"The DTypes <class 'numpy.dtype\[.*\]'> and "
125
- r"<class 'numpy.dtype\[.*\]'> do not have a common DType."
126
- ),
127
- ]
128
- msg = "|" .join (msgs )
129
- with pytest .raises (TypeError , match = msg ):
130
- x >= y
131
- with pytest .raises (TypeError , match = msg ):
132
- x > y
133
- with pytest .raises (TypeError , match = msg ):
134
- x < y
135
- with pytest .raises (TypeError , match = msg ):
136
- x <= y
137
-
97
+ @pytest .mark .parametrize (
98
+ "arg, arg2" ,
99
+ [
100
+ [
101
+ {
102
+ "a" : np .random .randint (10 , size = 10 ),
103
+ "b" : pd .date_range ("20010101" , periods = 10 ),
104
+ },
105
+ {
106
+ "a" : np .random .randint (10 , size = 10 ),
107
+ "b" : np .random .randint (10 , size = 10 ),
108
+ },
109
+ ],
110
+ [
111
+ {
112
+ "a" : np .random .randint (10 , size = 10 ),
113
+ "b" : np .random .randint (10 , size = 10 ),
114
+ },
115
+ {
116
+ "a" : np .random .randint (10 , size = 10 ),
117
+ "b" : pd .date_range ("20010101" , periods = 10 ),
118
+ },
119
+ ],
120
+ [
121
+ {
122
+ "a" : pd .date_range ("20010101" , periods = 10 ),
123
+ "b" : pd .date_range ("20010101" , periods = 10 ),
124
+ },
125
+ {
126
+ "a" : np .random .randint (10 , size = 10 ),
127
+ "b" : np .random .randint (10 , size = 10 ),
128
+ },
129
+ ],
130
+ [
131
+ {
132
+ "a" : np .random .randint (10 , size = 10 ),
133
+ "b" : pd .date_range ("20010101" , periods = 10 ),
134
+ },
135
+ {
136
+ "a" : pd .date_range ("20010101" , periods = 10 ),
137
+ "b" : pd .date_range ("20010101" , periods = 10 ),
138
+ },
139
+ ],
140
+ ],
141
+ )
142
+ def test_comparison_invalid (self , arg , arg2 ):
138
143
# GH4968
139
144
# invalid date/int comparisons
140
- df = DataFrame (np .random .randint (10 , size = (10 , 1 )), columns = ["a" ])
141
- df ["dates" ] = pd .date_range ("20010101" , periods = len (df ))
142
-
143
- df2 = df .copy ()
144
- df2 ["dates" ] = df ["a" ]
145
- check (df , df2 )
145
+ x = DataFrame (arg )
146
+ y = DataFrame (arg2 )
147
+ # we expect the result to match Series comparisons for
148
+ # == and !=, inequalities should raise
149
+ result = x == y
150
+ expected = DataFrame (
151
+ {col : x [col ] == y [col ] for col in x .columns },
152
+ index = x .index ,
153
+ columns = x .columns ,
154
+ )
155
+ tm .assert_frame_equal (result , expected )
146
156
147
- df = DataFrame (np .random .randint (10 , size = (10 , 2 )), columns = ["a" , "b" ])
148
- df2 = DataFrame (
149
- {
150
- "a" : pd .date_range ("20010101" , periods = len (df )),
151
- "b" : pd .date_range ("20100101" , periods = len (df )),
152
- }
157
+ result = x != y
158
+ expected = DataFrame (
159
+ {col : x [col ] != y [col ] for col in x .columns },
160
+ index = x .index ,
161
+ columns = x .columns ,
153
162
)
154
- check (df , df2 )
163
+ tm .assert_frame_equal (result , expected )
164
+
165
+ msgs = [
166
+ r"Invalid comparison between dtype=datetime64\[ns\] and ndarray" ,
167
+ "invalid type promotion" ,
168
+ (
169
+ # npdev 1.20.0
170
+ r"The DTypes <class 'numpy.dtype\[.*\]'> and "
171
+ r"<class 'numpy.dtype\[.*\]'> do not have a common DType."
172
+ ),
173
+ ]
174
+ msg = "|" .join (msgs )
175
+ with pytest .raises (TypeError , match = msg ):
176
+ x >= y
177
+ with pytest .raises (TypeError , match = msg ):
178
+ x > y
179
+ with pytest .raises (TypeError , match = msg ):
180
+ x < y
181
+ with pytest .raises (TypeError , match = msg ):
182
+ x <= y
155
183
156
184
def test_timestamp_compare (self ):
157
185
# make sure we can compare Timestamps on the right AND left hand side
0 commit comments