Skip to content

Commit d057a7e

Browse files
authored
[xds] Implement A114: WRR support for custom backend metrics (#12645)
### Description This PR implements [gRFC A114: WRR Support for Custom Backend Metrics](grpc/proposal#536). It updates the `weighted_round_robin` policy to allow users to configure which backend metrics drive the load balancing weights. ### Key Changes * **Configuration**: Supports the new `metric_names_for_computing_utilization` field in `WeightedRoundRobinLbConfig`. * **Weight Calculation**: Implements logic to resolve custom metrics (including map lookups like `named_metrics.foo`) when `application_utilization` is absent. * **Refactor**: Centralizes the complex metric lookup and validation logic (checking for NaN, <= 0, etc.) into a new internal utility `MetricReportUtils`. * **Testing**: Verifies correct precedence: `application_utilization` > `custom_metrics` (max valid value) > `cpu_utilization`.
1 parent 842636f commit d057a7e

8 files changed

Lines changed: 614 additions & 74 deletions

xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class LoadBalancerConfigFactory {
9191
static final String SHUFFLE_ADDRESS_LIST_FIELD_NAME = "shuffleAddressList";
9292

9393
static final String ERROR_UTILIZATION_PENALTY = "errorUtilizationPenalty";
94+
static final String METRIC_NAMES_FOR_COMPUTING_UTILIZATION = "metricNamesForComputingUtilization";
9495

9596
/**
9697
* Factory method for creating a new {link LoadBalancerConfigConverter} for a given xDS {@link
@@ -134,11 +135,9 @@ class LoadBalancerConfigFactory {
134135
* the given config values.
135136
*/
136137
private static ImmutableMap<String, ?> buildWrrConfig(String blackoutPeriod,
137-
String weightExpirationPeriod,
138-
String oobReportingPeriod,
139-
Boolean enableOobLoadReport,
140-
String weightUpdatePeriod,
141-
Float errorUtilizationPenalty) {
138+
String weightExpirationPeriod, String oobReportingPeriod, Boolean enableOobLoadReport,
139+
String weightUpdatePeriod, Float errorUtilizationPenalty,
140+
ImmutableList<String> metricNamesForComputingUtilization) {
142141
ImmutableMap.Builder<String, Object> configBuilder = ImmutableMap.builder();
143142
if (blackoutPeriod != null) {
144143
configBuilder.put(BLACK_OUT_PERIOD, blackoutPeriod);
@@ -158,6 +157,10 @@ class LoadBalancerConfigFactory {
158157
if (errorUtilizationPenalty != null) {
159158
configBuilder.put(ERROR_UTILIZATION_PENALTY, errorUtilizationPenalty);
160159
}
160+
if (metricNamesForComputingUtilization != null
161+
&& !metricNamesForComputingUtilization.isEmpty()) {
162+
configBuilder.put(METRIC_NAMES_FOR_COMPUTING_UTILIZATION, metricNamesForComputingUtilization);
163+
}
161164
return ImmutableMap.of(WeightedRoundRobinLoadBalancerProvider.SCHEME,
162165
configBuilder.buildOrThrow());
163166
}
@@ -284,7 +287,7 @@ static class LoadBalancingPolicyConverter {
284287
}
285288

286289
private static ImmutableMap<String, ?> convertWeightedRoundRobinConfig(
287-
ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException {
290+
ClientSideWeightedRoundRobin wrr) throws ResourceInvalidException {
288291
try {
289292
return buildWrrConfig(
290293
wrr.hasBlackoutPeriod() ? Durations.toString(wrr.getBlackoutPeriod()) : null,
@@ -293,7 +296,8 @@ static class LoadBalancingPolicyConverter {
293296
wrr.hasOobReportingPeriod() ? Durations.toString(wrr.getOobReportingPeriod()) : null,
294297
wrr.hasEnableOobLoadReport() ? wrr.getEnableOobLoadReport().getValue() : null,
295298
wrr.hasWeightUpdatePeriod() ? Durations.toString(wrr.getWeightUpdatePeriod()) : null,
296-
wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null);
299+
wrr.hasErrorUtilizationPenalty() ? wrr.getErrorUtilizationPenalty().getValue() : null,
300+
ImmutableList.copyOf(wrr.getMetricNamesForComputingUtilizationList()));
297301
} catch (IllegalArgumentException ex) {
298302
throw new ResourceInvalidException("Invalid duration in weighted round robin config: "
299303
+ ex.getMessage());

xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import io.grpc.services.MetricReport;
4141
import io.grpc.util.ForwardingSubchannel;
4242
import io.grpc.util.MultiChildLoadBalancer;
43+
import io.grpc.xds.internal.MetricReportUtils;
4344
import io.grpc.xds.orca.OrcaOobUtil;
4445
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
4546
import io.grpc.xds.orca.OrcaPerRequestUtil;
@@ -49,6 +50,7 @@
4950
import java.util.HashSet;
5051
import java.util.List;
5152
import java.util.Objects;
53+
import java.util.OptionalDouble;
5254
import java.util.Random;
5355
import java.util.Set;
5456
import java.util.concurrent.ScheduledExecutorService;
@@ -189,7 +191,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
189191
this.backendService = "";
190192
}
191193
config =
192-
(WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
194+
(WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
193195

194196
if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
195197
weightUpdateTimer.cancel();
@@ -236,7 +238,8 @@ protected void updateOverallBalancingState() {
236238

237239
private SubchannelPicker createReadyPicker(Collection<ChildLbState> activeList) {
238240
WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
239-
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence);
241+
config.enableOobLoadReport, config.errorUtilizationPenalty, sequence,
242+
config.metricNamesForComputingUtilization);
240243
updateWeight(picker);
241244
return picker;
242245
}
@@ -325,12 +328,16 @@ public void addSubchannel(WrrSubchannel wrrSubchannel) {
325328
subchannels.add(wrrSubchannel);
326329
}
327330

328-
public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) {
331+
public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty,
332+
ImmutableList<String> metricNamesForComputingUtilization) {
329333
if (orcaReportListener != null
330-
&& orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) {
334+
&& orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty
335+
&& orcaReportListener.metricNamesForComputingUtilization
336+
.equals(metricNamesForComputingUtilization)) {
331337
return orcaReportListener;
332338
}
333-
orcaReportListener = new OrcaReportListener(errorUtilizationPenalty);
339+
orcaReportListener =
340+
new OrcaReportListener(errorUtilizationPenalty, metricNamesForComputingUtilization);
334341
return orcaReportListener;
335342
}
336343

@@ -355,18 +362,19 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne
355362

356363
final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
357364
private final float errorUtilizationPenalty;
365+
private final ImmutableList<String> metricNamesForComputingUtilization;
358366

359-
OrcaReportListener(float errorUtilizationPenalty) {
367+
OrcaReportListener(float errorUtilizationPenalty,
368+
ImmutableList<String> metricNamesForComputingUtilization) {
360369
this.errorUtilizationPenalty = errorUtilizationPenalty;
370+
this.metricNamesForComputingUtilization = metricNamesForComputingUtilization;
361371
}
362372

363373
@Override
364374
public void onLoadReport(MetricReport report) {
375+
double utilization = getUtilization(report, metricNamesForComputingUtilization);
376+
365377
double newWeight = 0;
366-
// Prefer application utilization and fallback to CPU utilization if unset.
367-
double utilization =
368-
report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
369-
: report.getCpuUtilization();
370378
if (utilization > 0 && report.getQps() > 0) {
371379
double penalty = 0;
372380
if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
@@ -383,6 +391,40 @@ public void onLoadReport(MetricReport report) {
383391
lastUpdated = ticker.nanoTime();
384392
weight = newWeight;
385393
}
394+
395+
/**
396+
* Returns the utilization value computed from the specified metric names. If the custom
397+
* metrics are present and valid, the maximum of the custom metrics is returned. Otherwise,
398+
* if application utilization is > 0, it is returned. If neither are present, the CPU
399+
* utilization is returned.
400+
*/
401+
private double getUtilization(MetricReport report, ImmutableList<String> metricNames) {
402+
OptionalDouble customUtil = getCustomMetricUtilization(report, metricNames);
403+
if (customUtil.isPresent()) {
404+
return customUtil.getAsDouble();
405+
}
406+
double appUtil = report.getApplicationUtilization();
407+
if (appUtil > 0) {
408+
return appUtil;
409+
}
410+
return report.getCpuUtilization();
411+
}
412+
413+
/**
414+
* Returns the maximum utilization value among the specified metric names.
415+
* Returns OptionalDouble.empty() if NONE of the specified metrics are present in the report,
416+
* or if all present metrics are NaN.
417+
* Returns OptionalDouble.of(maxUtil) if at least one non-NaN metric is present.
418+
*/
419+
private OptionalDouble getCustomMetricUtilization(MetricReport report,
420+
ImmutableList<String> metricNames) {
421+
return metricNames.stream()
422+
.map(name -> MetricReportUtils.getMetric(report, name))
423+
.filter(OptionalDouble::isPresent)
424+
.mapToDouble(OptionalDouble::getAsDouble)
425+
.filter(d -> !Double.isNaN(d) && d > 0)
426+
.max();
427+
}
386428
}
387429
}
388430

@@ -403,10 +445,10 @@ private void createAndApplyOrcaListeners() {
403445
for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
404446
if (config.enableOobLoadReport) {
405447
OrcaOobUtil.setListener(weightedSubchannel,
406-
wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty),
448+
wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty,
449+
config.metricNamesForComputingUtilization),
407450
OrcaOobUtil.OrcaReportingConfig.newBuilder()
408-
.setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
409-
.build());
451+
.setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS).build());
410452
} else {
411453
OrcaOobUtil.setListener(weightedSubchannel, null, null);
412454
}
@@ -473,7 +515,8 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker {
473515
private volatile StaticStrideScheduler scheduler;
474516

475517
WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
476-
float errorUtilizationPenalty, AtomicInteger sequence) {
518+
float errorUtilizationPenalty, AtomicInteger sequence,
519+
ImmutableList<String> metricNamesForComputingUtilization) {
477520
checkNotNull(children, "children");
478521
Preconditions.checkArgument(!children.isEmpty(), "empty child list");
479522
this.children = children;
@@ -482,7 +525,8 @@ static final class WeightedRoundRobinPicker extends SubchannelPicker {
482525
for (ChildLbState child : children) {
483526
WeightedChildLbState wChild = (WeightedChildLbState) child;
484527
pickers.add(wChild.getCurrentPicker());
485-
reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
528+
reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty,
529+
metricNamesForComputingUtilization));
486530
}
487531
this.pickers = pickers;
488532
this.reportListeners = reportListeners;
@@ -723,23 +767,23 @@ static final class WeightedRoundRobinLoadBalancerConfig {
723767
final long oobReportingPeriodNanos;
724768
final long weightUpdatePeriodNanos;
725769
final float errorUtilizationPenalty;
770+
final ImmutableList<String> metricNamesForComputingUtilization;
726771

727772
public static Builder newBuilder() {
728773
return new Builder();
729774
}
730775

731776
private WeightedRoundRobinLoadBalancerConfig(long blackoutPeriodNanos,
732-
long weightExpirationPeriodNanos,
733-
boolean enableOobLoadReport,
734-
long oobReportingPeriodNanos,
735-
long weightUpdatePeriodNanos,
736-
float errorUtilizationPenalty) {
777+
long weightExpirationPeriodNanos, boolean enableOobLoadReport, long oobReportingPeriodNanos,
778+
long weightUpdatePeriodNanos, float errorUtilizationPenalty,
779+
ImmutableList<String> metricNamesForComputingUtilization) {
737780
this.blackoutPeriodNanos = blackoutPeriodNanos;
738781
this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
739782
this.enableOobLoadReport = enableOobLoadReport;
740783
this.oobReportingPeriodNanos = oobReportingPeriodNanos;
741784
this.weightUpdatePeriodNanos = weightUpdatePeriodNanos;
742785
this.errorUtilizationPenalty = errorUtilizationPenalty;
786+
this.metricNamesForComputingUtilization = metricNamesForComputingUtilization;
743787
}
744788

745789
@Override
@@ -754,27 +798,26 @@ public boolean equals(Object o) {
754798
&& this.oobReportingPeriodNanos == that.oobReportingPeriodNanos
755799
&& this.weightUpdatePeriodNanos == that.weightUpdatePeriodNanos
756800
// Float.compare considers NaNs equal
757-
&& Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0;
801+
&& Float.compare(this.errorUtilizationPenalty, that.errorUtilizationPenalty) == 0
802+
&& Objects.equals(this.metricNamesForComputingUtilization,
803+
that.metricNamesForComputingUtilization);
758804
}
759805

760806
@Override
761807
public int hashCode() {
762-
return Objects.hash(
763-
blackoutPeriodNanos,
764-
weightExpirationPeriodNanos,
765-
enableOobLoadReport,
766-
oobReportingPeriodNanos,
767-
weightUpdatePeriodNanos,
768-
errorUtilizationPenalty);
808+
return Objects.hash(blackoutPeriodNanos, weightExpirationPeriodNanos, enableOobLoadReport,
809+
oobReportingPeriodNanos, weightUpdatePeriodNanos, errorUtilizationPenalty,
810+
metricNamesForComputingUtilization);
769811
}
770812

771813
static final class Builder {
772814
long blackoutPeriodNanos = 10_000_000_000L; // 10s
773-
long weightExpirationPeriodNanos = 180_000_000_000L; //3min
815+
long weightExpirationPeriodNanos = 180_000_000_000L; // 3min
774816
boolean enableOobLoadReport = false;
775817
long oobReportingPeriodNanos = 10_000_000_000L; // 10s
776818
long weightUpdatePeriodNanos = 1_000_000_000L; // 1s
777819
float errorUtilizationPenalty = 1.0F;
820+
ImmutableList<String> metricNamesForComputingUtilization = ImmutableList.of();
778821

779822
private Builder() {
780823

@@ -812,10 +855,17 @@ Builder setErrorUtilizationPenalty(float errorUtilizationPenalty) {
812855
return this;
813856
}
814857

858+
Builder setMetricNamesForComputingUtilization(
859+
List<String> metricNamesForComputingUtilization) {
860+
this.metricNamesForComputingUtilization =
861+
ImmutableList.copyOf(metricNamesForComputingUtilization);
862+
return this;
863+
}
864+
815865
WeightedRoundRobinLoadBalancerConfig build() {
816866
return new WeightedRoundRobinLoadBalancerConfig(blackoutPeriodNanos,
817-
weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
818-
weightUpdatePeriodNanos, errorUtilizationPenalty);
867+
weightExpirationPeriodNanos, enableOobLoadReport, oobReportingPeriodNanos,
868+
weightUpdatePeriodNanos, errorUtilizationPenalty, metricNamesForComputingUtilization);
819869
}
820870
}
821871
}

xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import io.grpc.LoadBalancerProvider;
2525
import io.grpc.NameResolver.ConfigOrError;
2626
import io.grpc.Status;
27+
import io.grpc.internal.GrpcUtil;
2728
import io.grpc.internal.JsonUtil;
2829
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
30+
import java.util.List;
2931
import java.util.Map;
3032

3133
/**
@@ -73,14 +75,16 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map<String, ?> rawConfig) {
7375
private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map<String, ?> rawConfig) {
7476
Long blackoutPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "blackoutPeriod");
7577
Long weightExpirationPeriodNanos =
76-
JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod");
78+
JsonUtil.getStringAsDuration(rawConfig, "weightExpirationPeriod");
7779
Long oobReportingPeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "oobReportingPeriod");
7880
Boolean enableOobLoadReport = JsonUtil.getBoolean(rawConfig, "enableOobLoadReport");
7981
Long weightUpdatePeriodNanos = JsonUtil.getStringAsDuration(rawConfig, "weightUpdatePeriod");
8082
Float errorUtilizationPenalty = JsonUtil.getNumberAsFloat(rawConfig, "errorUtilizationPenalty");
83+
List<String> metricNamesForComputingUtilization = JsonUtil.getListOfStrings(rawConfig,
84+
LoadBalancerConfigFactory.METRIC_NAMES_FOR_COMPUTING_UTILIZATION);
8185

8286
WeightedRoundRobinLoadBalancerConfig.Builder configBuilder =
83-
WeightedRoundRobinLoadBalancerConfig.newBuilder();
87+
WeightedRoundRobinLoadBalancerConfig.newBuilder();
8488
if (blackoutPeriodNanos != null) {
8589
configBuilder.setBlackoutPeriodNanos(blackoutPeriodNanos);
8690
}
@@ -102,6 +106,10 @@ private ConfigOrError parseLoadBalancingPolicyConfigInternal(Map<String, ?> rawC
102106
if (errorUtilizationPenalty != null) {
103107
configBuilder.setErrorUtilizationPenalty(errorUtilizationPenalty);
104108
}
109+
if (metricNamesForComputingUtilization != null
110+
&& GrpcUtil.getFlag("GRPC_EXPERIMENTAL_WRR_CUSTOM_METRICS", false)) {
111+
configBuilder.setMetricNamesForComputingUtilization(metricNamesForComputingUtilization);
112+
}
105113
return ConfigOrError.fromConfig(configBuilder.build());
106114
}
107115
}

0 commit comments

Comments
 (0)