Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
fix default handling of sample_weight (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri authored Mar 6, 2023
1 parent 64f04d6 commit 2723e12
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
21 changes: 8 additions & 13 deletions src/equisolve/numpy/models/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def fit(
X: TensorMap,
y: TensorMap,
alpha: Union[float, TensorMap] = 1.0,
sample_weight: Union[float, TensorMap] = 1.0,
sample_weight: Union[float, TensorMap] = None,
rcond: float = 1e-13,
) -> None:
"""Fit a regression model to each block in `X`.
Expand All @@ -187,7 +187,8 @@ def fit(
Values must be non-negative floats i.e. in [0, inf). α can be different for
each column in `X` to regulerize each property differently.
:param sample_weight:
sample weights
Individual weights for each sample. For `None` or a float, every sample will
have the same weight of 1 or the float, respectively..
:param rcond:
Cut-off ratio for small singular values during the fit. For the purposes of
rank determination, singular values are treated as zero if they are smaller
Expand All @@ -207,18 +208,12 @@ def fit(
elif type(alpha) is not TensorMap:
raise ValueError("alpha must either be a float or a TensorMap")

if type(sample_weight) is float:
sw_tensor = ones_like(X)

properties = Labels(
names=X.property_names,
values=np.zeros([1, len(X.property_names)], dtype=int),
)

sw_tensor = slice(sw_tensor, properties=properties)
sample_weight = multiply(sw_tensor, sample_weight)
if sample_weight is None:
sample_weight = ones_like(y)
elif type(sample_weight) is float:
sample_weight = multiply(ones_like(y), sample_weight)
elif type(sample_weight) is not TensorMap:
raise ValueError("sample_weight must either be a float or a TensorMap")
raise ValueError("sample_weight must either be a float or a TensorMap.")

self._validate_data(X, y)
self._validate_params(X, alpha, sample_weight)
Expand Down
23 changes: 10 additions & 13 deletions tests/equisolve_tests/numpy/models/test_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,25 @@ def numpy_solver(self, X, y, sample_weights, regularizations):

return w_solver

@pytest.mark.parametrize("num_properties", num_properties)
@pytest.mark.parametrize("num_targets", num_targets)
def test_ridge(self, num_properties, num_targets):
"""Test if ridge is working and all shapes are converted correctly.
@pytest.mark.parametrize("alpha", [0.0, 1.0])
@pytest.mark.parametrize("sample_weight", [None, 1.0])
def test_ridge(self, alpha, sample_weight):
"""Test if ridge is working.
Test is performed for two blocks.
"""
num_targets = 50
num_properties = 5

# Create input values
X_arr = self.rng.random([2, num_targets, num_properties])
y_arr = self.rng.random([2, num_targets, 1])
alpha_arr = np.ones([2, 1, num_properties])
sw_arr = np.ones([2, num_targets, 1])

X = tensor_to_tensormap(X_arr)
y = tensor_to_tensormap(y_arr)
alpha = tensor_to_tensormap(alpha_arr)
sw = tensor_to_tensormap(sw_arr)

clf = Ridge(parameter_keys="values")
clf.fit(X=X, y=y, alpha=alpha, sample_weight=sw)
clf.fit(X=X, y=y, alpha=alpha, sample_weight=sample_weight)

assert len(clf.weights) == 2
assert clf.weights.block(0).values.shape[1] == num_properties
Expand All @@ -115,15 +114,13 @@ def test_double_fit_call(self):

X_arr = self.rng.random([num_blocks, num_targets, num_properties])
y_arr = self.rng.random([num_blocks, num_targets, 1])
alpha_arr = np.ones([num_blocks, 1, num_properties])

X = tensor_to_tensormap(X_arr)
y = tensor_to_tensormap(y_arr)
alpha = tensor_to_tensormap(alpha_arr)

clf = Ridge(parameter_keys="values")
clf.fit(X=X, y=y, alpha=alpha)
clf.fit(X=X, y=y, alpha=alpha)
clf.fit(X=X, y=y, alpha=1.0)
clf.fit(X=X, y=y, alpha=1.0)

assert len(clf.weights) == num_blocks

Expand Down

0 comments on commit 2723e12

Please sign in to comment.