Skip to content

Commit

Permalink
fix cost updating bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Ray Mattingly committed Jan 2, 2025
1 parent 421f571 commit 757173f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ boolean shouldSkipSloppyServerEvaluation() {
return isConditionalBalancingEnabled();
}

boolean isConditionalBalancingEnabled() {
return !conditionalClasses.isEmpty();
}

void clearConditionalWeightCaches() {
conditionals.stream().map(RegionPlanConditional::getCandidateGenerator)
.flatMap(Optional::stream).forEach(RegionPlanConditionalCandidateGenerator::clearWeightCache);
Expand Down Expand Up @@ -180,10 +184,6 @@ private RegionPlanConditional createConditional(Class<? extends RegionPlanCondit
return null;
}

private boolean isConditionalBalancingEnabled() {
return !conditionalClasses.isEmpty();
}

@Override
public void setConf(Configuration conf) {
this.conf = conf;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@InterfaceAudience.Private
abstract class CandidateGenerator {

static double MAX_WEIGHT = 100_000;
static double MAX_WEIGHT = 1.0;

abstract BalanceAction generate(BalancerClusterState cluster);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
*/
package org.apache.hadoop.hbase.master.balancer;

import com.google.common.base.Suppliers;
import com.google.errorprone.annotations.RestrictedApi;
import java.lang.reflect.Constructor;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hbase.ClusterMetrics;
Expand Down Expand Up @@ -159,6 +162,13 @@ public class StochasticLoadBalancer extends BaseLoadBalancer {

protected Map<Class<? extends CandidateGenerator>, CandidateGenerator> candidateGenerators;
private Map<Class<? extends CandidateGenerator>, Double> weightsOfGenerators;
private final Supplier<List<Class<? extends CandidateGenerator>>> shuffledGeneratorClasses =
Suppliers.memoizeWithExpiration(() -> {
List<Class<? extends CandidateGenerator>> shuffled =
new ArrayList<>(candidateGenerators.keySet());
Collections.shuffle(shuffled);
return shuffled;
}, 5, TimeUnit.SECONDS);

private final BalancerConditionals balancerConditionals = BalancerConditionals.INSTANCE;

Expand Down Expand Up @@ -435,47 +445,50 @@ Pair<CandidateGenerator, BalanceAction> nextAction(BalancerClusterState cluster)
* all cost functions that benefit from it.
*/
protected CandidateGenerator getRandomGenerator(BalancerClusterState cluster) {
double sum = 0;
for (Class<? extends CandidateGenerator> clazz : candidateGenerators.keySet()) {
sum += weightsOfGenerators.getOrDefault(clazz, 0.0);
}
if (sum == 0) {
return candidateGenerators.values().stream().findAny().orElseThrow();
// Prefer conditional generators if they have moves to make
if (balancerConditionals.isConditionalBalancingEnabled()) {
for (RegionPlanConditional conditional : balancerConditionals.getConditionals()) {
Optional<RegionPlanConditionalCandidateGenerator> generator =
conditional.getCandidateGenerator();
if (generator.isPresent() && generator.get().getWeight(cluster) > 0) {
return generator.get();
}
}
}

for (Class<? extends CandidateGenerator> clazz : candidateGenerators.keySet()) {
weightsOfGenerators.put(clazz,
Math.min(CandidateGenerator.MAX_WEIGHT, weightsOfGenerators.get(clazz) / sum));
List<Class<? extends CandidateGenerator>> generatorClasses = shuffledGeneratorClasses.get();
List<Double> partialSums = new ArrayList<>(generatorClasses.size());
double sum = 0.0;
for (Class<? extends CandidateGenerator> clazz : generatorClasses) {
double weight = weightsOfGenerators.getOrDefault(clazz, 0.0);
sum += weight;
partialSums.add(sum);
}

for (RegionPlanConditional conditional : balancerConditionals.getConditionals()) {
Optional<RegionPlanConditionalCandidateGenerator> generator =
conditional.getCandidateGenerator();
if (generator.isPresent() && generator.get().getWeight(cluster) > 0) {
return generator.get();
// If the sum of all weights is zero, fall back to a default (e.g., first in the list)
if (sum == 0.0) {
// If no generators at all, fail fast or throw
if (generatorClasses.isEmpty()) {
throw new IllegalStateException("No candidate generators available");
}
return candidateGenerators.get(generatorClasses.get(0));
}

// Normalize partial sums so that the last one should be exactly 1.0
for (int i = 0; i < partialSums.size(); i++) {
partialSums.set(i, partialSums.get(i) / sum);
}

// Generate a random number and pick the first generator whose partial sum is >= rand
double rand = ThreadLocalRandom.current().nextDouble();
CandidateGenerator greatestWeightGenerator = null;
double greatestWeight = 0;
for (Map.Entry<Class<? extends CandidateGenerator>,
CandidateGenerator> entry : candidateGenerators.entrySet()) {
Class<? extends CandidateGenerator> clazz = entry.getKey();
double generatorWeight = weightsOfGenerators.get(clazz);
if (generatorWeight > greatestWeight) {
greatestWeight = generatorWeight;
greatestWeightGenerator = entry.getValue();
}
if (rand <= generatorWeight) {
return entry.getValue();
for (int i = 0; i < partialSums.size(); i++) {
if (rand <= partialSums.get(i)) {
return candidateGenerators.get(generatorClasses.get(i));
}
}

if (greatestWeightGenerator != null) {
return greatestWeightGenerator;
}
return candidateGenerators.values().stream().findAny().orElseThrow();
// Fallback: if for some reason we didn't return above, return the last generator
return candidateGenerators.get(generatorClasses.get(generatorClasses.size() - 1));
}

@RestrictedApi(explanation = "Should only be called in tests", link = "",
Expand Down Expand Up @@ -565,7 +578,7 @@ protected List<RegionPlan> balanceTable(TableName tableName,
// Perform a stochastic walk to see if we can get a good fit.
long step;

boolean improvedConditionals = false;
boolean planImprovedConditionals = false;
Map<Class<? extends CandidateGenerator>, Long> generatorToStepCount = new HashMap<>();
Map<Class<? extends CandidateGenerator>, Long> generatorToApprovedActionCount = new HashMap<>();
for (step = 0; step < computedMaxSteps; step++) {
Expand All @@ -576,41 +589,50 @@ protected List<RegionPlan> balanceTable(TableName tableName,
if (action.getType() == BalanceAction.Type.NULL) {
continue;
}

generatorToStepCount.merge(generator.getClass(), action.getStepCount(), Long::sum);
step += action.getStepCount() - 1;

// Always accept a conditional generator output. Sometimes conditional generators
// may need to make controversial moves in order to break what would otherwise
// be a deadlocked situation.
// Otherwise, for normal moves, evaluate the action.
int conditionalViolationsChange;
boolean isViolating = false;
if (RegionPlanConditionalCandidateGenerator.class.isAssignableFrom(generator.getClass())) {
conditionalViolationsChange = -1;
} else {
conditionalViolationsChange = balancerConditionals.getViolationCountChange(cluster, action);
isViolating = balancerConditionals.isViolating(cluster, action);
long additionalSteps = action.getStepCount() - 1;
if (additionalSteps > 0) {
step += additionalSteps;
}

int conditionalViolationsChange = 0;
boolean isViolatingConditionals = false;
boolean moveImprovedConditionals = false;
// Only check conditionals if they are enabled
if (balancerConditionals.isConditionalBalancingEnabled()) {
// Always accept a conditional generator output. Sometimes conditional generators
// may need to make controversial moves in order to break what would otherwise
// be a deadlocked situation.
// Otherwise, for normal moves, evaluate the action.
if (RegionPlanConditionalCandidateGenerator.class.isAssignableFrom(generator.getClass())) {
conditionalViolationsChange = -1;
} else {
conditionalViolationsChange =
balancerConditionals.getViolationCountChange(cluster, action);
isViolatingConditionals = balancerConditionals.isViolating(cluster, action);
}
moveImprovedConditionals = conditionalViolationsChange < 0;
if (moveImprovedConditionals) {
planImprovedConditionals = true;
}
}

// Change state and evaluate costs
cluster.doAction(action);
updateCostsAndWeightsWithAction(cluster, action);
newCost = computeCost(cluster, currentCost);

boolean conditionalsImproved = conditionalViolationsChange < 0;
if (conditionalsImproved) {
improvedConditionals = true;
}
boolean conditionalsSimilarCostsImproved =
(newCost < currentCost && conditionalViolationsChange == 0 && !isViolating);
(newCost < currentCost && conditionalViolationsChange == 0 && !isViolatingConditionals);
// Our first priority is to reduce conditional violations
// Our second priority is to reduce balancer cost
// change, regardless of cost change
if (conditionalsImproved || conditionalsSimilarCostsImproved) {
if (moveImprovedConditionals || conditionalsSimilarCostsImproved) {
currentCost = newCost;
generatorToApprovedActionCount.merge(generator.getClass(), action.getStepCount(),
Long::sum);
balancerConditionals.loadClusterState(cluster);
updateCostsAndWeightsWithAction(cluster, action);

// save for JMX
curOverallCost = currentCost;
Expand Down Expand Up @@ -641,7 +663,7 @@ protected List<RegionPlan> balanceTable(TableName tableName,

metricsBalancer.balanceCluster(endTime - startTime);

if (improvedConditionals || initCost > currentCost) {
if (planImprovedConditionals || initCost > currentCost) {
updateStochasticCosts(tableName, curOverallCost, curFunctionCosts);
List<RegionPlan> plans = createRegionPlans(cluster);
LOG.info(
Expand Down

0 comments on commit 757173f

Please sign in to comment.