Skip to content

Commit

Permalink
Add support for balance_classes to Deep Water (h2oai#362)
Browse files Browse the repository at this point in the history
* Add support for balance_classes to Deep Water
* add balance_classes, max_after_balance_size and class_sampling_factors (same as for GBM/DRF/DL)
* fix stratified sampling for string chunks
* use the same gpu and device_id as in training for MOJO scoring (by default) - this fixes the imageURLs test (CPU vs GPU was the difference)

* Expose a bug in NA handling for categoricals - to be fixed.

* More cleanup: add checks for instability, modify the unstable test.
  • Loading branch information
arnocandel authored Oct 18, 2016
1 parent 031b0bf commit e439c91
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 31 deletions.
18 changes: 18 additions & 0 deletions h2o-algos/src/main/java/hex/deepwater/DeepWater.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.MRUtils;
import water.util.PrettyPrint;

import java.util.Arrays;

import static hex.deepwater.DeepWaterModel.makeDataInfo;
import static water.util.MRUtils.sampleFrame;
import static water.util.MRUtils.sampleFrameStratified;

/**
* Deep Learning Neural Net implementation based on MRTask
Expand Down Expand Up @@ -234,6 +237,21 @@ public final DeepWaterModel trainModel(DeepWaterModel model) {
Frame val_fr = _valid != null ? new Frame(mp._valid,_valid.names(), _valid.vecs()) : null;

train = tra_fr;
if (model._output.isClassifier() && mp._balance_classes) {
_job.update(0,"Balancing class distribution of training data...");
float[] trainSamplingFactors = new float[train.lastVec().domain().length]; //leave initialized to 0 -> will be filled up below
if (mp._class_sampling_factors != null) {
if (mp._class_sampling_factors.length != train.lastVec().domain().length)
throw new IllegalArgumentException("class_sampling_factors must have " + train.lastVec().domain().length + " elements");
trainSamplingFactors = mp._class_sampling_factors.clone(); //clone: don't modify the original
}
train = sampleFrameStratified(
train, train.lastVec(), train.vec(model._output.weightsName()), trainSamplingFactors, (long)(mp._max_after_balance_size*train.numRows()), mp._seed, true, false);
Vec l = train.lastVec();
Vec w = train.vec(model._output.weightsName());
MRUtils.ClassDist cd = new MRUtils.ClassDist(l);
model._output._modelClassDist = _weights != null ? cd.doAll(l, w).rel_dist() : cd.doAll(l).rel_dist();
}
model.training_rows = train.numRows();
model.actual_train_samples_per_iteration =
_parms._train_samples_per_iteration > 0 ? _parms._train_samples_per_iteration : //user-given value (>0)
Expand Down
5 changes: 3 additions & 2 deletions h2o-algos/src/main/java/hex/deepwater/DeepWaterModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

import static hex.ModelMetrics.calcVarImp;
Expand Down Expand Up @@ -73,11 +74,11 @@ void set_model_info(DeepWaterModelInfo mi) {

Key actual_best_model_key;

private final String unstable_msg = technote(4,
static final String unstable_msg = technote(4,
"\n\nTrying to predict with an unstable model." +
"\nJob was aborted due to observed numerical instability (exponential growth)."
+ "\nEither the weights or the bias values are unreasonably large or lead to large activation values."
+ "\nTry a different initial distribution, a bounded activation function (Tanh), adding regularization"
+ "\nTry a different network architecture, a bounded activation function (tanh), adding regularization"
+ "\n(via dropout) or use a smaller learning rate and/or momentum.");

public DeepWaterScoringInfo last_scored() { return (DeepWaterScoringInfo) super.last_scored(); }
Expand Down
17 changes: 13 additions & 4 deletions h2o-algos/src/main/java/hex/deepwater/DeepWaterTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;

Expand Down Expand Up @@ -175,14 +177,21 @@ else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemTy
_localmodel._backend.setParameter(_localmodel._model, "momentum", _localmodel.get_params().momentum((double) n));

//fork off GPU work, but let the iterator.Next() wait on completion before swapping again
//System.err.println("data: " + Arrays.toString(iter.getData()));
float[] preds = _localmodel._backend.predict(_localmodel._model, iter.getData());
if (Float.isNaN(ArrayUtils.sum(preds))) {
Log.err(DeepWaterModel.unstable_msg);
throw new UnsupportedOperationException(DeepWaterModel.unstable_msg);
}
// System.err.println("pred: " + Arrays.toString(preds));
ntt = new NativeTrainTask(_localmodel._backend, _localmodel._model, iter.getData(), iter.getLabel());
fs.add(H2O.submitTask(ntt));
_localmodel.add_processed_local(iter._batch_size);
}
fs.blockForPending();
// nativetime += ntt._timeInMillis;
} catch (IOException e) {
e.printStackTrace();
e.printStackTrace(); //gracefully continue if we can't find files etc.
}
// long end = System.currentTimeMillis();
// if (!_localmodel.get_params()._quiet_mode) {
Expand All @@ -195,14 +204,14 @@ else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemTy
static private class NativeTrainTask extends H2O.H2OCountedCompleter<NativeTrainTask> {

long _timeInMillis;
final BackendTrain _it;
final BackendTrain _backend;
final BackendModel _model;

float[] _data;
float[] _labels;

NativeTrainTask(BackendTrain backend, BackendModel model, float[] data, float[] label) {
_it = backend;
_backend = backend;
_model = model;
_data = data;
_labels = label;
Expand All @@ -211,7 +220,7 @@ static private class NativeTrainTask extends H2O.H2OCountedCompleter<NativeTrain
@Override
public void compute2() {
long start = System.currentTimeMillis();
_it.train(_model, _data,_labels); //ignore predictions
_backend.train(_model, _data,_labels); //ignore predictions
long end = System.currentTimeMillis();
_timeInMillis += end-start;
tryComplete();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ protected void writeModelData() throws IOException {
writekv("norm_resp_mul", _output._normRespMul);
writekv("norm_resp_sub", _output._normRespSub);
writekv("use_all_factor_levels", _output._useAllFactorLevels);
writekv("gpu", _parms._gpu);
writekv("device_id", _parms._device_id);

writeblob("model_network", _model_info._network);
writeblob("model_params", _model_info._modelparams);
Expand Down
33 changes: 31 additions & 2 deletions h2o-algos/src/main/java/hex/schemas/DeepWaterV3.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public static final class DeepWaterParametersV3 extends ModelParametersSchemaV3<
"training_frame",
"validation_frame",
"nfolds",
"balance_classes",
"max_after_balance_size",
"class_sampling_factors",
"keep_cross_validation_predictions",
"keep_cross_validation_fold_assignment",
"fold_assignment",
Expand Down Expand Up @@ -336,8 +339,7 @@ public static final class DeepWaterParametersV3 extends ModelParametersSchemaV3<
* times the dataset size or larger).
*/
@API(level = API.Level.expert, direction = API.Direction.INOUT, gridable = true,
help = "Enable shuffling of training data (recommended if training data is replicated and " +
"train_samples_per_iteration is close to #nodes x #rows, of if using balance_classes).")
help = "Enable global shuffling of training data.")
public boolean shuffle_training_data;

@API(level = API.Level.expert, direction=API.Direction.INOUT, gridable = true,
Expand Down Expand Up @@ -391,5 +393,32 @@ public static final class DeepWaterParametersV3 extends ModelParametersSchemaV3<
@API(level = API.Level.secondary, direction = API.Direction.INOUT, gridable = true,
help = "If enabled, automatically standardize the data. If disabled, the user must provide properly scaled input data.")
public boolean standardize;

/**
* For imbalanced data, balance training data class counts via
* over/under-sampling. This can result in improved predictive accuracy.
*/
@API(level = API.Level.secondary, direction = API.Direction.INOUT, gridable = true,
help = "Balance training data class counts via over/under-sampling (for imbalanced data).")
public boolean balance_classes;

/**
* Desired over/under-sampling ratios per class (lexicographic order).
* Only when balance_classes is enabled.
* If not specified, they will be automatically computed to obtain class balance during training.
*/
@API(level = API.Level.expert, direction = API.Direction.INOUT, gridable = true,
help = "Desired over/under-sampling ratios per class (in lexicographic order). If not specified, sampling " +
"factors will be automatically computed to obtain class balance during training. Requires balance_classes.")
public float[] class_sampling_factors;

/**
* When classes are balanced, limit the resulting dataset size to the
* specified multiple of the original dataset size.
*/
@API(level = API.Level.expert, direction = API.Direction.INOUT, gridable = false,
help = "Maximum relative size of the training data after balancing class counts (can be less than 1.0). " +
"Requires balance_classes.")
public float max_after_balance_size;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@ public void imageURLs() {
p._backend = getBackend();
p._train = (tr = parse_test_file("smalldata/deepwater/imagenet/binomial_image_urls.csv"))._key;
p._response_column = "C2";
p._balance_classes = true;
p._epochs = 1;
p._max_after_balance_size = 2f;
p._class_sampling_factors = new float[]{3,5};
DeepWater j = new DeepWater(p);
m = j.trainModel().get();
Assert.assertTrue((m._output._training_metrics).auc_obj()._auc > 0.90);
Expand Down
1 change: 1 addition & 0 deletions h2o-core/src/main/java/hex/ModelMetricsBinomial.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public String toString() {
if (_auc != null) sb.append(" AUC: " + (float)_auc._auc + "\n");
sb.append(" logloss: " + (float)_logloss + "\n");
sb.append(" mean_per_class_error: " + (float)_mean_per_class_error + "\n");
sb.append(" default threshold: " + (_auc == null ? 0.5 : (float)_auc.defaultThreshold()) + "\n");
if (cm() != null) sb.append(" CM: " + cm().toASCII());
if (_gainsLift != null) sb.append(_gainsLift);
return sb.toString();
Expand Down
18 changes: 14 additions & 4 deletions h2o-core/src/main/java/water/util/MRUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ public void map(Chunk[] cs, NewChunk[] ncs) {
ArrayUtils.shuffleArray(idx, getRNG(seed));
for (long anIdx : idx) {
for (int i = 0; i < ncs.length; i++) {
ncs[i].addNum(cs[i].atd((int) anIdx));
if (cs[i] instanceof CStrChunk) {
ncs[i].addStr(cs[i],cs[i].start()+anIdx);
} else {
ncs[i].addNum(cs[i].atd((int) anIdx));
}
}
}
}
}.doAll(fr.numCols(), Vec.T_NUM, fr).outputFrame(fr.names(), fr.domains());
}.doAll(fr.types(), fr).outputFrame(fr.names(), fr.domains());
}

