Skip to content

Commit 66bce03

Browse files
authored
Reimplement and optimize intersection of map types (#13524)
* Reimplement interesection of map types * Extract and reuse symmetrical merge & intersection * Move to the bottom of the file
1 parent 47279b1 commit 66bce03

File tree

1 file changed

+126
-98
lines changed

1 file changed

+126
-98
lines changed

lib/elixir/lib/module/types/descr.ex

Lines changed: 126 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,14 @@ defmodule Module.Types.Descr do
130130
cond do
131131
is_gradual_left and not is_gradual_right ->
132132
right_with_dynamic = Map.put(right, :dynamic, right)
133-
union_static(left, right_with_dynamic)
133+
symmetrical_merge(left, right_with_dynamic, &union/3)
134134

135135
is_gradual_right and not is_gradual_left ->
136136
left_with_dynamic = Map.put(left, :dynamic, left)
137-
union_static(left_with_dynamic, right)
137+
symmetrical_merge(left_with_dynamic, right, &union/3)
138138

139139
true ->
140-
union_static(left, right)
141-
end
142-
end
143-
144-
defp union_static(left, right) do
145-
# Erlang maps:merge_with/3 has to preserve the order in combiner.
146-
# We don't care about the order, so we have a faster implementation.
147-
if map_size(left) > map_size(right) do
148-
iterator_union(:maps.next(:maps.iterator(right)), left)
149-
else
150-
iterator_union(:maps.next(:maps.iterator(left)), right)
140+
symmetrical_merge(left, right, &union/3)
151141
end
152142
end
153143

@@ -167,24 +157,14 @@ defmodule Module.Types.Descr do
167157
cond do
168158
is_gradual_left and not is_gradual_right ->
169159
right_with_dynamic = Map.put(right, :dynamic, right)
170-
intersection_static(left, right_with_dynamic)
160+
symmetrical_intersection(left, right_with_dynamic, &intersection/3)
171161

172162
is_gradual_right and not is_gradual_left ->
173163
left_with_dynamic = Map.put(left, :dynamic, left)
174-
intersection_static(left_with_dynamic, right)
164+
symmetrical_intersection(left_with_dynamic, right, &intersection/3)
175165

176166
true ->
177-
intersection_static(left, right)
178-
end
179-
end
180-
181-
defp intersection_static(left, right) do
182-
# Erlang maps:intersect_with/3 has to preserve the order in combiner.
183-
# We don't care about the order, so we have a faster implementation.
184-
if map_size(left) > map_size(right) do
185-
iterator_intersection(:maps.next(:maps.iterator(right)), left, [])
186-
else
187-
iterator_intersection(:maps.next(:maps.iterator(left)), right, [])
167+
symmetrical_intersection(left, right, &intersection/3)
188168
end
189169
end
190170

@@ -272,58 +252,6 @@ defmodule Module.Types.Descr do
272252
|> IO.iodata_to_binary()
273253
end
274254

275-
## Iterator helpers
276-
277-
defp iterator_union({key, v1, iterator}, map) do
278-
acc =
279-
case map do
280-
%{^key => v2} -> %{map | key => union(key, v1, v2)}
281-
%{} -> Map.put(map, key, v1)
282-
end
283-
284-
iterator_union(:maps.next(iterator), acc)
285-
end
286-
287-
defp iterator_union(:none, map), do: map
288-
289-
defp iterator_intersection({key, v1, iterator}, map, acc) do
290-
acc =
291-
case map do
292-
%{^key => v2} ->
293-
case intersection(key, v1, v2) do
294-
0 -> acc
295-
[] -> acc
296-
value -> [{key, value} | acc]
297-
end
298-
299-
%{} ->
300-
acc
301-
end
302-
303-
iterator_intersection(:maps.next(iterator), map, acc)
304-
end
305-
306-
defp iterator_intersection(:none, _map, acc), do: :maps.from_list(acc)
307-
308-
defp iterator_difference({key, v2, iterator}, map) do
309-
acc =
310-
case map do
311-
%{^key => v1} ->
312-
case difference(key, v1, v2) do
313-
0 -> Map.delete(map, key)
314-
[] -> Map.delete(map, key)
315-
value -> %{map | key => value}
316-
end
317-
318-
%{} ->
319-
map
320-
end
321-
322-
iterator_difference(:maps.next(iterator), acc)
323-
end
324-
325-
defp iterator_difference(:none, map), do: map
326-
327255
## Type relations
328256

329257
@doc """
@@ -612,7 +540,7 @@ defmodule Module.Types.Descr do
612540
# (that is, there are no extra dynamic values).
613541

614542
defp dynamic_intersection(left, right) do
615-
inter = intersection_static(left, right)
543+
inter = symmetrical_intersection(left, right, &intersection/3)
616544
if empty?(inter), do: 0, else: inter
617545
end
618546

@@ -621,7 +549,7 @@ defmodule Module.Types.Descr do
621549
if empty?(diff), do: 0, else: diff
622550
end
623551

624-
defp dynamic_union(left, right), do: union_static(left, right)
552+
defp dynamic_union(left, right), do: symmetrical_merge(left, right, &union/3)
625553

626554
defp dynamic_to_quoted(%{} = descr) do
627555
cond do
@@ -746,27 +674,55 @@ defmodule Module.Types.Descr do
746674
end
747675

748676
# Intersects two map literals; throws if their intersection is empty.
749-
defp map_literal_intersection(tag1, map1, tag2, map2) do
750-
default1 = map_tag_to_type(tag1)
751-
default2 = map_tag_to_type(tag2)
677+
# Both open: the result is open.
678+
defp map_literal_intersection(:open, map1, :open, map2) do
679+
new_fields =
680+
symmetrical_merge(map1, map2, fn _, type1, type2 ->
681+
non_empty_intersection!(type1, type2)
682+
end)
752683

753-
# if any intersection of values is empty, the whole intersection is empty
684+
{:open, new_fields}
685+
end
686+
687+
# Both closed: the result is closed.
688+
defp map_literal_intersection(:closed, map1, :closed, map2) do
754689
new_fields =
755-
(for {key, value_type} <- map1 do
756-
value_type2 = Map.get(map2, key, default2)
757-
t = intersection(value_type, value_type2)
758-
if empty?(t), do: throw(:empty), else: {key, t}
759-
end ++
760-
for {key, value_type} <- map2, not is_map_key(map1, key) do
761-
t = intersection(default1, value_type)
762-
if empty?(t), do: throw(:empty), else: {key, t}
763-
end)
764-
|> Map.new()
765-
766-
case {tag1, tag2} do
767-
{:open, :open} -> {:open, new_fields}
768-
_ -> {:closed, new_fields}
690+
symmetrical_intersection(map1, map2, fn _, type1, type2 ->
691+
non_empty_intersection!(type1, type2)
692+
end)
693+
694+
if map_size(new_fields) < map_size(map1) or map_size(new_fields) < map_size(map2) do
695+
throw(:empty)
769696
end
697+
698+
{:closed, new_fields}
699+
end
700+
701+
# Open and closed: result is closed, all fields from open should be in closed
702+
defp map_literal_intersection(:open, open, :closed, closed) do
703+
:maps.iterator(open) |> :maps.next() |> map_literal_intersection_loop(closed)
704+
end
705+
706+
defp map_literal_intersection(:closed, closed, :open, open) do
707+
:maps.iterator(open) |> :maps.next() |> map_literal_intersection_loop(closed)
708+
end
709+
710+
defp map_literal_intersection_loop(:none, acc), do: {:closed, acc}
711+
712+
defp map_literal_intersection_loop({key, type1, iterator}, acc) do
713+
case acc do
714+
%{^key => type2} ->
715+
acc = %{acc | key => non_empty_intersection!(type1, type2)}
716+
:maps.next(iterator) |> map_literal_intersection_loop(acc)
717+
718+
_ ->
719+
throw(:empty)
720+
end
721+
end
722+
723+
defp non_empty_intersection!(type1, type2) do
724+
type = intersection(type1, type2)
725+
if empty?(type), do: throw(:empty), else: type
770726
end
771727

772728
defp map_difference(dnf1, dnf2) do
@@ -1076,4 +1032,76 @@ defmodule Module.Types.Descr do
10761032
end
10771033
end
10781034
end
1035+
1036+
## Map helpers
1037+
1038+
defp symmetrical_merge(left, right, fun) do
1039+
# Erlang maps:merge_with/3 has to preserve the order in combiner.
1040+
# We don't care about the order, so we have a faster implementation.
1041+
if map_size(left) > map_size(right) do
1042+
iterator_merge(:maps.next(:maps.iterator(right)), left, fun)
1043+
else
1044+
iterator_merge(:maps.next(:maps.iterator(left)), right, fun)
1045+
end
1046+
end
1047+
1048+
defp iterator_merge({key, v1, iterator}, map, fun) do
1049+
acc =
1050+
case map do
1051+
%{^key => v2} -> %{map | key => fun.(key, v1, v2)}
1052+
%{} -> Map.put(map, key, v1)
1053+
end
1054+
1055+
iterator_merge(:maps.next(iterator), acc, fun)
1056+
end
1057+
1058+
defp iterator_merge(:none, map, _fun), do: map
1059+
1060+
defp symmetrical_intersection(left, right, fun) do
1061+
# Erlang maps:intersect_with/3 has to preserve the order in combiner.
1062+
# We don't care about the order, so we have a faster implementation.
1063+
if map_size(left) > map_size(right) do
1064+
iterator_intersection(:maps.next(:maps.iterator(right)), left, [], fun)
1065+
else
1066+
iterator_intersection(:maps.next(:maps.iterator(left)), right, [], fun)
1067+
end
1068+
end
1069+
1070+
defp iterator_intersection({key, v1, iterator}, map, acc, fun) do
1071+
acc =
1072+
case map do
1073+
%{^key => v2} ->
1074+
case fun.(key, v1, v2) do
1075+
0 -> acc
1076+
[] -> acc
1077+
value -> [{key, value} | acc]
1078+
end
1079+
1080+
%{} ->
1081+
acc
1082+
end
1083+
1084+
iterator_intersection(:maps.next(iterator), map, acc, fun)
1085+
end
1086+
1087+
defp iterator_intersection(:none, _map, acc, _fun), do: :maps.from_list(acc)
1088+
1089+
defp iterator_difference({key, v2, iterator}, map) do
1090+
acc =
1091+
case map do
1092+
%{^key => v1} ->
1093+
case difference(key, v1, v2) do
1094+
0 -> Map.delete(map, key)
1095+
[] -> Map.delete(map, key)
1096+
value -> %{map | key => value}
1097+
end
1098+
1099+
%{} ->
1100+
map
1101+
end
1102+
1103+
iterator_difference(:maps.next(iterator), acc)
1104+
end
1105+
1106+
defp iterator_difference(:none, map), do: map
10791107
end

0 commit comments

Comments
 (0)