Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplified user interface #46

Open
pbenner opened this issue Jul 11, 2023 · 1 comment
Open

Simplified user interface #46

pbenner opened this issue Jul 11, 2023 · 1 comment
Labels
discussion Needs discussion enhancement New feature or request help wanted Extra attention is needed

Comments

@pbenner
Copy link
Collaborator

pbenner commented Jul 11, 2023

What about providing a simplified user interface for training and testing models? This would be a simple example:

import pandas as pd

from matbench_discovery.data import DATA_FILES, df_wbm
from pymatgen.core import Structure
from sklearn.metrics import r2_score

class MatbenchDiscovery:
    def __init__(self, task_type = "IS2RE"):
        if task_type not in ['IS2RE', 'RS2RE']:
            raise ValueError(f'Invalid task_type {task_type}')
        self.task_type = task_type
    
    def get_test_data(self):
        id_col = "material_id"
        input_col = {"IS2RE": "initial_structure", "RS2RE": "relaxed_structure"}[self.task_type]
        target_col = "e_form_per_atom_mp2020_corrected"

        data_path = {
            "IS2RE": DATA_FILES.wbm_initial_structures,
            "RS2RE": DATA_FILES.wbm_computed_structure_entries,
        }[self.task_type]

        df_in = pd.read_json(data_path).set_index(id_col)

        X = pd.Series([Structure.from_dict(x) for x in df_in[input_col]], index = df_in.index)
        y = pd.Series(df_wbm[target_col])

        return X[y.index], y

    def get_train_data(self):
        assert self.task_type == "IS2RE", "TODO"

        target_col = "formation_energy_per_atom"
        input_col = "structure"
        id_col = "material_id"

        df_cse = pd.read_json(DATA_FILES.mp_computed_structure_entries).set_index(id_col)
        df_eng = pd.read_csv(DATA_FILES.mp_energies).set_index(id_col)

        X = pd.Series([ Structure.from_dict(cse[input_col]) for cse in df_cse.entry ], index = df_cse.index)
        y = pd.Series(df_eng[target_col], index = df_eng.index)

        return X[y.index], y

    def evaluate_predictions(self, y_pred, apply_correction = False):

        assert type(y_pred) == pd.Series

        target_col = "e_form_per_atom_mp2020_corrected"

        y_pred = y_pred.dropna()
        y_true = df_wbm[target_col][y_pred.index]

        if apply_correction:
            y_pred -= df_wbm.e_correction_per_atom_mp_legacy
            y_pred += df_wbm.e_correction_per_atom_mp2020

        mae = (y_true - y_pred).abs().mean()
        r2 = r2_score(y_true, y_pred)

        return {'mae': mae, 'r2': r2, 'y_true': y_true, 'y_pred': y_pred}
@janosh
Copy link
Owner

janosh commented Jul 25, 2023

Super slow reply (sorry) but I think this is a great idea! It'll make all the more sense if future model evaluations become more complex from including other stability predictions like elastic/kinetic instability.

@matthewkuner Might be interested in this as well. Would be great to hear as many opinions on API design.

@janosh janosh added enhancement New feature or request help wanted Extra attention is needed discussion Needs discussion labels Jul 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Needs discussion enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants