diff --git a/.github/workflows/install.yaml b/.github/workflows/install.yaml index 577ca39b..27535ebd 100644 --- a/.github/workflows/install.yaml +++ b/.github/workflows/install.yaml @@ -12,7 +12,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python_version: ["3.8", "3.11"] + python_version: ["3.8", "3.12"] runs-on: ${{ matrix.os }} steps: - name: Set up python ${{ matrix.python_version }} diff --git a/.github/workflows/lint_check.yaml b/.github/workflows/lint_check.yaml index 4979df94..866d5bdf 100644 --- a/.github/workflows/lint_check.yaml +++ b/.github/workflows/lint_check.yaml @@ -2,14 +2,14 @@ name: Lint Check on: [pull_request] jobs: lint_check: - name: 3.11 lint check + name: 3.12 lint check runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.11 + - name: Set up Python 3.12 uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.12' - name: Install package with dev deps run: | python -m pip install -e ".[dev]" diff --git a/.github/workflows/test_without_dev_deps.yaml b/.github/workflows/test_without_dev_deps.yaml index bca2d557..308e51ec 100644 --- a/.github/workflows/test_without_dev_deps.yaml +++ b/.github/workflows/test_without_dev_deps.yaml @@ -11,14 +11,14 @@ on: workflow_dispatch: jobs: tests: - name: 3.11 test + name: 3.12 test runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up python 3.11 + - name: Set up python 3.12 uses: actions/setup-python@v4 with: - python-version: 3.11 + python-version: 3.12 cache: 'pip' cache-dependency-path: "pyproject.toml" - name: Build trane and install latest requirements diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 8392cd06..67f7beb3 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -15,11 +15,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.11"] - type-of-tests: ["unit", "integration"] - exclude: - - python-version: "3.8" - type-of-tests: "integration" + python-version: ["3.8", "3.12"] + type-of-tests: ["unit"] steps: - uses: actions/checkout@v3 - name: Set up python ${{ matrix.python-version }} @@ -40,12 +37,11 @@ jobs: if: (steps.cache.outputs.cache-hit == 'true') && ( github.event.pull_request.title != 'Automated Latest Dependency Updates') run: python -m pip install --no-dependencies . - name: Run unit tests - if: ${{ matrix.type-of-tests != 'integration' }} run: | make clean make unit-tests - name: Upload code coverage report - if: ${{ matrix.type-of-tests == 'unit' && matrix.python-version == '3.11' }} + if: ${{ matrix.type-of-tests == 'unit' && matrix.python-version == '3.12' }} uses: codecov/codecov-action@v3 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45dee235..a8372603 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,18 +17,15 @@ repos: hooks: - id: add-trailing-comma name: Add trailing comma - - repo: https://github.com/python/black - rev: 23.10.1 - hooks: - - id: black - additional_dependencies: [".[jupyter]"] - types_or: [python, jupyter] - args: - - --config=./pyproject.toml - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.1.4 + rev: v0.1.8 hooks: - id: ruff + types_or: [ python, pyi, jupyter ] args: - --fix - --config=./pyproject.toml + - id: ruff-format + types_or: [ python, pyi, jupyter ] + args: + - --config=./pyproject.toml diff --git a/Makefile b/Makefile index 3c9e2a94..324d5a54 100755 --- a/Makefile +++ b/Makefile @@ -6,15 +6,16 @@ clean: find . -name '*~' -delete find . -name '.coverage.*' -delete +LINT_CONFIG = trane/ tests/ --config=./pyproject.toml .PHONY: lint lint: - black trane/ tests/ --check --config=./pyproject.toml - ruff trane/ tests/ --config=./pyproject.toml + ruff check $(LINT_CONFIG) + ruff format --check $(LINT_CONFIG) .PHONY: lint-fix lint-fix: - black trane/ tests/ --config=./pyproject.toml - ruff trane/ tests/ --fix --config=./pyproject.toml + ruff check --fix $(LINT_CONFIG) + ruff format $(LINT_CONFIG) .PHONY: installdeps-dev installdeps-dev: @@ -38,11 +39,7 @@ tests: .PHONY: unit-tests unit-tests: - $(PYTEST) tests/ --ignore=tests/integration_tests $(COVERAGE) - -.PHONY: integration-tests -integration-tests: - $(PYTEST) tests/integration_tests + $(PYTEST) tests/ $(COVERAGE) .PHONY: upgradepip upgradepip: diff --git a/docs/changelog.md b/docs/changelog.md index b85802cb..52f486ee 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,17 @@ Changelog --------- +v0.8.0 (December X, X) +========================= +* Enhancements + * Update LLM helper to support updated GPT-4 models [#174][#174] + * Update ruff to latest and remove black as a development dependency [#174][#174] + * Add Python 3.11 markers and CI testing [#174][#174] +* Fixes + * + + [#174]: + v0.7.0 (October 21, 2023) ========================= * Enhancements diff --git a/pyproject.toml b/pyproject.toml index aea6adcc..8d88732d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", "Operating System :: Unix", @@ -55,13 +56,12 @@ test = [ "pytest-runner >= 2.11.1", ] dev = [ - "ruff >= 0.1.0" , - "black[jupyter] >= 22.12.0", - "pre-commit == 2.20.0", - "toml >= 0.10.2", + "ruff >= 0.1.8" , + "pre-commit >= 3.6.0", ] llm = [ - "openai >= 0.28.1", + "openai >= 1.3.7", + "anthropic >= 0.7.7", "tiktoken >= 0.5.1", ] @@ -128,17 +128,10 @@ requires = [ ] build-backend = "setuptools.build_meta" -[tool.black] -line-length = 88 -target-version = ["py311"] - [tool.ruff] preview = true line-length = 88 ignore = ["E501"] -exclude = [ - "Examples", -] select = [ # Pyflakes "F", @@ -151,9 +144,15 @@ select = [ "I001" ] src = ["trane"] +target-version = "py312" [tool.ruff.isort] known-first-party = ["trane"] [tool.ruff.per-file-ignores] "__init__.py" = ["F401", "E402", "F403", "F405", "E501", "I001"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +preview = true diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration_tests/test_examples.py b/tests/integration_tests/test_examples.py deleted file mode 100644 index 8a5db981..00000000 --- a/tests/integration_tests/test_examples.py +++ /dev/null @@ -1,29 +0,0 @@ -import pandas as pd -from tqdm import tqdm - -from trane.core.problem_generator import ProblemGenerator -from trane.datasets.load_functions import load_airbnb - - -def test_airbnb_reviews(): - data, metadata = load_airbnb(nrows=1000) - assert data["id"].is_unique - window_size = "1m" - problem_generator = ProblemGenerator( - metadata=metadata, - window_size=window_size, - ) - problems = problem_generator.generate() - print(f"generated {len(problems)} problems from {data.shape} columns") - num_target_values_created = 0 - for p in tqdm(problems): - if p.has_parameters_set() is True: - labels = p.create_target_values(data) - num_target_values_created += 1 - if "target" not in labels.columns: - continue - if pd.api.types.is_bool_dtype(labels["target"].dtype): - assert p.get_problem_type() == "classification" - else: - assert p.get_problem_type() == "regression" - print(f"created {num_target_values_created} target values") diff --git a/tests/integration_tests/utils.py b/tests/integration_tests/utils.py deleted file mode 100644 index f4c1de76..00000000 --- a/tests/integration_tests/utils.py +++ /dev/null @@ -1,152 +0,0 @@ -import trane -from trane.ops.aggregation_ops import ( - AggregationOpBase, - AvgAggregationOp, - CountAggregationOp, - ExistsAggregationOp, - FirstAggregationOp, - LastAggregationOp, - MajorityAggregationOp, - MaxAggregationOp, - MinAggregationOp, - SumAggregationOp, -) -from trane.ops.filter_ops import ( - AllFilterOp, - EqFilterOp, - FilterOpBase, - GreaterFilterOp, - LessFilterOp, - NeqFilterOp, -) -from trane.ops.transformation_ops import IdentityOp, OrderByOp, TransformationOpBase -from trane.utils import multiprocess_prediction_problem - -agg_op_str_dict = { - SumAggregationOp: " the total <{}> in all related records", - AvgAggregationOp: " the average <{}> in all related records", - MaxAggregationOp: " the maximum <{}> in all related records", - MinAggregationOp: " the minimum <{}> in all related records", - MajorityAggregationOp: " the majority <{}> in all related records", - CountAggregationOp: " the number of records", - ExistsAggregationOp: " if there exists a record", - FirstAggregationOp: " the first <{}> in all related records", - LastAggregationOp: " the last <{}> in all related records", -} - -transform_op_str_dict = { - IdentityOp: "", - OrderByOp: " sorted by <{}>", -} - -transform_op_str_dict = { - IdentityOp: "", - OrderByOp: " sorted by <{}>", -} - -transform_op_str_dict = { - IdentityOp: "", - OrderByOp: " sorted by <{}>", -} - -filter_op_str_dict = { - GreaterFilterOp: "greater than", - EqFilterOp: "equal to", - NeqFilterOp: "not equal to", - LessFilterOp: "less than", - # TODO: figure out the string for this operation - AllFilterOp: "", -} - - -def generate_and_verify_prediction_problem( - df, - meta, - entity_col, - time_col, - cutoff_strategy, - sample=None, - use_multiprocess=False, -): - prediction_problem_to_label_times = {} - cutoff = cutoff_strategy.window_size - problem_generator = trane.PredictionProblemGenerator( - df=df, - table_meta=meta, - entity_col=entity_col, - cutoff_strategy=cutoff_strategy, - time_col=time_col, - ) - problems = problem_generator.generate(df, generate_thresholds=True) - unique_entity_ids = df[entity_col].nunique() - for p in problems: - assert p.entity_col == entity_col - assert p.time_col == time_col - expected_problem_pre = f"For each <{entity_col}> predict " - expected_problem_end = f"in next {cutoff} days" - p_str = str(p) - assert p_str.startswith(expected_problem_pre) - assert p_str.endswith(expected_problem_end) - agg_column_name = None # noqa - filter_column_name = None # noqa - filter_threshold_value = None # noqa - for op in p.operations: - if isinstance(op, AggregationOpBase): - expected_agg_str = agg_op_str_dict[op.__class__] - expected_agg_str = expected_agg_str.replace( - "<{}>", - f"<{op.column_name}>", - ) - assert expected_agg_str in p_str - if op.column_name: - # agg_column_name - _ = op.column_name - elif isinstance(op, TransformationOpBase): - expected_transform_str = transform_op_str_dict[op.__class__] - expected_transform_str = expected_transform_str.replace( - "<{}>", - f"<{op.column_name}>", - ) - assert expected_transform_str in p_str - if op.column_name: - # filter_column_name - _ = op.column_name - elif isinstance(op, FilterOpBase): - expected_filter_str = filter_op_str_dict[op.__class__] - threshold = op.threshold - assert expected_filter_str in p_str - if op.column_name: - # filter_column_name - _ = op.column_name - if threshold: - # filter_threshold_value - _ = threshold - else: - raise ValueError( - f"Unexpected prediction problem generated: {p_str}: {p.operations}", - ) - if not use_multiprocess: - label_times = p.execute(df, -1) - assert label_times.target_dataframe_index == entity_col - # TODO: fix bug with Filter Operation results in labels that has target == 0 - # Below is not an ideal way to check the prediction problems - # (because it has less than, rather than exact number of unique instances) - if not label_times.empty: - assert label_times[entity_col].nunique() <= unique_entity_ids - prediction_problem_to_label_times[p_str] = label_times - - if use_multiprocess: - prediction_problem_to_label_times = multiprocess_prediction_problem( - problems, - df, - ) - return prediction_problem_to_label_times - - -def check_label_times(label_times, entity_col, unique_entity_ids): - assert label_times.target_dataframe_index == entity_col - # TODO: fix bug with Filter Operation results in labels that has target == 0 - # Below is not an ideal way to check the prediction problems - # (because it has less than, rather than exact number of unique instances) - if not label_times.empty: - assert label_times[entity_col].nunique() <= unique_entity_ids diff --git a/tests/test_calculate_target_values.py b/tests/test_calculate_target_values.py index 7fcb6736..99a08586 100644 --- a/tests/test_calculate_target_values.py +++ b/tests/test_calculate_target_values.py @@ -2,34 +2,55 @@ import pytest from trane.core.utils import ( - determine_gap_size, generate_data_slices, set_dataframe_index, ) from trane.metadata import SingleTableMetadata -from trane.utils.testing_utils import ( - create_mock_data, - create_mock_data_metadata, -) +from trane.utils import create_mock_data, create_mock_data_metadata -@pytest.fixture() -def data(): - df = create_mock_data(return_single_dataframe=True) - return df +def test_set_dataframe_index(): + df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + df = set_dataframe_index(df, "A") + assert df.index.name == "A" + df = set_dataframe_index(df, "B") + assert df.index.name == "B" -@pytest.fixture() -def cutoff_df(): +@pytest.mark.parametrize( + "window_size, gap, expected_dataslices", + [ + ("1d", "1d", [["A"], ["B"], ["C"], ["D"], ["E"], ["F"], ["G"], ["H"]]), + ("2d", "2d", [["A", "B"], ["C", "D"], ["E", "F"], ["G", "H"]]), + ("3d", "3d", [["A", "B", "C"], ["D", "E", "F"], ["G", "H"]]), + ("3d", "3d", [["A", "B", "C"], ["D", "E", "F"], ["G", "H"]]), + ], +) +def test_generate_data_slices(window_size, gap, expected_dataslices): + # Timestamps: t0 t1 t2 t3 t4 t5 t6 t7 + # Data: A B C D E F G H + df = pd.DataFrame( - {"A": [1, 2, 3, 4, 5]}, - index=pd.date_range("2022-01-01", periods=5, freq="D"), + { + "data": list("ABCDEFGH"), + "timestamp": pd.date_range(start="2022-01-01", end="2022-01-08", freq="D"), + }, ) - df.index = pd.to_datetime(df.index) - return df - - -def test_create_mock_data_metadata(data): + df = set_dataframe_index(df, "timestamp") + # start_times = ["2022-01-01", "2022-01-03", "2022-01-05", "2022-01-07"] + # end_times = ["2022-01-02", "2022-01-04", "2022-01-06", "2022-01-08"] + for dataslice, metadata in generate_data_slices( + df=df, + window_size=window_size, + gap=gap, + ): + assert dataslice["data"].tolist() == expected_dataslices.pop(0) + # TODO: Verify start times + assert len(expected_dataslices) == 0 + + +def test_create_mock_data_metadata(): + data = create_mock_data(return_single_dataframe=True) metadata = create_mock_data_metadata(single_table=True) assert isinstance(metadata, SingleTableMetadata) for col, ml_type in metadata.ml_types.items(): @@ -42,195 +63,3 @@ def test_create_mock_data_metadata(data): assert col in metadata.ml_types assert len(metadata.ml_types) == len(data.columns) assert data["transaction_id"].is_unique - - -def test_set_dataframe_index(): - df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - df = set_dataframe_index(df, "A") - assert df.index.name == "A" - - df = set_dataframe_index(df, "B") - assert df.index.name == "B" - - -def test_determine_gap_size(): - assert determine_gap_size("1d") == pd.Timedelta("1d") - assert determine_gap_size(2) == 2 - - -def test_integer_window_size(): - # Data: A B C D E F G H I J K L - - # Slice 1: | A B C | D E F G H I J K L - # Slice 2: A B C | D E F | G H I J K L - # Slice 3: A B C D E F | G H I | J K L - # Slice 4: A B C D E F G H I | J K L | - # |...| - Data window of size 3 - # Space between |...| - Gap of 0 - # A gap of 0 means that the windows are contiguous, i.e., there is no spacing between the end of one window and the start of the next. - - df = pd.DataFrame( - { - "data": list("ABCDEFGHIJKL"), - }, - ) - slices = list(generate_data_slices(df, gap=0, window_size=3)) - expected_slices = [ - ["A", "B", "C"], - ["D", "E", "F"], - ["G", "H", "I"], - ["J", "K", "L"], - ] - check_data_slices(slices, expected_slices) - # Slice 1: | A B C | D E F G H I J K L - # Slice 2: A B C D | E F G | H I J K L - # Slice 3: A B C D E F G H | I J K | L - # |...| - Data window of size 3 - # Space between |...| - Gap of 1 - # A gap of 1 means that you skip one element in the data sequence when starting a new slice. - slices = list(generate_data_slices(df, gap=1, window_size=3)) - expected_slices = [ - ["A", "B", "C"], - ["E", "F", "G"], - ["I", "J", "K"], - ] - check_data_slices(slices, expected_slices) - - -def test_timedelta_window_size(): - # Timestamps: t0 t1 t2 t3 t4 t5 t6 t7 - # Data: A B C D E F G H - - # Slice 1: | A B | C D E F G H (t0 to t1) - # Slice 2: A B | C D | E F G H (t2 to t3) - # Slice 3: A B C D | E F | G H (t4 to t5) - # Slice 4: A B C D E F | G H | (t6 to t7) - # |...| - Data window of timedelta size - # Space between |...| - Gap of 0 timedleta - - df = pd.DataFrame( - {"data": list("ABCDEFGH")}, - index=pd.date_range("2022-01-01", periods=8, freq="D"), - ) - - slices = list( - generate_data_slices( - df, - window_size=pd.Timedelta(days=2), - gap=0, - ), - ) - expected_slices = [ - ["A", "B"], - ["C", "D"], - ["E", "F"], - ["G", "H"], - ] - check_data_slices(slices, expected_slices) - - # Timestamps: t0 t1 t2 t3 t4 t5 t6 t7 - # Data: A B C D E F G H - - # Slice 1: | A B | C D E F G H (t0 to t1) - # Slice 2: A B C | D E | F G H (t3 to t4) - # Slice 3: A B C D E F | G H | (t6 to t7) - # |...| - Data window of timedelta size - # Space between |...| - Gap of 1 delta - slices = list( - generate_data_slices( - df, - window_size=pd.Timedelta(days=2), - gap=1, - ), - ) - expected_slices = [ - ["A", "B"], - ["D", "E"], - ["G", "H"], - ] - check_data_slices(slices, expected_slices) - - -def test_gap_timedelta_window_size_timedelta(): - # Timestamps: t0 t1 t2 t3 t4 t5 t6 t7 - # Data: A B C D E F G H - # Slice 1: | A | B C D E F G H (t0 to t0) - # Slice 2: A B | C | D E F G H (t3 to t3) - # Slice 3: A B C D | E | F G H (t4 to t4) - # Slice 4: A B C D E F | G | H (t5 to t5) - - df = pd.DataFrame( - {"value": list("ABCDEFGH")}, - index=pd.date_range(start="2022-01-01", end="2022-01-08", freq="D"), - ) - slices = list( - generate_data_slices( - df, - window_size=pd.Timedelta(days=1), - gap=pd.Timedelta(days=1), - ), - ) - expected_values = [ - ["A"], - ["C"], - ["E"], - ["G"], - ] - for i, (dataslice, meta) in enumerate(slices): - assert dataslice["value"].tolist() == expected_values[i] - - -def test_with_gap_larger_than_window_size(): - # Index: 0 1 2 3 4 5 6 7 - # Data: A B C D E F G H - - # Slice 1: | A B | C D E F G H (t0 to t1) - # Slice 2: A B C D | E F | G H (t4 to t5) - - # Legend: - # |...| - Data window of size 2 - # Space between |...| - Gap of 4 units - # We add the gap of 4, which brings us to index 5. - # The second data slice should thus be ['F', 'G'] if we keep the window size of 2. - df = pd.DataFrame({"data": list("ABCDEFGH")}) - slices = list(generate_data_slices(df, window_size=2, gap=3)) - expected_slices = [["A", "B"], ["F", "G"]] - check_data_slices(slices, expected_slices) - - -def check_data_slices(slices, expected_slices): - for slice_data, expected_data in zip(slices, expected_slices): - slice_vals = slice_data[0]["data"].tolist() - assert ( - slice_vals == expected_data - ), f"Expected {expected_data}, but got {slice_vals}" - - -def sum_amount(dataslice): - total = dataslice["amount"].sum() - return total - - -# def test_create_target_values(data): -# # ensure each row is included in the target values -# # make sure each dataslice gets 1 row of data (and only 1 row) -# window_size = pd.Timedelta(days=5) -# data["transaction_id"] = 1 -# data.sort_values(by=["transaction_time"], inplace=True) -# target_values = calculate_target_values( -# df=data.copy(), -# target_dataframe_index="transaction_id", -# labeling_function=sum_amount, -# time_index="transaction_time", -# window_size=window_size, -# ) -# assert target_values["transaction_id"].tolist() == [1, 1, 1] -# assert target_values["cutoff_time"].tolist() == [ -# pd.Timestamp("2022-01-01 00:27:57"), -# pd.Timestamp("2022-01-06 00:48:50"), -# pd.Timestamp("2022-01-11 00:46:09"), -# ] -# assert np.allclose( -# target_values["sum_amount"].tolist(), -# [13837.71, 12990.74, 11213.85], -# ) diff --git a/tests/test_check_operations.py b/tests/test_check_operations.py index bb056d42..d10b60ad 100644 --- a/tests/test_check_operations.py +++ b/tests/test_check_operations.py @@ -175,12 +175,12 @@ def test_check_operations_boolean(metadata): def test_check_operations_cat(metadata): # For each predict the number of records with equal to operations = [EqFilterOp("card_type"), IdentityOp(None), CountAggregationOp(None)] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is True # For each predict the number of records with not equal to operations = [NeqFilterOp("card_type"), IdentityOp(None), CountAggregationOp(None)] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is True # For each predict the majority in all related records with equal to NY @@ -189,7 +189,7 @@ def test_check_operations_cat(metadata): IdentityOp(None), MajorityAggregationOp("card_type"), ] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is True # Not a valid operation @@ -199,28 +199,28 @@ def test_check_operations_cat(metadata): IdentityOp(None), CountAggregationOp(None), ] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is False # Not a valid operation # cannot do SumAggregation on categorical operations = [AllFilterOp(None), IdentityOp(None), SumAggregationOp("card_type")] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is False # For each predict if there exists a record in all related records with equal to NY operations = [AllFilterOp(None), IdentityOp(None), ExistsAggregationOp("card_type")] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is True # For each predict the first in all related records operations = [AllFilterOp(None), IdentityOp(None), FirstAggregationOp("card_type")] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is True # For each predict the last in all related records operations = [AllFilterOp(None), IdentityOp(None), LastAggregationOp("card_type")] - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is True @@ -231,7 +231,7 @@ def test_foreign_key(metadata): SumAggregationOp("card_type"), ] metadata.ml_types["card_type"].add_tags({"foreign_key"}) - result, modified_meta = _check_operations_valid(operations, metadata) + result, _ = _check_operations_valid(operations, metadata) assert result is False diff --git a/tests/test_denormalize_dataframes.py b/tests/test_denormalize_dataframes.py index 65be6358..6ccc131e 100755 --- a/tests/test_denormalize_dataframes.py +++ b/tests/test_denormalize_dataframes.py @@ -8,9 +8,9 @@ def test_denormalize_two_tables(): """ - Products + products / - Logs + logs """ ( dataframes, @@ -58,9 +58,9 @@ def test_denormalize_two_tables(): def test_denormalize_two_tables_change(): """ - Products + products / - Logs + logs """ ( dataframes, @@ -88,9 +88,9 @@ def test_denormalize_two_tables_change(): def test_denormalize_three_tables(): """ - S P Sessions, Products + S P sessions, products \\ / . - L Logs + L logs """ ( dataframes, @@ -131,12 +131,12 @@ def test_denormalize_three_tables(): def test_denormalize_four_tables(): """ - C Customers + C customers | ||| - S P Sessions, Products + S P sessions, products \\ // - L Logs + L logs """ ( dataframes, @@ -172,13 +172,13 @@ def test_denormalize_four_tables(): def test_denormalize_change_target(): """ - C Customers + C customers | ||| - S P Sessions, Products + S P sessions, products ||| ||| || - L Logs + L logs """ ( dataframes, diff --git a/tests/test_denormalize_metadata_only.py b/tests/test_denormalize_metadata_only.py index 0d7c73f1..bb7b3468 100644 --- a/tests/test_denormalize_metadata_only.py +++ b/tests/test_denormalize_metadata_only.py @@ -30,7 +30,7 @@ def test_denormalize_two_tables(): relationships=relationships, time_indices=time_indices, ) - dataframes, normalized_metadata = denormalize( + _, normalized_metadata = denormalize( metadata=multi_metadata, target_table="logs", ) @@ -58,7 +58,7 @@ def test_denormalize_two_tables_change_target(): relationships=relationships, time_indices=time_indices, ) - dataframes, normalized_metadata = denormalize( + _, normalized_metadata = denormalize( metadata=multi_metadata, target_table="products", ) @@ -69,9 +69,9 @@ def test_denormalize_two_tables_change_target(): def test_denormalize_three_tables(): """ - S P Sessions, Products + S P sessions, products \\ / . - L Log + L log """ _, ml_types, relationships, primary_keys, time_indices = generate_mock_data( tables=["products", "logs", "sessions"], @@ -82,7 +82,7 @@ def test_denormalize_three_tables(): time_indices=time_indices, relationships=relationships, ) - dataframes, normalized_metadata = denormalize( + _, normalized_metadata = denormalize( metadata=multi_metadata, target_table="logs", ) @@ -102,14 +102,14 @@ def test_denormalize_three_tables(): def test_denormalize_four_tables(four_table_metadata): """ - C Customers + C customers | ||| - S P Sessions, Products + S P sessions, products \\ // - L Log + L log """ - dataframes, normalized_metadata = denormalize( + _, normalized_metadata = denormalize( metadata=four_table_metadata, target_table="logs", ) @@ -131,15 +131,15 @@ def test_denormalize_four_tables(four_table_metadata): def test_denormalize_change_target(four_table_metadata): """ - C Customers + C customers | ||| - S P Sessions, Products + S P sessions, products \\ // - L Log + L log """ - dataframes, normalized_metadata = denormalize( + _, normalized_metadata = denormalize( metadata=four_table_metadata, target_table="sessions", ) diff --git a/tests/test_llm.py b/tests/test_llm.py new file mode 100644 index 00000000..e2b1e0ca --- /dev/null +++ b/tests/test_llm.py @@ -0,0 +1,61 @@ +import os + +import pytest + +from trane import MultiTableMetadata, ProblemGenerator +from trane.llm import analyze +from trane.utils.testing_utils import generate_mock_data + + +@pytest.fixture +def metadata(): + tables = ["customers", "sessions", "products", "logs"] + ( + _, + ml_types, + relationships, + primary_keys, + time_indices, + ) = generate_mock_data( + tables=tables, + ) + metadata = MultiTableMetadata( + ml_types=ml_types, + primary_keys=primary_keys, + relationships=relationships, + time_indices=time_indices, + ) + return metadata + + +@pytest.fixture +def problems(metadata): + problem_generator = ProblemGenerator( + metadata=metadata, + entity_column=["product_id"], + target_table="products", + ) + problems = problem_generator.generate() + return problems + + +@pytest.mark.parametrize( + "model", + [ + ("gpt-4-1106-preview"), + ], +) +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPEN AI API KEY not found in environment variables", +) +def test_open_ai(problems, model): + instructions = "determine 5 most relevant problems about products" + context = "a fake dataset of ecommerce data" + relevant_problems = analyze( + problems=problems, + instructions=instructions, + context=context, + model=model, + ) + relevant_problems diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 47ce0fac..9321ee44 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -105,9 +105,9 @@ def test_from_dataframes_multi(): ( dataframes, ml_types, - relationships, - primary_keys, - time_indices, + _, + _, + _, ) = generate_mock_data( tables=["products", "logs"], ) diff --git a/tests/test_problem_generator.py b/tests/test_problem_generator.py index 77b10d2d..4e07856b 100644 --- a/tests/test_problem_generator.py +++ b/tests/test_problem_generator.py @@ -11,7 +11,7 @@ def test_problem_generator_single_table(): tables = ["products"] target_table = "products" - dataframe, ml_types, _, primary_key, time_index = generate_mock_data( + dataframe, ml_types, _, _, time_index = generate_mock_data( tables=tables, ) dataframe = dataframe[target_table] @@ -37,6 +37,8 @@ def test_problem_generator_single_table(): for p in problems: if p.has_parameters_set() is True: labels = p.create_target_values(dataframe) + if labels.empty: + raise ValueError("labels should not be empty") check_problem_type(labels, p.get_problem_type()) else: thresholds = p.get_recommended_thresholds(dataframe) diff --git a/tests/test_target_values.py b/tests/test_target_values.py new file mode 100644 index 00000000..5ae700a6 --- /dev/null +++ b/tests/test_target_values.py @@ -0,0 +1,62 @@ +import numpy as np +import pandas as pd +import pytest + +from trane import SingleTableMetadata +from trane.core.problem import Problem +from trane.ops.aggregation_ops import ExistsAggregationOp +from trane.ops.filter_ops import AllFilterOp, GreaterFilterOp +from trane.ops.transformation_ops import IdentityOp + + +@pytest.fixture() +def data(): + num_rows = 100 + data = { + "building_id": np.random.randint(0, 100, num_rows), + "timestamp": pd.date_range(start="2016-01-01", periods=num_rows, freq="H"), + "meter_reading": np.random.uniform(0, 100, num_rows), + } + df = pd.DataFrame(data) + return df + + +@pytest.fixture +def metadata(): + ml_types = { + "building_id": "Integer", + "timestamp": "Datetime", + "meter_reading": "Double", + } + metadata = SingleTableMetadata( + ml_types=ml_types, + primary_key="building_id", + time_index="timestamp", + ) + return metadata + + +def test_greater_than(data, metadata): + operations = [ + GreaterFilterOp("meter_reading"), + IdentityOp(None), + ExistsAggregationOp(None), + ] + problem = Problem( + metadata=metadata, + operations=operations, + entity_column="building_id", + window_size="2d", + ) + problem.create_target_values(data) + + +def test_exists(data, metadata): + operations = [AllFilterOp(None), IdentityOp(None), ExistsAggregationOp(None)] + problem = Problem( + metadata=metadata, + operations=operations, + entity_column="building_id", + window_size="2d", + ) + problem.create_target_values(data) diff --git a/tests/test_typing_utils.py b/tests/test_typing_utils.py new file mode 100644 index 00000000..648520c2 --- /dev/null +++ b/tests/test_typing_utils.py @@ -0,0 +1,76 @@ +import pandas as pd +import pytest + +from trane.metadata import SingleTableMetadata +from trane.typing.utils import set_dataframe_dtypes + + +@pytest.fixture +def sample_dataframe(): + df = pd.DataFrame({ + "num_col": [1, 2, 3], + "float_col": [1.1, 2.2, 3.3], + "bool_col": [True, False, True], + "date_col": ["2021-01-01", "2021-01-02", "2021-01-03"], + "cat_col": ["a", "b", "a"], + "str_col": ["12345", "22222", "33333"], + }) + for col in df: + df[col] = df[col].astype("string") + return df + + +@pytest.fixture +def sample_single_metadata(): + ml_types = { + "num_col": "Integer", + "float_col": "Double", + "bool_col": "Boolean", + "date_col": "Datetime", + "cat_col": "Categorical", + "str_col": "PostalCode", + } + single_metadata = SingleTableMetadata( + ml_types=ml_types, + primary_key="num_col", + time_index="date_col", + ) + return single_metadata + + +def test_numeric_conversion(sample_dataframe, sample_single_metadata): + converted_df = set_dataframe_dtypes(sample_dataframe, sample_single_metadata) + assert converted_df["num_col"].dtype == "int64[pyarrow]" + assert all( + converted_df["num_col"] + == pd.to_numeric(sample_dataframe["num_col"], downcast="integer"), + ) + + +def test_float_conversion(sample_dataframe, sample_single_metadata): + converted_df = set_dataframe_dtypes(sample_dataframe, sample_single_metadata) + assert converted_df["float_col"].dtype == "float64[pyarrow]" + + +def test_boolean_conversion(sample_dataframe, sample_single_metadata): + converted_df = set_dataframe_dtypes(sample_dataframe, sample_single_metadata) + assert converted_df["bool_col"].dtype == "bool[pyarrow]" + + +def test_datetime_conversion(sample_dataframe, sample_single_metadata): + converted_df = set_dataframe_dtypes(sample_dataframe, sample_single_metadata) + assert converted_df["date_col"].dtype == "datetime64[ns]" + + +def test_categorical_conversion(sample_dataframe, sample_single_metadata): + converted_df = set_dataframe_dtypes(sample_dataframe, sample_single_metadata) + assert converted_df["cat_col"].dtype == "category" + + +def test_empty_dataframe(): + df = pd.DataFrame() + metadata = metadata = SingleTableMetadata( + ml_types={}, + ) + converted_df = set_dataframe_dtypes(df, metadata) + assert converted_df.empty diff --git a/trane/__init__.py b/trane/__init__.py index 6374f68d..700186b6 100755 --- a/trane/__init__.py +++ b/trane/__init__.py @@ -1,6 +1,7 @@ from trane.core import * from trane.datasets import load_airbnb, load_store from trane.parsing import * +from trane.metadata import SingleTableMetadata, MultiTableMetadata from trane.typing import * from trane.utils import * from trane.version import __version__ diff --git a/trane/core/problem.py b/trane/core/problem.py index 35d48503..dff7d074 100644 --- a/trane/core/problem.py +++ b/trane/core/problem.py @@ -25,6 +25,7 @@ def __init__( ): self.operations = operations self.metadata = metadata + self.entity_column = entity_column self.window_size = window_size self.reasoning = reasoning @@ -125,11 +126,12 @@ def get_recommended_thresholds(self, dataframes, n_quantiles=10): ) return thresholds - def create_target_values(self, dataframes): + def create_target_values(self, dataframes, verbose=False): # Won't this always be normalized? normalized_dataframe = self.get_normalized_dataframe(dataframes) if self.has_parameters_set() is False: - print("Filter operation's parameters are not set, setting them now") + if verbose: + print("Filter operation's parameters are not set, setting them now") thresholds = self.get_recommended_thresholds(dataframes) self.set_parameters(thresholds[-1]) @@ -145,8 +147,8 @@ def create_target_values(self, dataframes): labeling_function=self._execute_operations_on_df, time_index=self.metadata.time_index, window_size=self.window_size, + verbose=verbose, ) - if "__identity__" in normalized_dataframe.columns: normalized_dataframe.drop(columns=["__identity__"], inplace=True) lt.drop(columns=["__identity__"], inplace=True) diff --git a/trane/core/problem_generator.py b/trane/core/problem_generator.py index 09e3a85f..32c3d65c 100644 --- a/trane/core/problem_generator.py +++ b/trane/core/problem_generator.py @@ -60,7 +60,7 @@ def generate(self, verbose=True): problems = [] valid_entity_columns = self.entity_columns if self.entity_columns is None: - # TODO: add logic to check entity_columns + # TODO: add logic to check entity_column valid_entity_columns = get_valid_entity_columns(single_metadata) # Force create with no entity column to generate problems "Predict X" valid_entity_columns.append(None) diff --git a/trane/core/utils.py b/trane/core/utils.py index 8821ab51..cc4fa2ea 100644 --- a/trane/core/utils.py +++ b/trane/core/utils.py @@ -1,67 +1,31 @@ -from datetime import datetime - import pandas as pd -def set_dataframe_index(df, time_index): - if df.index.name != time_index: - df = df.set_index(time_index) +def set_dataframe_index(df, index): + if df.index.name != index: + df = df.set_index(index, inplace=False) return df -def determine_gap_size(gap): - if isinstance(gap, str): - return pd.Timedelta(gap) - elif isinstance(gap, int) or isinstance(gap, pd.Timedelta): - return gap - elif not gap: - return 1 - return int(gap) - - -def generate_data_slices( - df, - window_size, - gap=1, - drop_empty=True, -): - start_idx = 0 - end_idx = len(df) - 1 - - gap = determine_gap_size(gap) - window_size = determine_gap_size(window_size) - - while start_idx < end_idx: - if isinstance(window_size, pd.Timedelta): - timestamp_at_start = df.index[start_idx] - slice_end_timestamp = timestamp_at_start + window_size - slice_end = df.index[df.index >= slice_end_timestamp].min() - if pd.isna(slice_end): - break - slice_end_idx = df.index.get_loc(slice_end) - if isinstance(df.index.get_loc(slice_end), slice): - # multiple matching indices, so we want the first one (for now) - # TODO: handle this better - slice_end_idx = slice_end_idx.start - dataslice = df.iloc[start_idx:slice_end_idx] - if isinstance(gap, pd.Timedelta): - start_idx_timestamp = slice_end_timestamp + gap - nearest_idx = df.index[df.index >= start_idx_timestamp].min() - if pd.isna(nearest_idx): - break - start_idx = df.index.get_loc(nearest_idx) - else: - start_idx = slice_end_idx + gap - else: - slice_end = start_idx + window_size - if slice_end > end_idx: - break - dataslice = df.iloc[start_idx:slice_end] - start_idx = slice_end + gap - - # Make sure slice_end is exclusive - if not drop_empty or not dataslice.empty: - yield dataslice, {"start": start_idx, "end": slice_end} +def generate_data_slices(df, window_size, gap, drop_empty=True): + # valid for a specify group of id + # so we need to groupby id (before this function) + window_size = pd.to_timedelta(window_size) + gap = pd.to_timedelta(gap) + if window_size != gap: + raise NotImplementedError("window_size != gap is not supported yet") + for start_ts, dataslice in df.resample( + window_size, + closed="left", + label="left", + kind="timestamp", + origin="start", + offset=gap, + ): + # inclusive start_ts and inclusive end_ts + end_ts = dataslice.index[-1] if not dataslice.empty else start_ts + if drop_empty is True and not dataslice.empty: + yield dataslice, {"start": start_ts, "end": end_ts} def calculate_target_values( @@ -70,7 +34,6 @@ def calculate_target_values( labeling_function, time_index, window_size, - gap=1, drop_empty=True, verbose=False, ): @@ -79,11 +42,12 @@ def calculate_target_values( label_name = labeling_function.__name__ for group_key, df_by_index in df.groupby(target_dataframe_index, observed=True): - for dataslice, metadata in generate_data_slices( - df_by_index, - window_size, - gap, - drop_empty, + # TODO: support gap + for dataslice, _ in generate_data_slices( + df=df_by_index, + window_size=window_size, + gap=window_size, + drop_empty=drop_empty, ): record = labeling_function(dataslice) records.append( @@ -93,15 +57,5 @@ def calculate_target_values( label_name: record, }, ) - - if verbose: - print(f"Processed label for group {group_key}") - records = pd.DataFrame.from_records(records, index=None) return records - - -def clean_date(date): - if isinstance(date, str): - return pd.Timestamp(datetime.strptime(date, "%Y-%m-%d")) - return date diff --git a/trane/llm/helpers.py b/trane/llm/helpers.py index 8b24e368..b5a82ebd 100644 --- a/trane/llm/helpers.py +++ b/trane/llm/helpers.py @@ -1,12 +1,13 @@ import json +import os import re -from IPython.display import Markdown, display - from trane.utils.library_utils import import_or_none openai = import_or_none("openai") tiktoken = import_or_none("tiktoken") +anthropic = import_or_none("anthropic") +ipython = import_or_none("IPython") system_context = ( @@ -84,7 +85,7 @@ def analyze( ) response = openai_gpt(prompt, model) if jupyter: - display(Markdown(response)) + ipython.display.display(ipython.display.Markdown(response)) else: print(response) @@ -93,7 +94,6 @@ def analyze( relevant_problems = [] for id_ in relevant_ids: relevant_problems.append(problems[int(id_) - 1]) - reasonsings = extract_reasonings_from_response(response) for idx, reason in enumerate(reasonsings): relevant_problems[idx].set_reasoning(reason) @@ -106,12 +106,16 @@ def extract_problems_from_response(response, model): f"Extract the IDs in the following text." f"## The constraints of your response:\n" f" Return your response as JSON only.\n" + f" Return the IDs in same order as they appear in the text.\n" f"## The text:\n" f"{response}\n" "{{ Insert your response here }}" ) response = openai_gpt(prompt, model) - response = json.loads(response).values() + if model in ["gpt-3.5-turbo-1106", "gpt-4-1106-preview"]: + response = re.findall(r"\d+", response) + else: + response = json.loads(response).values() response = list(flatten(response)) return response @@ -142,26 +146,34 @@ def flatten(container): def openai_gpt(prompt: str, model: str, temperature: float = 0.7) -> str: + client = openai.OpenAI( + api_key=os.environ.get("OPENAI_API_KEY"), + ) messages = [ {"role": "system", "content": system_context}, {"role": "user", "content": prompt}, ] - response = openai.ChatCompletion.create( - model=model, + chat_completion = client.chat.completions.create( messages=messages, + model=model, temperature=temperature, ) - return response["choices"][0]["message"]["content"].strip() + return chat_completion.choices[0].message.content def get_token_limit(model: str) -> int: models = { - "gpt-3.5-turbo": 4000, - "gpt-3.5-turbo-16k": 16000, - "gpt-4": 8000, - "gpt-4-32k": 32000, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-1106": 16385, + "gpt-3.5-turbo-16k-0613": 16385, + "gpt-4": 8192, + "gpt-4-0613": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0613": 32768, + "gpt-4-1106-preview": 128000, } - return models.get(model) + return models.get(model.strip()) def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"): @@ -169,12 +181,12 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"): try: encoding = tiktoken.encoding_for_model(model) except KeyError: + print("Warning: model not found. Using cl100k_base encoding.") encoding = tiktoken.get_encoding("cl100k_base") if model in { "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4-0314", - "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", }: @@ -186,8 +198,14 @@ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"): ) tokens_per_name = -1 # if there's a name, the role is omitted elif "gpt-3.5-turbo" in model: + print( + "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.", + ) return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") elif "gpt-4" in model: + print( + "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.", + ) return num_tokens_from_messages(messages, model="gpt-4-0613") else: raise NotImplementedError( diff --git a/trane/metadata/metadata.py b/trane/metadata/metadata.py index cd5613b5..1c924d48 100644 --- a/trane/metadata/metadata.py +++ b/trane/metadata/metadata.py @@ -1,5 +1,7 @@ from collections import defaultdict +import pandas as pd + from trane.typing.inference import infer_ml_types from trane.typing.ml_types import Datetime, MLType @@ -41,6 +43,15 @@ def __init__( self.set_time_key(time_index) self.original_multi_table_metadata = original_multi_table_metadata + def __repr__(self): + result = "SingleTableMetadata\n" + result += "primary key: " + str(self.primary_key) + "\n" + result += "time index: " + str(self.time_index) + "\n\n" + df = pd.DataFrame.from_dict(data=self.ml_types, orient="index") + df.columns = ["ML Type"] + result += df.to_string() + return result + def set_primary_key(self, primary_key): if primary_key not in self.ml_types: raise ValueError("Index does not exist in ml_types") diff --git a/trane/parsing/denormalize.py b/trane/parsing/denormalize.py index 76fd4988..7208a05c 100644 --- a/trane/parsing/denormalize.py +++ b/trane/parsing/denormalize.py @@ -147,7 +147,7 @@ def child_relationships(parent_table, relationships): def reorder_relationships(target_table, relationships): reordered_relationships = [] for relationship in relationships: - parent_table_name, parent_key, child_table_name, child_key = relationship + _, _, child_table_name, _ = relationship if child_table_name == target_table: reordered_relationships.append(relationship) else: diff --git a/trane/typing/inference_functions.py b/trane/typing/inference_functions.py index d5fd7ffc..c278f282 100644 --- a/trane/typing/inference_functions.py +++ b/trane/typing/inference_functions.py @@ -141,9 +141,7 @@ def boolean_nullable_func(series): set(boolean_list) for boolean_list in boolean_inference_strings ]: return True - except ( - TypeError - ): # Necessary to check for non-hashable values because of object dtype consideration + except TypeError: # Necessary to check for non-hashable values because of object dtype consideration return False elif pdtypes.is_integer_dtype(series.dtype) and len( boolean_inference_ints, diff --git a/trane/typing/ml_types.py b/trane/typing/ml_types.py index 91255526..ebe503d6 100644 --- a/trane/typing/ml_types.py +++ b/trane/typing/ml_types.py @@ -80,6 +80,10 @@ def is_boolean(self): def is_datetime(self): False + @property + def is_categorical(self): + return False + class Boolean(MLType): dtype = "boolean[pyarrow]" @@ -101,6 +105,10 @@ class Categorical(MLType): def inference_func(series): return categorical_func(series) + @property + def is_categorical(self): + return True + class Datetime(MLType): dtype = "datetime64[ns]" diff --git a/trane/typing/utils.py b/trane/typing/utils.py index e69de29b..5463ca19 100644 --- a/trane/typing/utils.py +++ b/trane/typing/utils.py @@ -0,0 +1,30 @@ +import pandas as pd +from pandas.api.types import ( + is_bool_dtype, + is_datetime64_any_dtype, + is_integer_dtype, + is_numeric_dtype, +) + + +def set_dataframe_dtypes(dataframe, metadata): + for col, ml_type in metadata.ml_types.items(): + actual_dtype = dataframe[col].dtype + expected_dtype = ml_type.dtype + if ml_type.is_numeric and not is_numeric_dtype(actual_dtype): + if is_integer_dtype(expected_dtype): + dataframe[col] = dataframe[col].astype(expected_dtype) + else: + dataframe[col] = dataframe[col].astype(expected_dtype) + elif ml_type.is_datetime and not is_datetime64_any_dtype(actual_dtype): + dataframe[col] = pd.to_datetime(dataframe[col]) + elif ml_type.is_categorical and not isinstance( + actual_dtype, + pd.CategoricalDtype, + ): + dataframe[col] = dataframe[col].astype(expected_dtype) + elif ml_type.is_boolean and not is_bool_dtype(actual_dtype): + dataframe[col] = dataframe[col].astype(expected_dtype) + elif actual_dtype != expected_dtype: + dataframe[col] = dataframe[col].astype(expected_dtype) + return dataframe diff --git a/trane/utils/testing_utils.py b/trane/utils/testing_utils.py index 880b6d61..dc5147f4 100644 --- a/trane/utils/testing_utils.py +++ b/trane/utils/testing_utils.py @@ -9,12 +9,12 @@ # v1 version def generate_mock_data(tables): """ - C Customers + C customers | ||| - S P Sessions, Products + S P sessions, products \\ // - L Logs + L logs """ ml_types = {} dataframes = {}