Get rid of 'sharedModel' in 'Connection' class

This commit is contained in:
Florian Reimair 2019-02-26 13:31:51 +01:00
parent 3909feec37
commit 7fa5c190aa
No known key found for this signature in database
GPG Key ID: 7EA8CA324B6E5633

View File

@ -96,7 +96,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
* All handlers are called on User thread.
*/
@Slf4j
public class Connection implements MessageListener {
public class Connection extends Capabilities implements MessageListener {
///////////////////////////////////////////////////////////////////////////////////////////
// Enums
@ -139,7 +139,6 @@ public class Connection implements MessageListener {
private final String uid;
private final ExecutorService singleThreadExecutor = Executors.newSingleThreadExecutor();
// holder of state shared between InputHandler and Connection
private final SharedModel sharedModel;
private final Statistic statistic;
private final int msgThrottlePer10Sec;
private final int msgThrottlePerSec;
@ -182,8 +181,6 @@ public class Connection implements MessageListener {
addMessageListener(messageListener);
sharedModel = new SharedModel(this, socket);
if (socket.getLocalPort() == 0)
portInfo = "port=" + socket.getPort();
else
@ -203,7 +200,7 @@ public class Connection implements MessageListener {
protoOutputStream = new SynchronizedProtoOutputStream(socket.getOutputStream(), statistic);
InputStream protoInputStream = socket.getInputStream();
// We create a thread for handling inputStream data
inputHandler = new InputHandler(sharedModel, protoInputStream, portInfo, this, networkProtoResolver);
inputHandler = new InputHandler(this, protoInputStream, portInfo, this, networkProtoResolver);
singleThreadExecutor.submit(inputHandler);
// Use Peer as default, in case of other types they will set it as soon as possible.
@ -218,12 +215,6 @@ public class Connection implements MessageListener {
}
}
private void handleException(Throwable e) {
if (sharedModel != null)
sharedModel.handleConnectionException(e);
}
///////////////////////////////////////////////////////////////////////////////////////////
// API
///////////////////////////////////////////////////////////////////////////////////////////
@ -283,7 +274,7 @@ public class Connection implements MessageListener {
handleException(t);
}
} else {
log.info("We did not send the message because the peer does not support our required capabilities. message={}, peers supportedCapabilities={}", networkEnvelope, sharedModel.getSupportedCapabilities());
log.info("We did not send the message because the peer does not support our required capabilities. message={}, peers supportedCapabilities={}", networkEnvelope, getSupportedCapabilities());
}
} else {
log.debug("called sendMessage but was already stopped");
@ -293,12 +284,12 @@ public class Connection implements MessageListener {
public boolean noCapabilityRequiredOrCapabilityIsSupported(Proto msg) {
if (msg instanceof AddDataMessage) {
final ProtectedStoragePayload protectedStoragePayload = (((AddDataMessage) msg).getProtectedStorageEntry()).getProtectedStoragePayload();
return protectedStoragePayload instanceof CapabilityRequiringPayload && sharedModel.isCapabilitySupported(((CapabilityRequiringPayload) protectedStoragePayload).getRequiredCapabilities());
return protectedStoragePayload instanceof CapabilityRequiringPayload && isCapabilitySupported(((CapabilityRequiringPayload) protectedStoragePayload).getRequiredCapabilities());
} else if (msg instanceof AddPersistableNetworkPayloadMessage) {
final PersistableNetworkPayload persistableNetworkPayload = ((AddPersistableNetworkPayloadMessage) msg).getPersistableNetworkPayload();
return persistableNetworkPayload instanceof CapabilityRequiringPayload && sharedModel.isCapabilitySupported(((CapabilityRequiringPayload) persistableNetworkPayload).getRequiredCapabilities());
return persistableNetworkPayload instanceof CapabilityRequiringPayload && isCapabilitySupported(((CapabilityRequiringPayload) persistableNetworkPayload).getRequiredCapabilities());
} else if(msg instanceof CapabilityRequiringPayload) {
return sharedModel.isCapabilitySupported(((CapabilityRequiringPayload) msg).getRequiredCapabilities());
return isCapabilitySupported(((CapabilityRequiringPayload) msg).getRequiredCapabilities());
} else {
return true;
}
@ -306,7 +297,7 @@ public class Connection implements MessageListener {
@Nullable
public Capabilities getSupportedCapabilities() {
return sharedModel.getSupportedCapabilities();
return new Capabilities(capabilities);
}
public void addMessageListener(MessageListener messageListener) {
@ -328,7 +319,7 @@ public class Connection implements MessageListener {
@SuppressWarnings({"unused", "UnusedReturnValue"})
public boolean reportIllegalRequest(RuleViolation ruleViolation) {
return sharedModel.reportInvalidRequest(ruleViolation);
return reportInvalidRequest(ruleViolation);
}
// TODO either use the argument or delete it
@ -412,7 +403,7 @@ public class Connection implements MessageListener {
if (BanList.isBanned(peerNodeAddress)) {
log.warn("We detected a connection to a banned peer. We will close that connection. (setPeersNodeAddress)");
sharedModel.reportInvalidRequest(RuleViolation.PEER_BANNED);
reportInvalidRequest(RuleViolation.PEER_BANNED);
}
}
@ -446,7 +437,7 @@ public class Connection implements MessageListener {
}
public RuleViolation getRuleViolation() {
return sharedModel.getRuleViolation();
return ruleViolation;
}
public Statistic getStatistic() {
@ -478,7 +469,7 @@ public class Connection implements MessageListener {
Thread.currentThread().setName("Connection:SendCloseConnectionMessage-" + this.uid);
try {
String reason = closeConnectionReason == CloseConnectionReason.RULE_VIOLATION ?
sharedModel.getRuleViolation().name() : closeConnectionReason.name();
getRuleViolation().name() : closeConnectionReason.name();
sendMessage(new CloseConnectionMessage(reason));
setStopFlags();
@ -505,7 +496,6 @@ public class Connection implements MessageListener {
private void setStopFlags() {
stopped = true;
sharedModel.stop();
if (inputHandler != null)
inputHandler.stop();
}
@ -514,7 +504,7 @@ public class Connection implements MessageListener {
// Use UserThread.execute as its not clear if that is called from a non-UserThread
UserThread.execute(() -> connectionListener.onDisconnect(closeConnectionReason, this));
try {
sharedModel.getSocket().close();
socket.close();
} catch (SocketException e) {
log.trace("SocketException at shutdown might be expected " + e.getMessage());
} catch (IOException e) {
@ -563,7 +553,10 @@ public class Connection implements MessageListener {
", peerType=" + peerType +
", portInfo=" + portInfo +
", uid='" + uid + '\'' +
", sharedSpace=" + sharedModel.toString() +
", closeConnectionReason=" + closeConnectionReason +
", ruleViolation=" + ruleViolation +
", ruleViolations=" + ruleViolations +
", supportedCapabilities=" + capabilities +
", stopped=" + stopped +
'}';
}
@ -572,130 +565,83 @@ public class Connection implements MessageListener {
///////////////////////////////////////////////////////////////////////////////////////////
// SharedSpace
///////////////////////////////////////////////////////////////////////////////////////////
/**
* Holds all shared data between Connection and InputHandler
* Runs in same thread as Connection
*/
private static class SharedModel extends Capabilities {
private static final Logger log = LoggerFactory.getLogger(SharedModel.class);
private RuleViolation ruleViolation;
private final ConcurrentHashMap<RuleViolation, Integer> ruleViolations = new ConcurrentHashMap<>();
private final Connection connection;
private final Socket socket;
private final ConcurrentHashMap<RuleViolation, Integer> ruleViolations = new ConcurrentHashMap<>();
// mutable
private volatile boolean stopped;
private CloseConnectionReason closeConnectionReason;
private RuleViolation ruleViolation;
public boolean reportInvalidRequest(RuleViolation ruleViolation) {
log.warn("We got reported the ruleViolation {} at connection {}", ruleViolation, this);
int numRuleViolations;
numRuleViolations = ruleViolations.getOrDefault(ruleViolation, 0);
SharedModel(Connection connection, Socket socket) {
this.connection = connection;
this.socket = socket;
}
numRuleViolations++;
ruleViolations.put(ruleViolation, numRuleViolations);
public boolean reportInvalidRequest(RuleViolation ruleViolation) {
log.warn("We got reported the ruleViolation {} at connection {}", ruleViolation, connection);
int numRuleViolations;
numRuleViolations = ruleViolations.getOrDefault(ruleViolation, 0);
numRuleViolations++;
ruleViolations.put(ruleViolation, numRuleViolations);
if (numRuleViolations >= ruleViolation.maxTolerance) {
log.warn("We close connection as we received too many corrupt requests.\n" +
"numRuleViolations={}\n\t" +
"corruptRequest={}\n\t" +
"corruptRequests={}\n\t" +
"connection={}", numRuleViolations, ruleViolation, ruleViolations.toString(), connection);
this.ruleViolation = ruleViolation;
if (ruleViolation == RuleViolation.PEER_BANNED) {
log.warn("We close connection due RuleViolation.PEER_BANNED. peersNodeAddress={}", connection.getPeersNodeAddressOptional());
shutDown(CloseConnectionReason.PEER_BANNED);
} else if (ruleViolation == RuleViolation.INVALID_CLASS) {
log.warn("We close connection due RuleViolation.INVALID_CLASS");
shutDown(CloseConnectionReason.INVALID_CLASS_RECEIVED);
} else {
log.warn("We close connection due RuleViolation.RULE_VIOLATION");
shutDown(CloseConnectionReason.RULE_VIOLATION);
}
return true;
if (numRuleViolations >= ruleViolation.maxTolerance) {
log.warn("We close connection as we received too many corrupt requests.\n" +
"numRuleViolations={}\n\t" +
"corruptRequest={}\n\t" +
"corruptRequests={}\n\t" +
"connection={}", numRuleViolations, ruleViolation, ruleViolations.toString(), this);
this.ruleViolation = ruleViolation;
if (ruleViolation == RuleViolation.PEER_BANNED) {
log.warn("We close connection due RuleViolation.PEER_BANNED. peersNodeAddress={}", getPeersNodeAddressOptional());
shutDown(CloseConnectionReason.PEER_BANNED);
} else if (ruleViolation == RuleViolation.INVALID_CLASS) {
log.warn("We close connection due RuleViolation.INVALID_CLASS");
shutDown(CloseConnectionReason.INVALID_CLASS_RECEIVED);
} else {
return false;
log.warn("We close connection due RuleViolation.RULE_VIOLATION");
shutDown(CloseConnectionReason.RULE_VIOLATION);
}
}
@Nullable
public Capabilities getSupportedCapabilities() {
return new Capabilities(capabilities);
return true;
} else {
return false;
}
}
public void handleConnectionException(Throwable e) {
if (e instanceof SocketException) {
if (socket.isClosed())
closeConnectionReason = CloseConnectionReason.SOCKET_CLOSED;
else
closeConnectionReason = CloseConnectionReason.RESET;
private void handleException(Throwable e) {
if (e instanceof SocketException) {
if (socket.isClosed())
closeConnectionReason = CloseConnectionReason.SOCKET_CLOSED;
else
closeConnectionReason = CloseConnectionReason.RESET;
log.info("SocketException (expected if connection lost). closeConnectionReason={}; connection={}", closeConnectionReason, connection);
} else if (e instanceof SocketTimeoutException || e instanceof TimeoutException) {
closeConnectionReason = CloseConnectionReason.SOCKET_TIMEOUT;
log.info("Shut down caused by exception {} on connection={}", e.toString(), connection);
} else if (e instanceof EOFException) {
closeConnectionReason = CloseConnectionReason.TERMINATED;
log.warn("Shut down caused by exception {} on connection={}", e.toString(), connection);
} else if (e instanceof OptionalDataException || e instanceof StreamCorruptedException) {
closeConnectionReason = CloseConnectionReason.CORRUPTED_DATA;
log.warn("Shut down caused by exception {} on connection={}", e.toString(), connection);
} else {
// TODO sometimes we get StreamCorruptedException, OptionalDataException, IllegalStateException
closeConnectionReason = CloseConnectionReason.UNKNOWN_EXCEPTION;
log.warn("Unknown reason for exception at socket: {}\n\t" +
"peer={}\n\t" +
"Exception={}",
socket.toString(),
connection.peersNodeAddressOptional,
e.toString());
e.printStackTrace();
}
shutDown(closeConnectionReason);
log.info("SocketException (expected if connection lost). closeConnectionReason={}; connection={}", closeConnectionReason, this);
} else if (e instanceof SocketTimeoutException || e instanceof TimeoutException) {
closeConnectionReason = CloseConnectionReason.SOCKET_TIMEOUT;
log.info("Shut down caused by exception {} on connection={}", e.toString(), this);
} else if (e instanceof EOFException) {
closeConnectionReason = CloseConnectionReason.TERMINATED;
log.warn("Shut down caused by exception {} on connection={}", e.toString(), this);
} else if (e instanceof OptionalDataException || e instanceof StreamCorruptedException) {
closeConnectionReason = CloseConnectionReason.CORRUPTED_DATA;
log.warn("Shut down caused by exception {} on connection={}", e.toString(), this);
} else {
// TODO sometimes we get StreamCorruptedException, OptionalDataException, IllegalStateException
closeConnectionReason = CloseConnectionReason.UNKNOWN_EXCEPTION;
log.warn("Unknown reason for exception at socket: {}\n\t" +
"peer={}\n\t" +
"Exception={}",
socket.toString(),
this.peersNodeAddressOptional,
e.toString());
e.printStackTrace();
}
shutDown(closeConnectionReason);
}
private CloseConnectionReason closeConnectionReason;
public void shutDown(CloseConnectionReason closeConnectionReason) {
if (!stopped) {
stopped = true;
connection.shutDown(closeConnectionReason);
}
}
public Socket getSocket() {
return socket;
}
public void stop() {
stopped = true;
}
RuleViolation getRuleViolation() {
return ruleViolation;
}
@Override
public String toString() {
return "SharedModel{" +
"\n connection=" + connection +
",\n socket=" + socket +
",\n ruleViolations=" + ruleViolations +
",\n stopped=" + stopped +
",\n closeConnectionReason=" + closeConnectionReason +
",\n ruleViolation=" + ruleViolation +
",\n supportedCapabilities=" + capabilities +
"\n}";
}
}
///////////////////////////////////////////////////////////////////////////////////////////
// InputHandler
///////////////////////////////////////////////////////////////////////////////////////////
@ -706,7 +652,7 @@ public class Connection implements MessageListener {
private static class InputHandler implements Runnable {
private static final Logger log = LoggerFactory.getLogger(InputHandler.class);
private final SharedModel sharedModel;
private final Connection sharedModel;
private final InputStream protoInputStream;
private final String portInfo;
private final MessageListener messageListener;
@ -716,11 +662,11 @@ public class Connection implements MessageListener {
private long lastReadTimeStamp;
private boolean threadNameSet;
InputHandler(SharedModel sharedModel,
InputStream protoInputStream,
String portInfo,
MessageListener messageListener,
NetworkProtoResolver networkProtoResolver) {
public InputHandler(Connection sharedModel,
InputStream protoInputStream,
String portInfo,
MessageListener messageListener,
NetworkProtoResolver networkProtoResolver) {
this.sharedModel = sharedModel;
this.protoInputStream = protoInputStream;
this.portInfo = portInfo;
@ -746,9 +692,9 @@ public class Connection implements MessageListener {
try {
Thread.currentThread().setName("InputHandler");
while (!stopped && !Thread.currentThread().isInterrupted()) {
if (!threadNameSet && sharedModel.connection != null &&
sharedModel.connection.getPeersNodeAddressOptional().isPresent()) {
Thread.currentThread().setName("InputHandler-" + sharedModel.connection.getPeersNodeAddressOptional().get().getFullAddress());
if (!threadNameSet && sharedModel != null &&
sharedModel.getPeersNodeAddressOptional().isPresent()) {
Thread.currentThread().setName("InputHandler-" + sharedModel.getPeersNodeAddressOptional().get().getFullAddress());
threadNameSet = true;
}
try {
@ -759,8 +705,8 @@ public class Connection implements MessageListener {
return;
}
Connection connection = checkNotNull(sharedModel.connection, "connection must not be null");
log.trace("InputHandler waiting for incoming network_messages.\n\tConnection={}", connection);
Connection connection = checkNotNull(sharedModel, "connection must not be null");
log.trace("InputHandler waiting for incoming network_messages.\n\tConnection=" + connection);
// Throttle inbound network_messages
long now = System.currentTimeMillis();
@ -923,11 +869,11 @@ public class Connection implements MessageListener {
e.printStackTrace();
reportInvalidRequest(RuleViolation.INVALID_DATA_TYPE);
} catch (Throwable t) {
handleException(t);
sharedModel.handleException(t);
}
}
} catch (Throwable t) {
handleException(t);
sharedModel.handleException(t);
}
}
@ -936,13 +882,6 @@ public class Connection implements MessageListener {
sharedModel.shutDown(reason);
}
private void handleException(Throwable e) {
stop();
if (sharedModel != null)
sharedModel.handleConnectionException(e);
}
private boolean reportInvalidRequest(RuleViolation ruleViolation) {
boolean causedShutDown = sharedModel.reportInvalidRequest(ruleViolation);
if (causedShutDown)