Skip to content

Commit fe75806

Browse files
authored
Improve discriminated union implementation (#4556)
* WIP * Simplify discriminated union logic * Simplify disc union internals, improve perf * Improve docs * Typo
1 parent 5e31fd0 commit fe75806

File tree

17 files changed

+611
-2300
lines changed

17 files changed

+611
-2300
lines changed

packages/bench/discriminated-union.ts

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { makeData, makeSchema, randomPick, randomString } from "./benchUtil.js";
1+
import { makeData, randomPick, randomString } from "./benchUtil.js";
22
import { metabench } from "./metabench.js";
33

44

@@ -54,55 +54,54 @@ const z3DiscUnion = z3.discriminatedUnion("type", z3Union._def.options);
5454

5555
function makeSchema(z: typeof z4){
5656
const z4fields = {
57-
data1: z4.string(),
58-
data2: z4.string(),
59-
data3: z4.string(),
60-
// data4: z4.string(),
61-
// data5: z4.string(),
62-
// data6: z4.string(),
63-
// data7: z4.string(),
64-
// data8: z4.string(),
65-
// data9: z4.string(),
66-
// data10: z4.string(),
67-
}
68-
const z4Union = z4.union([
69-
z4.object({
70-
type: z4.literal("a"),
71-
...z4fields
72-
}),
73-
z4.object({
74-
type: z4.literal("b"),
75-
...z4fields
76-
}),
77-
z4.object({
78-
type: z4.literal("c"),
79-
...z4fields
80-
}),
81-
z4.object({
82-
type: z4.literal("d"),
83-
...z4fields
84-
}),
85-
z4.object({
86-
type: z4.literal("e"),
87-
...z4fields
88-
}),
89-
z4.object({
90-
type: z4.literal("f"),
91-
...z4fields
92-
}),
93-
z4.object({
94-
type: z4.literal("g"),
95-
...z4fields
96-
}),
97-
]);
98-
return z4Union;
57+
data1: z.string(),
58+
data2: z.string(),
59+
data3: z.string(),
60+
// data4: z.string(),
61+
// data5: z.string(),
62+
// data6: z.string(),
63+
// data7: z.string(),
64+
// data8: z.string(),
65+
// data9: z.string(),
66+
// data10: z.string(),
67+
}
68+
const z4Union = z.union([
69+
z.object({
70+
type: z.literal("a"),
71+
...z4fields
72+
}),
73+
z.object({
74+
type: z.literal("b"),
75+
...z4fields
76+
}),
77+
z.object({
78+
type: z.literal("c"),
79+
...z4fields
80+
}),
81+
z.object({
82+
type: z.literal("d"),
83+
...z4fields
84+
}),
85+
z.object({
86+
type: z.literal("e"),
87+
...z4fields
88+
}),
89+
z.object({
90+
type: z.literal("f"),
91+
...z4fields
92+
}),
93+
z.object({
94+
type: z.literal("g"),
95+
...z4fields
96+
}),
97+
]);
98+
return z4Union;
9999

100100
}
101101

102102
const z4Union = makeSchema(z4);
103103
const z4LibUnion = makeSchema(z4lib as any);
104-
const z4LibDiscUnion = z4lib.discriminatedUnion( z4LibUnion._def.options);
105-
104+
const z4LibDiscUnion = z4lib.discriminatedUnion("type", z4LibUnion._def.options as any);
106105
const z4DiscUnion = z4.discriminatedUnion("type", z4Union.def.options);
107106

108107
const DATA = makeData(100, () => ({
@@ -137,7 +136,7 @@ console.dir(z4LibDiscUnion.parse(DATA[0]), {depth: null});
137136

138137

139138
const args= {jitless: true}
140-
const bench = metabench("z.disriminatedUnion().parse", {
139+
const bench = metabench("z.discriminatedUnion().parse", {
141140
// z3() {
142141
// for (const item of DATA) {
143142
// z3Union.parse(item);

packages/bench/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ async function run() {
66
const files = process.argv[2].split(",").map((file) => import.meta.resolve(`./${file}`).replace("file://", ""));
77

88
for (const file of files) {
9-
await $`pnpm tsx --conditions @zod/source ${file}`;
9+
await $`pnpm tsx ${file}`;
1010
}
1111
}
1212

packages/bench/jit-union.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ const DATA = makeData(100, () => ({
7474
const args= {jitless: true}
7575
console.dir(z4Union.parse(DATA[0]), {depth: null});
7676
console.dir(z4Union.parse(DATA[0], args), {depth: null});
77-
const bench = metabench("z.disriminatedUnion().parse", {
77+
const bench = metabench("z.discriminatedUnion().parse", {
7878

7979
v4_jit() {
8080
for (const item of DATA) {

packages/bench/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"arktype": "^2.1.19",
77
"valibot": "^1.0.0",
88
"zod": "workspace:*",
9-
"zodnext": "npm:zod@next",
9+
"zodnext": "npm:zod@^3.25.0",
1010
"zod3": "npm:zod@^3.23.7"
1111
},
1212
"scripts": {

packages/docs/content/v4/index.mdx

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ Replace `errorMap` with `error` (function syntax):
827827
828828
## Upgraded `z.discriminatedUnion()`
829829
830-
Discriminated unions now support a number of schema types not previously supported, including unions, pipes, and nested objects:
830+
Discriminated unions now support a number of schema types not previously supported, including unions and pipes:
831831
832832
```ts
833833
const MyResult = z.discriminatedUnion("status", [
@@ -836,23 +836,22 @@ const MyResult = z.discriminatedUnion("status", [
836836
// union discriminator
837837
z.object({ status: z.union([z.literal("bbb"), z.literal("ccc")]) }),
838838
// pipe discriminator
839-
z.object({ status: z.object({ value: z.literal("fail") }) }),
839+
z.object({ status: z.literal("fail").transform(val => val.toUpperCase()) }),
840840
]);
841841
```
842842
843843
Perhaps most importantly, discriminated unions now *compose*—you can use one discriminated union as a member of another.
844844
845845
```ts
846846
const BaseError = z.object({ status: z.literal("failed"), message: z.string() });
847-
const MyErrors = z.discriminatedUnion("code", [
848-
BaseError.extend({ code: z.literal(400) }),
849-
BaseError.extend({ code: z.literal(401) }),
850-
BaseError.extend({ code: z.literal(500) })
851-
]);
852847
853848
const MyResult = z.discriminatedUnion("status", [
854849
z.object({ status: z.literal("success"), data: z.string() }),
855-
MyErrors
850+
z.discriminatedUnion("code", [
851+
BaseError.extend({ code: z.literal(400) }),
852+
BaseError.extend({ code: z.literal(401) }),
853+
BaseError.extend({ code: z.literal(500) })
854+
])
856855
]);
857856
```
858857

0 commit comments

Comments
 (0)