Skip to content

Commit

Permalink
rem feats name arg
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 committed Apr 6, 2024
1 parent c515f43 commit 82159af
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/HHbbVV/postprocessing/TrainBDT.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def main(args):

if args.rem_feats:
bdtVars = bdtVars[: -args.rem_feats]
if len(args.rem_feats_name):
bdtVars = [var for var in bdtVars if var not in args.rem_feats_name]

print("BDT features:\n", bdtVars)

Expand Down Expand Up @@ -557,6 +559,7 @@ def plot_train_test_shapes(
)


# def plot_mass_shapes(model, bdtVars, multiclass, train, test, sig_keys, model_dir, training_keys):
def plot_mass_shapes(train, test, sig_keys, model_dir, training_keys):
cuts = [0, 0.1, 0.5, 0.9, 0.95]

Expand Down Expand Up @@ -599,13 +602,27 @@ def plot_mass_shapes(train, test, sig_keys, model_dir, training_keys):
key,
f"BDTScore{sig_key}",
cuts,
year,
"all",
weight_key,
plot_dir=save_model_dir,
name=f"{label}_{key}_BDT{sig_key}Cuts_AllYears",
show=False,
)

# make plots with artificially increased qcd statistics
# if key == qcd_key:
# nsampling = 100
# X = get_X(data_dict, bdtVars)
# X = X.iloc[np.tile(np.arange(len(X)), nsampling)]

# rng = np.random.default_rng(seed=42)
# X["VVFatJetParticleNetMass"] = rng.random(len(X)) * 15 + 125
# X["VVFatJetParTMD_probHWW3q"] = rng.random(len(X)) * 0.05 + 0.95
# X["VVFatJetParTMD_probQCD"] = rng.random(len(X)) * 0.025 + 0.025

# preds = model.predict_proba(get_X(data, bdtVars))
# preds = _get_bdt_scores(preds, sig_keys, multiclass)


def evaluate_model(
model: xgb.XGBClassifier,
Expand Down Expand Up @@ -876,7 +893,12 @@ def do_inference(
"--n-estimators", default=10000, help="max number of trees to keep adding", type=int
)

parser.add_argument("--rem-feats", default=0, help="remove N lowest importance feats", type=int)
parser.add_argument(
"--rem-feats", default=0, help="remove N lowest importance training feats", type=int
)
parser.add_argument(
"--rem-feats-name", default=[], help="remove training features by name", type=str, nargs="*"
)

"""
Slightly worse to use a single tagger score
Expand Down

0 comments on commit 82159af

Please sign in to comment.