Prevent excessive api calls

This change provides a gRPC CallRateMeteringInterceptor to help protect
the server and network against being overloaded by CLI scripting mistakes.

An interceptor instance can be configured on a gRPC service to set
individual method call rate limits on one or more of the the service's
methods. For example, the GrpcOffersService could be configured with
this interceptor to set the createoffer rate limit to 5/hour, and
the takeoffer call rate limit could be set to 20/day.  Whenever a
call rate limit is exceeded, the gRPC call is aborted and the client
recieves a "rate limit exceeded" error.

Below is a simple example showing how to set rate limits for one method
in GrpcVersionService.

    final ServerInterceptor[] interceptors() {
        return new ServerInterceptor[]{
                new CallRateMeteringInterceptor(new HashMap<>() {{
                    put("getVersion", new GrpcCallRateMeter(2, SECONDS));
                }})
        };
    }

It specifies a CLI can execute getversion 2 times / second.

This is not a throttling mechansim, there is no blocking nor locking
to slow call rates.  When call rates are exceeded, calls are
simply aborted.
This commit is contained in:
ghubstan 2020-12-17 12:33:45 -03:00
parent fa9ffa1fb2
commit 2148a4d958
No known key found for this signature in database
GPG key ID: E35592D6800A861E
2 changed files with 172 additions and 0 deletions

View file

@ -0,0 +1,107 @@
package bisq.daemon.grpc.interceptor;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.StatusRuntimeException;
import org.apache.commons.lang3.StringUtils;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import lombok.extern.slf4j.Slf4j;
import static io.grpc.Status.FAILED_PRECONDITION;
import static io.grpc.Status.PERMISSION_DENIED;
import static java.lang.String.format;
@Slf4j
public final class CallRateMeteringInterceptor implements ServerInterceptor {
// Maps the gRPC server method names to rate meters. This allows one interceptor
// instance to handle rate metering for any or all the methods in a Grpc*Service.
protected final Map<String, GrpcCallRateMeter> serviceCallRateMeters;
public CallRateMeteringInterceptor(Map<String, GrpcCallRateMeter> serviceCallRateMeters) {
this.serviceCallRateMeters = serviceCallRateMeters;
}
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall,
Metadata headers,
ServerCallHandler<ReqT, RespT> serverCallHandler) {
Optional<Map.Entry<String, GrpcCallRateMeter>> rateMeterKV = getRateMeterKV(serverCall);
rateMeterKV.ifPresentOrElse(
(kv) -> checkRateMeterAndMaybeCloseCall(kv, serverCall),
() -> handleInterceptorConfigErrorAndCloseCall(serverCall));
// We leave it to the gRPC framework to clean up if the server call was closed
// above. But we still have to invoke startCall here because the method must
// return a ServerCall.Listener<RequestT>.
return serverCallHandler.startCall(serverCall, headers);
}
private void checkRateMeterAndMaybeCloseCall(Map.Entry<String, GrpcCallRateMeter> rateMeterKV,
ServerCall<?, ?> serverCall) {
String methodName = rateMeterKV.getKey();
GrpcCallRateMeter rateMeter = rateMeterKV.getValue();
// The service method's rate meter doesn't start running until the 1st call.
if (!rateMeter.isRunning())
rateMeter.start();
rateMeter.incrementCallsCount();
if (rateMeter.isCallRateExceeded())
handlePermissionDeniedErrorAndCloseCall(methodName, rateMeter, serverCall);
else
log.info(rateMeter.getCallsCountProgress(methodName));
}
private void handleInterceptorConfigErrorAndCloseCall(ServerCall<?, ?> serverCall)
throws StatusRuntimeException {
String methodName = getRateMeterKey(serverCall);
String msg = format("%s's rate metering interceptor is incorrectly configured;"
+ " its rate meter cannot be found ",
methodName);
log.error(StringUtils.capitalize(msg) + ".");
serverCall.close(FAILED_PRECONDITION.withDescription(msg), new Metadata());
}
private void handlePermissionDeniedErrorAndCloseCall(String methodName,
GrpcCallRateMeter rateMeter,
ServerCall<?, ?> serverCall)
throws StatusRuntimeException {
String msg = getDefaultRateExceededError(methodName, rateMeter);
log.error(StringUtils.capitalize(msg) + ".");
serverCall.close(PERMISSION_DENIED.withDescription(msg), new Metadata());
}
private String getDefaultRateExceededError(String methodName,
GrpcCallRateMeter rateMeter) {
// The derived method name may not be an exact match to CLI's method name.
String timeUnitName = StringUtils.chop(rateMeter.getTimeUnit().name().toLowerCase());
return format("the maximum allowed number of %s calls (%d/%s) has been exceeded by %d calls",
methodName.toLowerCase(),
rateMeter.getAllowedCallsPerTimeUnit(),
timeUnitName,
rateMeter.getCallsCount() - rateMeter.getAllowedCallsPerTimeUnit());
}
private Optional<Map.Entry<String, GrpcCallRateMeter>> getRateMeterKV(ServerCall<?, ?> serverCall) {
String rateMeterKey = getRateMeterKey(serverCall);
return serviceCallRateMeters.entrySet().stream()
.filter((e) -> e.getKey().equals(rateMeterKey)).findFirst();
}
private String getRateMeterKey(ServerCall<?, ?> serverCall) {
// Get the rate meter map key from the full rpc service name. The key name
// is hard coded in the Grpc*Service interceptors() method.
String fullServiceName = serverCall.getMethodDescriptor().getServiceName();
return StringUtils.uncapitalize(Objects.requireNonNull(fullServiceName)
.substring("io.bisq.protobuffer.".length()));
}
}

View file

@ -0,0 +1,65 @@
package bisq.daemon.grpc.interceptor;
import bisq.common.Timer;
import bisq.common.UserThread;
import org.apache.commons.lang3.StringUtils;
import java.util.concurrent.TimeUnit;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import javax.annotation.Nullable;
import static java.lang.String.format;
@Slf4j
public final class GrpcCallRateMeter {
@Getter
private final long allowedCallsPerTimeUnit;
@Getter
private final TimeUnit timeUnit;
@Getter
private long callsCount = 0;
@Getter
private boolean isRunning;
@Nullable
private Timer timer;
public GrpcCallRateMeter(long allowedCallsPerTimeUnit, TimeUnit timeUnit) {
this.allowedCallsPerTimeUnit = allowedCallsPerTimeUnit;
this.timeUnit = timeUnit;
}
public void start() {
if (timer != null)
timer.stop();
timer = UserThread.runPeriodically(() -> callsCount = 0, 1, timeUnit);
isRunning = true;
}
public void incrementCallsCount() {
callsCount++;
}
public boolean isCallRateExceeded() {
return callsCount > allowedCallsPerTimeUnit;
}
public String getCallsCountProgress(String calledMethodName) {
String shortTimeUnitName = StringUtils.chop(timeUnit.name().toLowerCase());
return format("%s has been called %d time%s in the last %s; the rate limit is %d/%s.",
calledMethodName,
callsCount,
callsCount == 1 ? "" : "s",
shortTimeUnitName,
allowedCallsPerTimeUnit,
shortTimeUnitName);
}
}