/**
Expand Down Expand Up @@ -282,8 +286,14 @@ public void map(Chunk[] cs, NewChunk[] ncs) {
sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
}
for (int i = 0; i < ncs.length; i++) {
for (int j = 0; j < sampling_reps; ++j) {
ncs[i].addNum(cs[i].atd(r));
if (cs[i] instanceof CStrChunk) {
for (int j = 0; j < sampling_reps; ++j) {
ncs[i].addStr(cs[i],cs[0].start()+r);
}
} else {
for (int j = 0; j < sampling_reps; ++j) {
ncs[i].addNum(cs[i].atd(r));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public final double[] score0(double[] doubles, double offset, double[] preds) {
if (_nclasses > 1) {
for (int i = 0; i < predFloats.length; ++i)
preds[1 + i] = predFloats[i];
if (_balanceClasses)
GenModel.correctProbabilities(preds, _priorClassDistrib, _modelClassDistrib);
preds[0] = GenModel.getPrediction(preds, _priorClassDistrib, doubles, _defaultThreshold);
} else {
if (_normRespMul!=null && _normRespSub!=null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ protected void readModelData() throws IOException {
_model._imageDataSet = new ImageDataSet(_model._width, _model._height, _model._channels);

_model._opts = new RuntimeOptions();
_model._opts.setSeed(0); // ignored
_model._opts.setUseGPU(false); // don't use a GPU for inference
_model._opts.setDeviceID(0); // ignored
_model._opts.setSeed(0); // ignored - not needed during scoring
_model._opts.setUseGPU((boolean)readkv("gpu"));
_model._opts.setDeviceID((int[])readkv("device_id"));

_model._backendParams = new BackendParams();
_model._backendParams.set("mini_batch_size", 1);
Expand Down
71 changes: 55 additions & 16 deletions h2o-py/h2o/estimators/deepwater.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ class H2ODeepWaterEstimator(H2OEstimator):
def __init__(self, **kwargs):
super(H2ODeepWaterEstimator, self).__init__()
self._parms = {}
names_list = {"model_id", "checkpoint", "training_frame", "validation_frame", "nfolds",
"keep_cross_validation_predictions", "keep_cross_validation_fold_assignment", "fold_assignment",
"fold_column", "response_column", "ignored_columns", "score_each_iteration",
"categorical_encoding", "overwrite_with_best_model", "epochs", "train_samples_per_iteration",
"target_ratio_comm_to_comp", "seed", "standardize", "learning_rate", "learning_rate_annealing",
"momentum_start", "momentum_ramp", "momentum_stable", "distribution", "score_interval",
"score_training_samples", "score_validation_samples", "score_duty_cycle", "stopping_rounds",
"stopping_metric", "stopping_tolerance", "max_runtime_secs", "ignore_const_cols",
"shuffle_training_data", "mini_batch_size", "clip_gradient", "network", "backend", "image_shape",
"channels", "gpu", "device_id", "network_definition_file", "network_parameters_file",
"mean_image_file", "export_native_parameters_prefix", "activation", "hidden",
"input_dropout_ratio", "hidden_dropout_ratios", "problem_type"}
names_list = {"model_id", "checkpoint", "training_frame", "validation_frame", "nfolds", "balance_classes",
"max_after_balance_size", "class_sampling_factors", "keep_cross_validation_predictions",
"keep_cross_validation_fold_assignment", "fold_assignment", "fold_column", "response_column",
"ignored_columns", "score_each_iteration", "categorical_encoding", "overwrite_with_best_model",
"epochs", "train_samples_per_iteration", "target_ratio_comm_to_comp", "seed", "standardize",
"learning_rate", "learning_rate_annealing", "momentum_start", "momentum_ramp", "momentum_stable",
"distribution", "score_interval", "score_training_samples", "score_validation_samples",
"score_duty_cycle", "stopping_rounds", "stopping_metric", "stopping_tolerance",
"max_runtime_secs", "ignore_const_cols", "shuffle_training_data", "mini_batch_size",
"clip_gradient", "network", "backend", "image_shape", "channels", "gpu", "device_id",
"network_definition_file", "network_parameters_file", "mean_image_file",
"export_native_parameters_prefix", "activation", "hidden", "input_dropout_ratio",
"hidden_dropout_ratios", "problem_type"}
if "Lambda" in kwargs: kwargs["lambda_"] = kwargs.pop("Lambda")
for pname, pvalue in kwargs.items():
if pname == 'model_id':
Expand Down Expand Up @@ -93,6 +94,47 @@ def nfolds(self, nfolds):
self._parms["nfolds"] = nfolds


@property
def balance_classes(self):
"""
bool: Balance training data class counts via over/under-sampling (for imbalanced data). (Default: False)
"""
return self._parms.get("balance_classes")

@balance_classes.setter
def balance_classes(self, balance_classes):
assert_is_type(balance_classes, None, bool)
self._parms["balance_classes"] = balance_classes


@property
def max_after_balance_size(self):
"""
float: Maximum relative size of the training data after balancing class counts (can be less than 1.0). Requires
balance_classes. (Default: 5.0)
"""
return self._parms.get("max_after_balance_size")

@max_after_balance_size.setter
def max_after_balance_size(self, max_after_balance_size):
assert_is_type(max_after_balance_size, None, float)
self._parms["max_after_balance_size"] = max_after_balance_size


@property
def class_sampling_factors(self):
"""
List[float]: Desired over/under-sampling ratios per class (in lexicographic order). If not specified, sampling
factors will be automatically computed to obtain class balance during training. Requires balance_classes.
"""
return self._parms.get("class_sampling_factors")

@class_sampling_factors.setter
def class_sampling_factors(self, class_sampling_factors):
assert_is_type(class_sampling_factors, None, [float])
self._parms["class_sampling_factors"] = class_sampling_factors


@property
def keep_cross_validation_predictions(self):
"""bool: Whether to keep the predictions of the cross-validation models. (Default: False)"""
Expand Down Expand Up @@ -450,10 +492,7 @@ def ignore_const_cols(self, ignore_const_cols):

@property
def shuffle_training_data(self):
"""
bool: Enable shuffling of training data (recommended if training data is replicated and
train_samples_per_iteration is close to #nodes x #rows, of if using balance_classes). (Default: True)
"""
"""bool: Enable global shuffling of training data. (Default: True)"""
return self._parms.get("shuffle_training_data")

@shuffle_training_data.setter
Expand Down
Loading

0 comments on commit e439c91

Please sign in to comment.