Skip to content

Commit 8880f93

Browse files
unique inverse rewrite
1 parent 3dc6ee3 commit 8880f93

File tree

1 file changed

+11
-45
lines changed

1 file changed

+11
-45
lines changed

docs/tutorials/array-api.ipynb

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
},
8585
{
8686
"cell_type": "code",
87-
"execution_count": 3,
87+
"execution_count": 4,
8888
"metadata": {},
8989
"outputs": [
9090
{
@@ -260,30 +260,24 @@
260260
" -> Int(2)\n",
261261
"unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n",
262262
"unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n",
263-
" -> unique_inverse(reshape(NDArray.var(\"y\"), TupleInt(Int(-1))))[Int(0)].shape[Int(0)]\n",
264-
" -> unique_inverse(NDArray.var(\"y\"))[Int(0)].shape[Int(0)]\n"
263+
" -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n",
264+
" -> Int(3)\n"
265265
]
266266
},
267267
{
268-
"ename": "EggSmolError",
269-
"evalue": "Not found: fake expression Int.to_py [Value { tag: \"Int\", bits: 133 }]",
268+
"ename": "AttributeError",
269+
"evalue": "module '__main__' has no attribute 'mean'",
270270
"output_type": "error",
271271
"traceback": [
272272
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
273-
"\u001b[0;31mEggSmolError\u001b[0m Traceback (most recent call last)",
274-
"Cell \u001b[0;32mIn[3], line 680\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 667\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 668\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 669\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 676\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 677\u001b[0m )\n\u001b[0;32m--> 680\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 682\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 683\u001b[0m \n\u001b[1;32m 684\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 685\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
273+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
274+
"Cell \u001b[0;32mIn[4], line 682\u001b[0m\n\u001b[1;32m 668\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 669\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 670\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 671\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 678\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 679\u001b[0m )\n\u001b[0;32m--> 682\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 684\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 685\u001b[0m \n\u001b[1;32m 686\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 687\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
275275
"Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[39mwith\u001b[39;00m config_context(array_api_dispatch\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[39m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[39m=\u001b[39m lda\u001b[39m.\u001b[39;49mfit(X, y)\u001b[39m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[39m=\u001b[39m iris\u001b[39m.\u001b[39mtarget_names\n",
276276
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[39m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[39mreturn\u001b[39;00m fit_method(estimator, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
277277
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcovariance estimator \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mis not supported \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwith svd solver. Try another solver\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_solve_svd(X, y)\n\u001b[1;32m 630\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msolver \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlsqr\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n",
278278
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:501\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 498\u001b[0m n_samples, n_features \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 499\u001b[0m n_classes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclasses_\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m--> 501\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmeans_ \u001b[39m=\u001b[39m _class_means(X, y)\n\u001b[1;32m 502\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstore_covariance:\n\u001b[1;32m 503\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_ \u001b[39m=\u001b[39m _class_cov(X, y, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_)\n",
279-
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:120\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[39mif\u001b[39;00m is_array_api_compliant:\n\u001b[1;32m 119\u001b[0m \u001b[39mprint\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m])\n\u001b[0;32m--> 120\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39;49m(classes\u001b[39m.\u001b[39;49mshape[\u001b[39m0\u001b[39;49m]):\n\u001b[1;32m 121\u001b[0m means[i, :] \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mmean(X[y \u001b[39m==\u001b[39m i], axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[39m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# from a performance-wise\u001b[39;00m\n",
280-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:403\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 401\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[1;32m 402\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_typed_expr__\u001b[39m.\u001b[39mtp\u001b[39m.\u001b[39mname\u001b[39m}\u001b[39;00m\u001b[39m has no method \u001b[39m\u001b[39m{\u001b[39;00m__name\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 403\u001b[0m \u001b[39mreturn\u001b[39;00m method(\u001b[39mself\u001b[39;49m)\n",
281-
"Cell \u001b[0;32mIn[3], line 198\u001b[0m, in \u001b[0;36mInt.__index__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[39m@egraph\u001b[39m\u001b[39m.\u001b[39mmethod(preserve\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 197\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__index__\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mint\u001b[39m:\n\u001b[0;32m--> 198\u001b[0m \u001b[39mreturn\u001b[39;00m extract_py(\u001b[39mself\u001b[39;49m)\n",
282-
"Cell \u001b[0;32mIn[3], line 33\u001b[0m, in \u001b[0;36mextract_py\u001b[0;34m(e)\u001b[0m\n\u001b[1;32m 31\u001b[0m egraph\u001b[39m.\u001b[39mrun((run(runtime_ruleset, limit\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m) \u001b[39m+\u001b[39m run(limit\u001b[39m=\u001b[39m\u001b[39m10\u001b[39m))\u001b[39m.\u001b[39msaturate())\n\u001b[1;32m 32\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m -> \u001b[39m\u001b[39m{\u001b[39;00megraph\u001b[39m.\u001b[39mextract(final_object)\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> 33\u001b[0m res \u001b[39m=\u001b[39m egraph\u001b[39m.\u001b[39mload_object(egraph\u001b[39m.\u001b[39;49mextract(final_object\u001b[39m.\u001b[39;49mto_py()))\n\u001b[1;32m 34\u001b[0m \u001b[39mreturn\u001b[39;00m res\n",
283-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:737\u001b[0m, in \u001b[0;36mEGraph.extract\u001b[0;34m(self, expr)\u001b[0m\n\u001b[1;32m 735\u001b[0m typed_expr \u001b[39m=\u001b[39m expr_parts(expr)\n\u001b[1;32m 736\u001b[0m egg_expr \u001b[39m=\u001b[39m typed_expr\u001b[39m.\u001b[39mto_egg(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mod_decls)\n\u001b[0;32m--> 737\u001b[0m extract_report \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_run_extract(egg_expr, \u001b[39m0\u001b[39;49m)\n\u001b[1;32m 738\u001b[0m new_typed_expr \u001b[39m=\u001b[39m TypedExprDecl\u001b[39m.\u001b[39mfrom_egg(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mod_decls, extract_report\u001b[39m.\u001b[39mexpr)\n\u001b[1;32m 739\u001b[0m \u001b[39mif\u001b[39;00m new_typed_expr\u001b[39m.\u001b[39mtp \u001b[39m!=\u001b[39m typed_expr\u001b[39m.\u001b[39mtp:\n",
284-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:754\u001b[0m, in \u001b[0;36mEGraph._run_extract\u001b[0;34m(self, expr, n)\u001b[0m\n\u001b[1;32m 753\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_run_extract\u001b[39m(\u001b[39mself\u001b[39m, expr: bindings\u001b[39m.\u001b[39m_Expr, n: \u001b[39mint\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m bindings\u001b[39m.\u001b[39mExtractReport:\n\u001b[0;32m--> 754\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_process_commands([bindings\u001b[39m.\u001b[39;49mExtract(n, expr)])\n\u001b[1;32m 755\u001b[0m extract_report \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_egraph\u001b[39m.\u001b[39mextract_report()\n\u001b[1;32m 756\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m extract_report:\n",
285-
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/egraph.py:634\u001b[0m, in \u001b[0;36mEGraph._process_commands\u001b[0;34m(self, commands)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_process_commands\u001b[39m(\u001b[39mself\u001b[39m, commands: Iterable[bindings\u001b[39m.\u001b[39m_Command]) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m--> 634\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_egraph\u001b[39m.\u001b[39;49mrun_program(\u001b[39m*\u001b[39;49mcommands)\n",
286-
"\u001b[0;31mEggSmolError\u001b[0m: Not found: fake expression Int.to_py [Value { tag: \"Int\", bits: 133 }]"
279+
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:121\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[39mprint\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m])\n\u001b[1;32m 120\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]):\n\u001b[0;32m--> 121\u001b[0m means[i, :] \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39;49mmean(X[y \u001b[39m==\u001b[39m i], axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[39m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# from a performance-wise\u001b[39;00m\n\u001b[1;32m 125\u001b[0m cnt \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mbincount(y)\n",
280+
"\u001b[0;31mAttributeError\u001b[0m: module '__main__' has no attribute 'mean'"
287281
]
288282
}
289283
],
@@ -897,6 +891,8 @@
897891
"def _unique_inverse(x: NDArray):\n",
898892
" return [\n",
899893
" rewrite(unique_inverse(x).length()).to(Int(2)),\n",
894+
" # Shape of unique_inverse first element is same as shape of unique_values\n",
895+
" rewrite(unique_inverse(x)[Int(0)].shape).to(unique_values(x).shape),\n",
900896
" ]\n",
901897
"\n",
902898
"@egraph.function\n",
@@ -975,36 +971,6 @@
975971
"# y_arr = NDArray(y_obj)"
976972
]
977973
},
978-
{
979-
"cell_type": "code",
980-
"execution_count": null,
981-
"metadata": {},
982-
"outputs": [],
983-
"source": [
984-
"x = unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n"
985-
]
986-
},
987-
{
988-
"cell_type": "code",
989-
"execution_count": null,
990-
"metadata": {},
991-
"outputs": [
992-
{
993-
"ename": "TypeError",
994-
"evalue": "'RuntimeExpr' object cannot be interpreted as an integer",
995-
"output_type": "error",
996-
"traceback": [
997-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
998-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
999-
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mrange\u001b[39;49m(Int(\u001b[39m10\u001b[39;49m))\n",
1000-
"\u001b[0;31mTypeError\u001b[0m: 'RuntimeExpr' object cannot be interpreted as an integer"
1001-
]
1002-
}
1003-
],
1004-
"source": [
1005-
"range(Int(10))"
1006-
]
1007-
},
1008974
{
1009975
"cell_type": "code",
1010976
"execution_count": null,

0 commit comments

Comments
 (0)