|
34 | 34 | ResourceEnum,
|
35 | 35 | Identifier,
|
36 | 36 | )
|
| 37 | +from sagemaker.feature_store.dataset_builder import ( |
| 38 | + JoinTypeEnum, |
| 39 | +) |
37 | 40 | from sagemaker.session import get_execution_role, Session
|
38 | 41 | from tests.integ.timeout import timeout
|
39 | 42 |
|
@@ -787,6 +790,193 @@ def test_create_dataset_with_feature_group_base(
|
787 | 790 | )
|
788 | 791 |
|
789 | 792 |
|
| 793 | +def test_create_dataset_with_feature_group_base_with_additional_params( |
| 794 | + feature_store_session, |
| 795 | + region_name, |
| 796 | + role, |
| 797 | + base_name, |
| 798 | + feature_group_name, |
| 799 | + offline_store_s3_uri, |
| 800 | + base_dataframe, |
| 801 | + feature_group_dataframe, |
| 802 | +): |
| 803 | + base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session) |
| 804 | + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) |
| 805 | + with cleanup_feature_group(base), cleanup_feature_group(feature_group): |
| 806 | + _create_feature_group_and_ingest_data( |
| 807 | + base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role |
| 808 | + ) |
| 809 | + _create_feature_group_and_ingest_data( |
| 810 | + feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role |
| 811 | + ) |
| 812 | + base_table_name = _get_athena_table_name_after_data_replication( |
| 813 | + feature_store_session, base, offline_store_s3_uri |
| 814 | + ) |
| 815 | + feature_group_table_name = _get_athena_table_name_after_data_replication( |
| 816 | + feature_store_session, feature_group, offline_store_s3_uri |
| 817 | + ) |
| 818 | + |
| 819 | + with timeout(minutes=10) and cleanup_offline_store( |
| 820 | + base_table_name, feature_store_session |
| 821 | + ) and cleanup_offline_store(feature_group_table_name, feature_store_session): |
| 822 | + feature_store = FeatureStore(sagemaker_session=feature_store_session) |
| 823 | + df, query_string = ( |
| 824 | + feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) |
| 825 | + .with_number_of_recent_records_by_record_identifier(4) |
| 826 | + .with_feature_group( |
| 827 | + feature_group, |
| 828 | + target_feature_name_in_base="base_time", |
| 829 | + feature_name_in_target="fg_time", |
| 830 | + join_type=JoinTypeEnum.FULL_JOIN, |
| 831 | + ) |
| 832 | + .to_dataframe() |
| 833 | + ) |
| 834 | + sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True) |
| 835 | + merged_df = base_dataframe.merge( |
| 836 | + feature_group_dataframe, left_on="base_time", right_on="fg_time", how="outer" |
| 837 | + ) |
| 838 | + |
| 839 | + expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True) |
| 840 | + |
| 841 | + expect_df.rename( |
| 842 | + columns={ |
| 843 | + "fg_id": "fg_id.1", |
| 844 | + "fg_time": "fg_time.1", |
| 845 | + "fg_feature_1": "fg_feature_1.1", |
| 846 | + "fg_feature_2": "fg_feature_2.1", |
| 847 | + }, |
| 848 | + inplace=True, |
| 849 | + ) |
| 850 | + |
| 851 | + assert sorted_df.equals(expect_df) |
| 852 | + assert ( |
| 853 | + query_string |
| 854 | + == "WITH fg_base AS (WITH table_base AS (\n" |
| 855 | + + "SELECT *\n" |
| 856 | + + "FROM (\n" |
| 857 | + + "SELECT *, row_number() OVER (\n" |
| 858 | + + 'PARTITION BY origin_base."base_id", origin_base."base_time"\n' |
| 859 | + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' |
| 860 | + + ") AS dedup_row_base\n" |
| 861 | + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' |
| 862 | + + ")\n" |
| 863 | + + "WHERE dedup_row_base = 1\n" |
| 864 | + + "),\n" |
| 865 | + + "deleted_base AS (\n" |
| 866 | + + "SELECT *\n" |
| 867 | + + "FROM (\n" |
| 868 | + + "SELECT *, row_number() OVER (\n" |
| 869 | + + 'PARTITION BY origin_base."base_id"\n' |
| 870 | + + 'ORDER BY origin_base."base_time" DESC,' |
| 871 | + ' origin_base."api_invocation_time" DESC,' |
| 872 | + ' origin_base."write_time" DESC\n' |
| 873 | + + ") AS deleted_row_base\n" |
| 874 | + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' |
| 875 | + + "WHERE is_deleted\n" |
| 876 | + + ")\n" |
| 877 | + + "WHERE deleted_row_base = 1\n" |
| 878 | + + ")\n" |
| 879 | + + 'SELECT table_base."base_id", table_base."base_time",' |
| 880 | + ' table_base."base_feature_1", table_base."base_feature_2"\n' |
| 881 | + + "FROM (\n" |
| 882 | + + 'SELECT table_base."base_id", table_base."base_time",' |
| 883 | + ' table_base."base_feature_1", table_base."base_feature_2",' |
| 884 | + ' table_base."write_time"\n' |
| 885 | + + "FROM table_base\n" |
| 886 | + + "LEFT JOIN deleted_base\n" |
| 887 | + + 'ON table_base."base_id" = deleted_base."base_id"\n' |
| 888 | + + 'WHERE deleted_base."base_id" IS NULL\n' |
| 889 | + + "UNION ALL\n" |
| 890 | + + 'SELECT table_base."base_id", table_base."base_time",' |
| 891 | + ' table_base."base_feature_1", table_base."base_feature_2",' |
| 892 | + ' table_base."write_time"\n' |
| 893 | + + "FROM deleted_base\n" |
| 894 | + + "JOIN table_base\n" |
| 895 | + + 'ON table_base."base_id" = deleted_base."base_id"\n' |
| 896 | + + "AND (\n" |
| 897 | + + 'table_base."base_time" > deleted_base."base_time"\n' |
| 898 | + + 'OR (table_base."base_time" = deleted_base."base_time" AND' |
| 899 | + ' table_base."api_invocation_time" >' |
| 900 | + ' deleted_base."api_invocation_time")\n' |
| 901 | + + 'OR (table_base."base_time" = deleted_base."base_time" AND' |
| 902 | + ' table_base."api_invocation_time" =' |
| 903 | + ' deleted_base."api_invocation_time" AND' |
| 904 | + ' table_base."write_time" > deleted_base."write_time")\n' |
| 905 | + + ")\n" |
| 906 | + + ") AS table_base\n" |
| 907 | + + "),\n" |
| 908 | + + "fg_0 AS (WITH table_0 AS (\n" |
| 909 | + + "SELECT *\n" |
| 910 | + + "FROM (\n" |
| 911 | + + "SELECT *, row_number() OVER (\n" |
| 912 | + + 'PARTITION BY origin_0."fg_id", origin_0."fg_time"\n' |
| 913 | + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' |
| 914 | + + ") AS dedup_row_0\n" |
| 915 | + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' |
| 916 | + + ")\n" |
| 917 | + + "WHERE dedup_row_0 = 1\n" |
| 918 | + + "),\n" |
| 919 | + + "deleted_0 AS (\n" |
| 920 | + + "SELECT *\n" |
| 921 | + + "FROM (\n" |
| 922 | + + "SELECT *, row_number() OVER (\n" |
| 923 | + + 'PARTITION BY origin_0."fg_id"\n' |
| 924 | + + 'ORDER BY origin_0."fg_time" DESC, origin_0."api_invocation_time" DESC,' |
| 925 | + ' origin_0."write_time" DESC\n' |
| 926 | + + ") AS deleted_row_0\n" |
| 927 | + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' |
| 928 | + + "WHERE is_deleted\n" |
| 929 | + + ")\n" |
| 930 | + + "WHERE deleted_row_0 = 1\n" |
| 931 | + + ")\n" |
| 932 | + + 'SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1",' |
| 933 | + ' table_0."fg_feature_2"\n' |
| 934 | + + "FROM (\n" |
| 935 | + + 'SELECT table_0."fg_id", table_0."fg_time",' |
| 936 | + ' table_0."fg_feature_1", table_0."fg_feature_2",' |
| 937 | + ' table_0."write_time"\n' |
| 938 | + + "FROM table_0\n" |
| 939 | + + "LEFT JOIN deleted_0\n" |
| 940 | + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' |
| 941 | + + 'WHERE deleted_0."fg_id" IS NULL\n' |
| 942 | + + "UNION ALL\n" |
| 943 | + + 'SELECT table_0."fg_id", table_0."fg_time",' |
| 944 | + ' table_0."fg_feature_1", table_0."fg_feature_2",' |
| 945 | + ' table_0."write_time"\n' |
| 946 | + + "FROM deleted_0\n" |
| 947 | + + "JOIN table_0\n" |
| 948 | + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' |
| 949 | + + "AND (\n" |
| 950 | + + 'table_0."fg_time" > deleted_0."fg_time"\n' |
| 951 | + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' |
| 952 | + ' table_0."api_invocation_time" >' |
| 953 | + ' deleted_0."api_invocation_time")\n' |
| 954 | + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' |
| 955 | + ' table_0."api_invocation_time" =' |
| 956 | + ' deleted_0."api_invocation_time" AND table_0."write_time" >' |
| 957 | + ' deleted_0."write_time")\n' |
| 958 | + + ")\n" |
| 959 | + + ") AS table_0\n" |
| 960 | + + ")\n" |
| 961 | + + "SELECT base_id, base_time, base_feature_1, base_feature_2," |
| 962 | + ' "fg_id.1", "fg_time.1", "fg_feature_1.1",' |
| 963 | + ' "fg_feature_2.1"\n' + "FROM (\n" + "SELECT fg_base.base_id, fg_base.base_time," |
| 964 | + " fg_base.base_feature_1, fg_base.base_feature_2," |
| 965 | + ' fg_0."fg_id" as "fg_id.1", fg_0."fg_time" as "fg_time.1",' |
| 966 | + ' fg_0."fg_feature_1" as "fg_feature_1.1",' |
| 967 | + ' fg_0."fg_feature_2" as "fg_feature_2.1", row_number()' |
| 968 | + " OVER (\n" |
| 969 | + + 'PARTITION BY fg_base."base_id"\n' |
| 970 | + + 'ORDER BY fg_base."base_time" DESC, fg_0."fg_time" DESC\n' |
| 971 | + + ") AS row_recent\n" |
| 972 | + + "FROM fg_base\n" |
| 973 | + + "FULL JOIN fg_0\n" |
| 974 | + + 'ON fg_base."base_time" = fg_0."fg_time"\n' |
| 975 | + + ")\n" |
| 976 | + + "WHERE row_recent <= 4" |
| 977 | + ) |
| 978 | + |
| 979 | + |
790 | 980 | def _create_feature_group_and_ingest_data(
|
791 | 981 | feature_group: FeatureGroup,
|
792 | 982 | dataframe: DataFrame,
|
|
0 commit comments