-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into milestone-2-fix
- Loading branch information
Showing
5 changed files
with
104 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
def validate_column_names(wine, correct_columns): | ||
""" | ||
This function validates that the column names of the provided DataFrame match the expected column names. | ||
Parameters: | ||
---------- | ||
wine : pandas.DataFrame | ||
The DataFrame to validate. | ||
correct_columns : set | ||
A set of expected column names. | ||
Raises: | ||
------ | ||
ValueError | ||
If the column names in the DataFrame don't match the expected column names. | ||
The error will specify: | ||
- Unexpected columns (columns present in the DataFrame but not in the expected set). | ||
- Missing columns (columns expected but not present in the DataFrame). | ||
Returns: | ||
------- | ||
None | ||
If the column names match the expected set, the function will print "Column name test passed!". | ||
Example: | ||
------- | ||
>>> import pandas as pd | ||
>>> wine_df = pd.DataFrame(columns=["feature1", "feature2", "target"]) | ||
>>> expected_columns = {"feature1", "feature2", "target"} | ||
>>> validate_column_names(wine_df, expected_columns) | ||
Column name test passed! | ||
>>> incorrect_df = pd.DataFrame(columns=["feature1", "feature3"]) | ||
>>> validate_column_names(incorrect_df, expected_columns) | ||
ValueError: Unexpected columns: ['feature3'], missing columns: ['feature2', 'target'] | ||
""" | ||
extracted_columns = set(wine.columns) | ||
if extracted_columns != correct_columns: | ||
wrong_columns = extracted_columns.difference(correct_columns) | ||
missing_columns = correct_columns.difference(extracted_columns) | ||
if wrong_columns and missing_columns: | ||
raise ValueError(f"Unexpected columns: {list(wrong_columns)}, missing columns: {list(missing_columns)}") | ||
elif wrong_columns: | ||
raise ValueError(f"Unexpected columns: {list(wrong_columns)}") | ||
elif missing_columns: | ||
raise ValueError(f"Missing columns: {list(missing_columns)}") | ||
else: | ||
print("Column name test passed!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pytest -v test_random_search.py \ | ||
--train_data=../data/proc/wine_train.csv \ | ||
--test_data=../data/proc/wine_test.csv \ | ||
--pipeline_path=../results/models/wine_pipeline.pickle | ||
|
||
pytest test_validate_column_names.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
# Random Search Test | ||
pytest -v test_random_search.py \ | ||
--train_data=../data/proc/wine_train.csv \ | ||
--test_data=../data/proc/wine_test.csv \ | ||
--pipeline_path=../results/models/wine_pipeline.pickle | ||
|
||
pytest test_validate_column_names.py | ||
|
||
pytest test_split_data.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import pytest | ||
import os | ||
import pandas as pd | ||
import sys | ||
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | ||
from src.validate_column_names import validate_column_names | ||
|
||
# Correct columns | ||
correct_columns = {"column1", "column2", "column3"} | ||
|
||
# Test data | ||
correct_df = pd.DataFrame(columns=["column1", "column2", "column3"]) | ||
extra_column_df = pd.DataFrame(columns=["column1", "column2", "column3", "extra_column"]) | ||
missing_column_df = pd.DataFrame(columns=["column1", "column2"]) | ||
wrong_column_df = pd.DataFrame(columns=["wrong_column1", "wrong_column2", "wrong_column3"]) | ||
|
||
# Test for correct column names | ||
def test_validate_column_names_correct(): | ||
try: | ||
validate_column_names(correct_df, correct_columns) | ||
except ValueError: | ||
pytest.fail("validate_column_names raised ValueError unexpectedly for correct columns.") | ||
|
||
# Test for extra column | ||
def test_validate_column_names_extra_column(): | ||
with pytest.raises(ValueError, match="Unexpected columns:"): | ||
validate_column_names(extra_column_df, correct_columns) | ||
|
||
# Test for missing column | ||
def test_validate_column_names_missing_column(): | ||
with pytest.raises(ValueError, match="Missing columns:"): | ||
validate_column_names(missing_column_df, correct_columns) | ||
|
||
# Test for completely wrong columns | ||
def test_validate_column_names_wrong_column(): | ||
with pytest.raises(ValueError, match="Unexpected columns:"): | ||
validate_column_names(wrong_column_df, correct_columns) | ||
|
||
# Test for both extra and missing columns | ||
def test_validate_column_names_extra_and_missing(): | ||
mixed_df = pd.DataFrame(columns=["column1", "extra_column"]) | ||
with pytest.raises(ValueError, match="Unexpected columns:.*missing columns:"): | ||
validate_column_names(mixed_df, correct_columns) |