8
8
use IteratorAggregate ;
9
9
use MongoDB \Builder \Type \StageInterface ;
10
10
use MongoDB \Exception \InvalidArgumentException ;
11
- use Traversable ;
12
11
13
12
use function array_is_list ;
14
13
use function array_merge ;
14
+ use function is_array ;
15
15
16
16
/**
17
17
* An aggregation pipeline consists of one or more stages that process documents.
18
18
*
19
19
* @see https://www.mongodb.com/docs/manual/core/aggregation-pipeline/
20
20
*
21
21
* @psalm-immutable
22
- * @implements IteratorAggregate<StageInterface>
22
+ * @implements IteratorAggregate<StageInterface|array<string,mixed>|object >
23
23
*/
24
- class Pipeline implements IteratorAggregate
24
+ final class Pipeline implements IteratorAggregate
25
25
{
26
- /** @var StageInterface[] */
27
26
private readonly array $ stages ;
28
27
29
- /** @no-named-arguments */
30
- public function __construct (StageInterface |Pipeline ...$ stagesOrPipelines )
28
+ /**
29
+ * @param StageInterface|Pipeline|list<StageInterface> ...$stagesOrPipelines
30
+ *
31
+ * @no-named-arguments
32
+ */
33
+ public function __construct (StageInterface |Pipeline |array ...$ stagesOrPipelines )
31
34
{
32
35
if (! array_is_list ($ stagesOrPipelines )) {
33
36
throw new InvalidArgumentException ('Named arguments are not supported for pipelines ' );
@@ -36,7 +39,9 @@ public function __construct(StageInterface|Pipeline ...$stagesOrPipelines)
36
39
$ stages = [];
37
40
38
41
foreach ($ stagesOrPipelines as $ stageOrPipeline ) {
39
- if ($ stageOrPipeline instanceof Pipeline) {
42
+ if (is_array ($ stageOrPipeline ) && array_is_list ($ stageOrPipeline )) {
43
+ $ stages = array_merge ($ stages , $ stageOrPipeline );
44
+ } elseif ($ stageOrPipeline instanceof Pipeline) {
40
45
$ stages = array_merge ($ stages , $ stageOrPipeline ->stages );
41
46
} else {
42
47
$ stages [] = $ stageOrPipeline ;
@@ -46,8 +51,7 @@ public function __construct(StageInterface|Pipeline ...$stagesOrPipelines)
46
51
$ this ->stages = $ stages ;
47
52
}
48
53
49
- /** @return Traversable<StageInterface> */
50
- public function getIterator (): Traversable
54
+ public function getIterator (): ArrayIterator
51
55
{
52
56
return new ArrayIterator ($ this ->stages );
53
57
}
0 commit comments