Skip to content

Commit

Permalink
Refactor K-Means output: rename rows -> size.
Browse files Browse the repository at this point in the history
  • Loading branch information
aboyoun committed Dec 31, 2014
1 parent 78daa77 commit 2b18b50
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 36 deletions.
24 changes: 12 additions & 12 deletions h2o-algos/src/main/java/hex/kmeans/KMeans.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private class KMeansDriver extends H2OCountedCompleter<KMeansDriver> {
// per iteration ('cause we only tracked the 1 worst row)
boolean badrow=false;
for( int clu=0; clu<_parms._k; clu++ ) {
if (task._rows[clu] == 0) {
if (task._size[clu] == 0) {
// If we see 2 or more bad rows, just re-run Lloyds to get the
// next-worst row. We don't count this as an iteration, because
// we're not really adjusting the centers, we're trying to get
Expand All @@ -178,20 +178,20 @@ private class KMeansDriver extends H2OCountedCompleter<KMeansDriver> {
long row = task._worst_row;
Log.warn("KMeans: Re-initializing cluster " + clu + " to row " + row);
data(centers[clu] = task._cMeans[clu], vecs, row, means, mults);
task._rows[clu] = 1;
task._size[clu] = 1;
badrow = true;
}
}

// Fill in the model; destandardized centers
model._output._names = _train.names();
model._output._centers = destandardize(task._cMeans, _ncats, means, mults);
model._output._rows = task._rows;
model._output._size = task._size;
model._output._withinmse = task._cSqr;
double ssq = 0; // sum squared error
for( int i=0; i<_parms._k; i++ ) {
ssq += model._output._withinmse[i]; // sum squared error all clusters
model._output._withinmse[i] /= task._rows[i]; // mse within-cluster
model._output._withinmse[i] /= task._size[i]; // mse within-cluster
}
model._output._avgwithinss = ssq/_train.numRows(); // mse total

Expand All @@ -218,7 +218,7 @@ private class KMeansDriver extends H2OCountedCompleter<KMeansDriver> {
StringBuilder sb = new StringBuilder();
sb.append("KMeans: iter: ").append(model._output._iters).append(", MSE=").append(model._output._avgwithinss);
for( int i=0; i<_parms._k; i++ )
sb.append(", ").append(task._cSqr[i]).append("/").append(task._rows[i]);
sb.append(", ").append(task._cSqr[i]).append("/").append(task._size[i]);
Log.info(sb);
}

Expand Down Expand Up @@ -333,7 +333,7 @@ private static class Lloyds extends MRTask<Lloyds> {
double[][] _cMeans; // Means for each cluster
long[/*k*/][/*ncats*/][] _cats; // Histogram of cat levels
double[] _cSqr; // Sum of squares for each cluster
long[] _rows; // Rows per cluster
long[] _size; // Number of rows in each cluster
long _worst_row; // Row with max err
double _worst_err; // Max-err-row's max-err

Expand All @@ -350,7 +350,7 @@ private static class Lloyds extends MRTask<Lloyds> {
assert _centers[0].length==N;
_cMeans = new double[_k][N];
_cSqr = new double[_k];
_rows = new long[_k];
_size = new long[_k];
// Space for cat histograms
_cats = new long[_k][_ncats][];
for( int clu=0; clu< _k; clu++ )
Expand All @@ -373,29 +373,29 @@ private static class Lloyds extends MRTask<Lloyds> {
_cats[clu][col][(int)values[col]]++; // Histogram the cats
for( int col = _ncats; col < N; col++ )
_cMeans[clu][col] += values[col];
_rows[clu]++;
_size[clu]++;
// Track worst row
if( cd._dist > _worst_err) { _worst_err = cd._dist; _worst_row = cs[0].start()+row; }
}
// Scale back down to local mean
for( int clu = 0; clu < _k; clu++ )
if( _rows[clu] != 0 ) ArrayUtils.div(_cMeans[clu],_rows[clu]);
if( _size[clu] != 0 ) ArrayUtils.div(_cMeans[clu], _size[clu]);
_centers = null;
_means = _mults = null;
}

@Override public void reduce(Lloyds mr) {
for( int clu = 0; clu < _k; clu++ ) {
long ra = _rows[clu];
long rb = mr._rows[clu];
long ra = _size[clu];
long rb = mr._size[clu];
double[] ma = _cMeans[clu];
double[] mb = mr._cMeans[clu];
for( int c = 0; c < ma.length; c++ ) // Recursive mean
if( ra+rb > 0 ) ma[c] = (ma[c] * ra + mb[c] * rb) / (ra + rb);
}
ArrayUtils.add(_cats, mr._cats);
ArrayUtils.add(_cSqr, mr._cSqr);
ArrayUtils.add(_rows, mr._rows);
ArrayUtils.add(_size, mr._size);
// track global worst-row
if( _worst_err < mr._worst_err) { _worst_err = mr._worst_err; _worst_row = mr._worst_row; }
}
Expand Down
4 changes: 2 additions & 2 deletions h2o-algos/src/main/java/hex/kmeans/KMeansModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public static class KMeansOutput extends Model.Output {
// is used during the building process, the *builders* cluster centers are standardized).
public double[/*k*/][/*features*/] _centers;

// Rows per cluster
public long[/*k*/] _rows;
// Cluster size. Defined as the number of rows in each cluster.
public long[/*k*/] _size;

// Sum squared distance between each point and its cluster center, divided by total observations in cluster.
public double[/*k*/] _withinmse; // Within-cluster MSE, variance
Expand Down
4 changes: 2 additions & 2 deletions h2o-algos/src/main/java/hex/schemas/KMeansModelV2.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ public static final class KMeansModelOutputV2 extends ModelOutputSchema<KMeansMo
@API(help="Cluster Centers[k][features]")
public double[/*k*/][/*features*/] centers;

@API(help="Rows[k]")
public long[/*k*/] rows;
@API(help="Cluster Size[k]")
public long[/*k*/] size;

@API(help="Within cluster Mean Square Error per cluster")
public double[/*k*/] withinmse; // Within-cluster MSE, variance
Expand Down
4 changes: 2 additions & 2 deletions h2o-algos/src/test/java/hex/kmeans/KMeansRandomTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ public void run() {
Frame score = null;
try {
for (int j = 0; j < parms._k; j++)
Assert.assertTrue(m._output._rows[j] != 0);
Assert.assertTrue(m._output._size[j] != 0);

Assert.assertTrue(m._output._iters <= max_iter);
for (double d : m._output._withinmse) Assert.assertFalse(Double.isNaN(d));
Assert.assertFalse(Double.isNaN(m._output._avgwithinss));
for (long o : m._output._rows) Assert.assertTrue(o > 0); //have at least one point per centroid
for (long o : m._output._size) Assert.assertTrue(o > 0); //have at least one point per centroid
for (double[] dc : m._output._centers) for (double d : dc) Assert.assertFalse(Double.isNaN(d));

// make prediction (cluster assignment)
Expand Down
2 changes: 1 addition & 1 deletion h2o-algos/src/test/java/hex/kmeans/KMeansTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private static KMeansModel doSeed( KMeansModel.KMeansParameters parms, long seed
if (job != null) job.remove();
}
for( int i=0; i<parms._k; i++ )
Assert.assertTrue( "Seed: "+seed, kmm._output._rows[i] != 0 );
Assert.assertTrue( "Seed: "+seed, kmm._output._size[i] != 0 );
return kmm;
}

Expand Down
18 changes: 9 additions & 9 deletions h2o-samples/src/main/java/droplets/KMeansDroplet.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ public static void main(String[] args) throws Exception {
task.doAll(frame);

for( int c = 0; c < centers.length; c++ ) {
if( task._counts[c] > 0 ) {
if( task._size[c] > 0 ) {
for( int v = 0; v < frame.vecs().length; v++ ) {
double value = task._sums[c][v] / task._counts[c];
double value = task._sums[c][v] / task._size[c];
centers[c][v] = value;
}
}
Expand Down Expand Up @@ -93,13 +93,13 @@ public static void main(String[] args) throws Exception {
public static class KMeans extends MRTask<KMeans> {
double[][] _centers; // IN: Centroids/cluster centers

double[][] _sums; // OUT: Sum of features in each cluster
int[] _counts; // OUT: Count of rows in cluster
double _error; // OUT: Total sqr distance
double[][] _sums; // OUT: Sum of features in each cluster
int[] _size; // OUT: Row counts in each cluster
double _error; // OUT: Total sqr distance

@Override public void map(Chunk[] chunks) {
_sums = new double[_centers.length][chunks.length];
_counts = new int[_centers.length];
_size = new int[_centers.length];

// Find nearest cluster for each row
for( int row = 0; row < chunks[0]._len; row++ ) {
Expand All @@ -121,16 +121,16 @@ public static class KMeans extends MRTask<KMeans> {
// Add values and increment counter for chosen cluster
for( int column = 0; column < chunks.length; column++ )
_sums[nearest][column] += chunks[column].at0(row);
_counts[nearest]++;
_size[nearest]++;
}
_centers = null;
}

@Override public void reduce(KMeans task) {
for( int cluster = 0; cluster < _counts.length; cluster++ ) {
for( int cluster = 0; cluster < _size.length; cluster++ ) {
for( int column = 0; column < _sums[0].length; column++ )
_sums[cluster][column] += task._sums[cluster][column];
_counts[cluster] += task._counts[cluster];
_size[cluster] += task._size[cluster];
}
_error += task._error;
}
Expand Down
14 changes: 7 additions & 7 deletions py2/h2o_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, kmeansResult, parameters, numRows, numCols, labels, noPrint=F
if 'python_elapsed' in kmeansResult:
self.python_elapsed = kmeansResult['python_elapsed']

rows = self.rows # [78, 5, 41, 76]
size = self.size # [78, 5, 41, 76]
model_category = self.model_category # Clustering
iters = self.iters # 11.0
domains = self.domains
Expand All @@ -49,12 +49,12 @@ def __init__(self, kmeansResult, parameters, numRows, numCols, labels, noPrint=F
avgss = self.avgss

if numRows:
assert numRows==sum(rows)
assert numRows==sum(size)

if 'k' in parameters:
k = parameters['k']
assert len(centers) == k
assert len(rows) == k
assert len(size) == k

if numCols:
assert len(names) == numCols, \
Expand Down Expand Up @@ -83,18 +83,18 @@ def __init__(self, kmeansResult, parameters, numRows, numCols, labels, noPrint=F
# create a tuple for each cluster result, then sort by rows for easy comparison
# maybe should sort by centers?
# put a cluster index in there too, (leftmost) so we don't lose track
tuples = zip(range(len(centers)), centers, rows, withinmse)
tuples = zip(range(len(centers)), centers, size, withinmse)
# can we sort on the sum of the centers?
self.tuplesSorted = sorted(tuples, key=lambda tup: sum(tup[1]))

print "iters:", iters
# undo for printing what the caller will see
ids, centers, rows, withinmse = zip(*self.tuplesSorted)
ids, centers, size, withinmse = zip(*self.tuplesSorted)
for i,c in enumerate(centers):
print "cluster id %s (2 places):" % ids[i], h2o_util.twoDecimals(c)
print "rows_per_cluster[%s]: " % i, rows[i]
print "rows_per_cluster[%s]: " % i, size[i]
print "withinmse[%s]: " % i, withinmse[i]
print "rows[%s]:" % i, rows[i]
print "size[%s]:" % i, size[i]

print "KMeansObj created for:", "???"# vars(self)

Expand Down
2 changes: 1 addition & 1 deletion py2/testdir_single_jvm/kmeans.jsonschema
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@
1.736904633329732
]
],
"rows": [
"size": [
571753,
9259
],
Expand Down

0 comments on commit 2b18b50

Please sign in to comment.