|
84 | 84 | },
|
85 | 85 | {
|
86 | 86 | "cell_type": "code",
|
87 |
| - "execution_count": 6, |
| 87 | + "execution_count": 9, |
88 | 88 | "metadata": {},
|
89 | 89 | "outputs": [
|
90 | 90 | {
|
|
265 | 265 | ]
|
266 | 266 | },
|
267 | 267 | {
|
268 |
| - "ename": "AttributeError", |
269 |
| - "evalue": "Class NDArray does not have method __eq__", |
| 268 | + "ename": "TypeError", |
| 269 | + "evalue": "'RuntimeExpr' object does not support item assignment", |
270 | 270 | "output_type": "error",
|
271 | 271 | "traceback": [
|
272 | 272 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
273 |
| - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", |
274 |
| - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:388\u001b[0m, in \u001b[0;36m_special_method\u001b[0;34m(self, __name, *args)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 388\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_decls__\u001b[39m.\u001b[39;49mget_class_decl(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_typed_expr__\u001b[39m.\u001b[39;49mtp\u001b[39m.\u001b[39;49mname)\u001b[39m.\u001b[39;49mpreserved_methods[__name]\n\u001b[1;32m 389\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n", |
275 |
| - "\u001b[0;31mKeyError\u001b[0m: '__eq__'", |
276 |
| - "\nDuring handling of the above exception, another exception occurred:\n", |
277 |
| - "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", |
278 |
| - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:318\u001b[0m, in \u001b[0;36mRuntimeMethod.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 318\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_fn_decl__ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_decls__\u001b[39m.\u001b[39;49mget_function_decl(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_callable_ref__)\n\u001b[1;32m 319\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n", |
279 |
| - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/declarations.py:192\u001b[0m, in \u001b[0;36mModuleDeclarations.get_function_decl\u001b[0;34m(self, ref)\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[39mpass\u001b[39;00m\n\u001b[0;32m--> 192\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mFunction \u001b[39m\u001b[39m{\u001b[39;00mref\u001b[39m}\u001b[39;00m\u001b[39m not found\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 193\u001b[0m \u001b[39melse\u001b[39;00m:\n", |
280 |
| - "\u001b[0;31mKeyError\u001b[0m: \"Function MethodRef(class_name='NDArray', method_name='__eq__') not found\"", |
281 |
| - "\nDuring handling of the above exception, another exception occurred:\n", |
282 |
| - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", |
283 |
| - "Cell \u001b[0;32mIn[6], line 700\u001b[0m\n\u001b[1;32m 686\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 687\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 688\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 689\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 696\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 697\u001b[0m )\n\u001b[0;32m--> 700\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 702\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 703\u001b[0m \n\u001b[1;32m 704\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 705\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n", |
| 273 | + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", |
| 274 | + "Cell \u001b[0;32mIn[9], line 711\u001b[0m\n\u001b[1;32m 697\u001b[0m \u001b[39m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 698\u001b[0m egraph\u001b[39m.\u001b[39mregister(\n\u001b[1;32m 699\u001b[0m rewrite(X_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(X\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[1;32m 700\u001b[0m rewrite(y_arr\u001b[39m.\u001b[39mdtype, runtime_ruleset)\u001b[39m.\u001b[39mto(convert(y\u001b[39m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 707\u001b[0m rewrite(unique_values(y_arr)\u001b[39m.\u001b[39mshape)\u001b[39m.\u001b[39mto(TupleInt(Int(\u001b[39m3\u001b[39m))),\n\u001b[1;32m 708\u001b[0m )\n\u001b[0;32m--> 711\u001b[0m res \u001b[39m=\u001b[39m fit(X_arr, y_arr)\n\u001b[1;32m 713\u001b[0m \u001b[39m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 714\u001b[0m \n\u001b[1;32m 715\u001b[0m \u001b[39m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 716\u001b[0m \u001b[39m# y_arr = NDArray(y_obj)\u001b[39;00m\n", |
284 | 275 | "Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[39mwith\u001b[39;00m config_context(array_api_dispatch\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[39m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[39m=\u001b[39m\u001b[39m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[39m=\u001b[39m lda\u001b[39m.\u001b[39;49mfit(X, y)\u001b[39m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[39mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[39m=\u001b[39m iris\u001b[39m.\u001b[39mtarget_names\n",
|
285 | 276 | "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[39m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[39mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[39m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[39mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[39mreturn\u001b[39;00m fit_method(estimator, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
286 | 277 | "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mcovariance estimator \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mis not supported \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwith svd solver. Try another solver\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_solve_svd(X, y)\n\u001b[1;32m 630\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39msolver \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlsqr\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n",
|
287 | 278 | "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:501\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 498\u001b[0m n_samples, n_features \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape\n\u001b[1;32m 499\u001b[0m n_classes \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclasses_\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[0;32m--> 501\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmeans_ \u001b[39m=\u001b[39m _class_means(X, y)\n\u001b[1;32m 502\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstore_covariance:\n\u001b[1;32m 503\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcovariance_ \u001b[39m=\u001b[39m _class_cov(X, y, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpriors_)\n",
|
288 |
| - "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:121\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[39mprint\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m])\n\u001b[1;32m 120\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]):\n\u001b[0;32m--> 121\u001b[0m means[i, :] \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mmean(X[y \u001b[39m==\u001b[39;49m i], axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[39m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# from a performance-wise\u001b[39;00m\n\u001b[1;32m 125\u001b[0m cnt \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mbincount(y)\n", |
289 |
| - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:390\u001b[0m, in \u001b[0;36m_special_method\u001b[0;34m(self, __name, *args)\u001b[0m\n\u001b[1;32m 388\u001b[0m method \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_decls__\u001b[39m.\u001b[39mget_class_decl(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_typed_expr__\u001b[39m.\u001b[39mtp\u001b[39m.\u001b[39mname)\u001b[39m.\u001b[39mpreserved_methods[__name]\n\u001b[1;32m 389\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[0;32m--> 390\u001b[0m \u001b[39mreturn\u001b[39;00m RuntimeMethod(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_decls__, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__egg_typed_expr__, __name)(\u001b[39m*\u001b[39margs)\n\u001b[1;32m 391\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 392\u001b[0m \u001b[39mreturn\u001b[39;00m method(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs)\n", |
290 |
| - "File \u001b[0;32m<string>:6\u001b[0m, in \u001b[0;36m__init__\u001b[0;34m(self, __egg_decls__, __egg_typed_expr__, __egg_method_name__)\u001b[0m\n", |
291 |
| - "File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:320\u001b[0m, in \u001b[0;36mRuntimeMethod.__post_init__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_fn_decl__ \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_decls__\u001b[39m.\u001b[39mget_function_decl(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_callable_ref__)\n\u001b[1;32m 319\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m:\n\u001b[0;32m--> 320\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mClass \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mclass_name\u001b[39m}\u001b[39;00m\u001b[39m does not have method \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__egg_method_name__\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n", |
292 |
| - "\u001b[0;31mAttributeError\u001b[0m: Class NDArray does not have method __eq__" |
| 279 | + "File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:121\u001b[0m, in \u001b[0;36m_class_means\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[39mprint\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m])\n\u001b[1;32m 120\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(classes\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]):\n\u001b[0;32m--> 121\u001b[0m means[i, :] \u001b[39m=\u001b[39m xp\u001b[39m.\u001b[39mmean(X[y \u001b[39m==\u001b[39m i], axis\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 122\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 123\u001b[0m \u001b[39m# TODO: Explore the choice of using bincount + add.at as it seems sub optimal\u001b[39;00m\n\u001b[1;32m 124\u001b[0m \u001b[39m# from a performance-wise\u001b[39;00m\n\u001b[1;32m 125\u001b[0m cnt \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mbincount(y)\n", |
| 280 | + "\u001b[0;31mTypeError\u001b[0m: 'RuntimeExpr' object does not support item assignment" |
293 | 281 | ]
|
294 | 282 | }
|
295 | 283 | ],
|
|
592 | 580 | " ...\n",
|
593 | 581 | "\n",
|
594 | 582 | "\n",
|
| 583 | + "\n", |
595 | 584 | "converter(tuple, IndexKey, lambda x: IndexKey.tuple_int(convert(x, TupleInt)))\n",
|
596 | 585 | "converter(int, IndexKey, lambda x: IndexKey.int(Int(x)))\n",
|
597 | 586 | "converter(Int, IndexKey, lambda x: IndexKey.int(x))\n",
|
|
662 | 651 | "\n",
|
663 | 652 | " def __gt__(self, other: NDArray) -> NDArray:\n",
|
664 | 653 | " ...\n",
|
| 654 | + " \n", |
| 655 | + " def __eq__(self, other: NDArray) -> NDArray:\n", |
| 656 | + " ...\n", |
665 | 657 | "\n",
|
666 | 658 | " @classmethod\n",
|
667 | 659 | " def scalar_float(cls, other: Float) -> NDArray:\n",
|
|
675 | 667 | " def scalar_bool(cls, other: Bool) -> NDArray:\n",
|
676 | 668 | " ...\n",
|
677 | 669 | "\n",
|
| 670 | + "@egraph.function\n", |
| 671 | + "def ndarray_index(x: NDArray) -> IndexKey:\n", |
| 672 | + " ...\n", |
| 673 | + "\n", |
| 674 | + "converter(NDArray, IndexKey, ndarray_index)\n", |
| 675 | + "\n", |
| 676 | + "\n", |
678 | 677 | "\n",
|
679 | 678 | "converter(float, NDArray, lambda x: NDArray.scalar_float(Float(x)))\n",
|
680 | 679 | "converter(int, NDArray, lambda x: NDArray.scalar_int(Int(x)))\n",
|
|
0 commit comments