Skip to content

Commit cd44ef1

Browse files
committed
[ETCM-102] Fix race condition in frame encoder/decoder
1 parent 030870a commit cd44ef1

File tree

6 files changed

+204
-96
lines changed

6 files changed

+204
-96
lines changed

src/main/scala/io/iohk/ethereum/network/PeerActor.scala

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
package io.iohk.ethereum.network
22

3-
import java.net.{ InetSocketAddress, URI }
3+
import java.net.{InetSocketAddress, URI}
44

55
import akka.actor.SupervisorStrategy.Escalate
66
import akka.actor._
77
import akka.util.ByteString
88
import io.iohk.ethereum.network.PeerActor.Status._
9-
import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.{ MessageFromPeer, PeerHandshakeSuccessful }
9+
import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.{MessageFromPeer, PeerHandshakeSuccessful}
1010
import io.iohk.ethereum.network.PeerEventBusActor.Publish
1111
import io.iohk.ethereum.network.PeerManagerActor.PeerConfiguration
1212
import io.iohk.ethereum.network.handshaker.Handshaker
13-
import io.iohk.ethereum.network.handshaker.Handshaker.HandshakeComplete.{ HandshakeFailure, HandshakeSuccess }
14-
import io.iohk.ethereum.network.handshaker.Handshaker.{ HandshakeResult, NextMessage }
13+
import io.iohk.ethereum.network.handshaker.Handshaker.HandshakeComplete.{HandshakeFailure, HandshakeSuccess}
14+
import io.iohk.ethereum.network.handshaker.Handshaker.{HandshakeResult, NextMessage}
1515
import io.iohk.ethereum.network.p2p._
1616
import io.iohk.ethereum.network.p2p.messages.Versions
1717
import io.iohk.ethereum.network.p2p.messages.WireProtocol._
1818
import io.iohk.ethereum.network.rlpx.RLPxConnectionHandler.RLPxConfiguration
19-
import io.iohk.ethereum.network.rlpx.{ AuthHandshaker, RLPxConnectionHandler }
19+
import io.iohk.ethereum.network.rlpx.{AuthHandshaker, RLPxConnectionHandler}
2020
import org.bouncycastle.util.encoders.Hex
2121

2222

@@ -28,18 +28,18 @@ import org.bouncycastle.util.encoders.Hex
2828
* Once that's done it can send/receive messages with peer (HandshakedHandler.receive).
2929
*/
3030
class PeerActor[R <: HandshakeResult](
31-
peerAddress: InetSocketAddress,
32-
rlpxConnectionFactory: ActorContext => ActorRef,
33-
val peerConfiguration: PeerConfiguration,
34-
peerEventBus: ActorRef,
35-
knownNodesManager: ActorRef,
36-
incomingConnection: Boolean,
37-
externalSchedulerOpt: Option[Scheduler] = None,
38-
initHandshaker: Handshaker[R])
31+
peerAddress: InetSocketAddress,
32+
rlpxConnectionFactory: ActorContext => ActorRef,
33+
val peerConfiguration: PeerConfiguration,
34+
peerEventBus: ActorRef,
35+
knownNodesManager: ActorRef,
36+
incomingConnection: Boolean,
37+
externalSchedulerOpt: Option[Scheduler] = None,
38+
initHandshaker: Handshaker[R])
3939
extends Actor with ActorLogging with Stash {
4040

4141
import PeerActor._
42-
import context.{ dispatcher, system }
42+
import context.{dispatcher, system}
4343

4444
override val supervisorStrategy: OneForOneStrategy =
4545
OneForOneStrategy() {
@@ -73,6 +73,7 @@ class PeerActor[R <: HandshakeResult](
7373
context watch ref
7474
RLPxConnection(ref, remoteAddress, uriOpt)
7575
}
76+
7677
private def modifyOutGoingUri(remoteNodeId: ByteString, rlpxConnection: RLPxConnection, uri: URI): URI = {
7778
val host = getHostName(rlpxConnection.remoteAddress.getAddress)
7879
val port = rlpxConnection.remoteAddress.getPort
@@ -106,23 +107,14 @@ class PeerActor[R <: HandshakeResult](
106107

107108
def processingHandshaking(handshaker: Handshaker[R], rlpxConnection: RLPxConnection,
108109
timeout: Cancellable, numRetries: Int): Receive =
109-
handleTerminated(rlpxConnection, numRetries, Handshaking(numRetries)) orElse
110+
handleTerminated(rlpxConnection, numRetries, Handshaking(numRetries)) orElse
110111
handleDisconnectMsg(rlpxConnection, Handshaking(numRetries)) orElse
111112
handlePingMsg(rlpxConnection) orElse stashMessages orElse {
112113

113114
case RLPxConnectionHandler.MessageReceived(msg) =>
114-
115-
// We need to determine p2p version just after hello message as next messages in handshake
116-
// can be compressed.
117-
msg match {
118-
case Hello(p2pVersion, _, _, _, _) =>
119-
rlpxConnection.ref ! PeerP2pVersion(p2pVersion)
120-
case _ => ()
121-
}
122-
123115
// Processes the received message, cancels the timeout and processes a new message but only if the handshaker
124116
// handles the received message
125-
handshaker.applyMessage(msg).foreach{ newHandshaker =>
117+
handshaker.applyMessage(msg).foreach { newHandshaker =>
126118
timeout.cancel()
127119
processHandshakerNextMessage(newHandshaker, rlpxConnection, numRetries)
128120
}
@@ -143,7 +135,7 @@ class PeerActor[R <: HandshakeResult](
143135
*
144136
* @param handshaker
145137
* @param rlpxConnection
146-
* @param numRetries, number of connection retries done during RLPxConnection establishment
138+
* @param numRetries , number of connection retries done during RLPxConnection establishment
147139
*/
148140
private def processHandshakerNextMessage(handshaker: Handshaker[R],
149141
rlpxConnection: RLPxConnection,
@@ -155,7 +147,7 @@ class PeerActor[R <: HandshakeResult](
155147
context become processingHandshaking(handshaker, rlpxConnection, newTimeout, numRetries)
156148

157149
case Left(HandshakeSuccess(handshakeResult)) =>
158-
rlpxConnection.uriOpt.foreach { uri =>knownNodesManager ! KnownNodesManager.AddKnownNode(uri) }
150+
rlpxConnection.uriOpt.foreach { uri => knownNodesManager ! KnownNodesManager.AddKnownNode(uri) }
159151
context become new HandshakedPeer(rlpxConnection, handshakeResult).receive
160152
unstashAll()
161153

@@ -252,8 +244,8 @@ class PeerActor[R <: HandshakeResult](
252244
*/
253245
def receive: Receive =
254246
handlePingMsg(rlpxConnection) orElse
255-
handleDisconnectMsg(rlpxConnection, Handshaked) orElse
256-
handleTerminated(rlpxConnection, 0, Handshaked) orElse {
247+
handleDisconnectMsg(rlpxConnection, Handshaked) orElse
248+
handleTerminated(rlpxConnection, 0, Handshaked) orElse {
257249

258250
case RLPxConnectionHandler.MessageReceived(message) =>
259251
log.debug(s"Received message: {} from $peerId", message)
@@ -268,7 +260,7 @@ class PeerActor[R <: HandshakeResult](
268260
case GetStatus =>
269261
sender() ! StatusResponse(Handshaked)
270262

271-
}
263+
}
272264
}
273265

274266
}
@@ -304,8 +296,6 @@ object PeerActor {
304296
}
305297
}
306298

307-
case class PeerP2pVersion(p2pVersion: Long)
308-
309299
case class HandleConnection(connection: ActorRef, remoteAddress: InetSocketAddress)
310300

311301
case class IncomingConnectionHandshakeSuccess(peer: Peer)
@@ -321,16 +311,25 @@ object PeerActor {
321311
private case object ResponseTimeout
322312

323313
case object GetStatus
314+
324315
case class StatusResponse(status: Status)
325316

326317
case class DisconnectPeer(reason: Int)
327318

328319
sealed trait Status
320+
329321
object Status {
322+
330323
case object Idle extends Status
324+
331325
case object Connecting extends Status
326+
332327
case class Handshaking(numRetries: Int) extends Status
328+
333329
case object Handshaked extends Status
330+
334331
case object Disconnected extends Status
332+
335333
}
334+
336335
}

src/main/scala/io/iohk/ethereum/network/rlpx/MessageCodec.scala

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import java.util.concurrent.atomic.AtomicInteger
44

55
import akka.util.ByteString
66
import io.iohk.ethereum.network.handshaker.EtcHelloExchangeState
7+
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Hello
78
import io.iohk.ethereum.network.p2p.{Message, MessageDecoder, MessageSerializable}
89
import org.xerial.snappy.Snappy
10+
911
import scala.util.{Failure, Success, Try}
1012

1113
class MessageCodec(frameCodec: FrameCodec, messageDecoder: MessageDecoder, protocolVersion: Message.Version) {
@@ -17,19 +19,37 @@ class MessageCodec(frameCodec: FrameCodec, messageDecoder: MessageDecoder, proto
1719
// 16Mb in base 2
1820
val maxDecompressedLength = 16777216
1921

20-
def readMessages(data: ByteString, p2pVersion: Option[Long]): Seq[Try[Message]] = {
22+
// MessageCodec is only used from actor context so it can be var
23+
@volatile
24+
private var remotePeerP2pVersion: Option[Long] = None
25+
26+
private def setRemoteVersionBasedOnHelloMessage(m: Message): Unit = {
27+
if (remotePeerP2pVersion.isEmpty) {
28+
m match {
29+
case hello: Hello =>
30+
remotePeerP2pVersion = Some(hello.p2pVersion)
31+
case _ =>
32+
}
33+
}
34+
}
35+
36+
def readMessages(data: ByteString): Seq[Try[Message]] = {
2137
val frames = frameCodec.readFrames(data)
2238

2339
frames map { frame =>
2440
val frameData = frame.payload.toArray
2541
val payloadTry =
26-
if (p2pVersion.contains(EtcHelloExchangeState.P2pVersion)){
42+
if (remotePeerP2pVersion.exists(version => version >= EtcHelloExchangeState.P2pVersion)) {
2743
decompressData(frameData)
2844
} else {
2945
Success(frameData)
3046
}
3147

32-
payloadTry.map(payload => messageDecoder.fromBytes(frame.`type`, payload, protocolVersion))
48+
payloadTry.map { payload =>
49+
val m = messageDecoder.fromBytes(frame.`type`, payload, protocolVersion)
50+
setRemoteVersionBasedOnHelloMessage(m)
51+
m
52+
}
3353
}
3454
}
3555

@@ -42,14 +62,16 @@ class MessageCodec(frameCodec: FrameCodec, messageDecoder: MessageDecoder, proto
4262
}
4363
}
4464

45-
def encodeMessage(serializable: MessageSerializable, p2pVersion: Option[Long]): ByteString = {
65+
def encodeMessage(serializable: MessageSerializable): ByteString = {
4666
val encoded: Array[Byte] = serializable.toBytes
4767
val numFrames = Math.ceil(encoded.length / MaxFramePayloadSize.toDouble).toInt
4868
val contextId = contextIdCounter.incrementAndGet()
4969
val frames = (0 until numFrames) map { frameNo =>
5070
val framedPayload = encoded.drop(frameNo * MaxFramePayloadSize).take(MaxFramePayloadSize)
5171
val payload =
52-
if (p2pVersion.contains(EtcHelloExchangeState.P2pVersion)){
72+
if (
73+
remotePeerP2pVersion.exists(version => version >= EtcHelloExchangeState.P2pVersion) && serializable.code != Hello.code
74+
) {
5375
Snappy.compress(framedPayload)
5476
} else {
5577
framedPayload

src/main/scala/io/iohk/ethereum/network/rlpx/RLPxConnectionHandler.scala

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import akka.actor._
66
import akka.io.Tcp._
77
import akka.io.{IO, Tcp}
88
import akka.util.ByteString
9-
import io.iohk.ethereum.network.PeerActor.PeerP2pVersion
109
import io.iohk.ethereum.network.p2p.{Message, MessageDecoder, MessageSerializable}
1110
import io.iohk.ethereum.network.rlpx.RLPxConnectionHandler.RLPxConfiguration
1211
import io.iohk.ethereum.utils.ByteUtils
@@ -24,15 +23,15 @@ import scala.util.{Failure, Success, Try}
2423
* 1. when created it waits for initial command (either handle incoming connection or connect usin g uri)
2524
* 2. when new connection is requested the actor waits for the result (waitingForConnectionResult)
2625
* 3. once underlying connection is established it either waits for handshake init message or for response message
27-
* (depending on who initiated the connection)
26+
* (depending on who initiated the connection)
2827
* 4. once handshake is done (and secure connection established) actor can send/receive messages (`handshaked` state)
2928
*/
3029
class RLPxConnectionHandler(
31-
messageDecoder: MessageDecoder,
32-
protocolVersion: Message.Version,
33-
authHandshaker: AuthHandshaker,
34-
messageCodecFactory: (Secrets, MessageDecoder, Message.Version) => MessageCodec,
35-
rlpxConfiguration: RLPxConfiguration)
30+
messageDecoder: MessageDecoder,
31+
protocolVersion: Message.Version,
32+
authHandshaker: AuthHandshaker,
33+
messageCodecFactory: (Secrets, MessageDecoder, Message.Version) => MessageCodec,
34+
rlpxConfiguration: RLPxConfiguration)
3635
extends Actor with ActorLogging {
3736

3837
import AuthHandshaker.{InitiatePacketLength, ResponsePacketLength}
@@ -126,10 +125,10 @@ class RLPxConnectionHandler(
126125
/**
127126
* Decode V4 packet
128127
*
129-
* @param data, includes both the V4 packet with bytes from next messages
128+
* @param data , includes both the V4 packet with bytes from next messages
130129
* @return data of the packet and the remaining data
131130
*/
132-
private def decodeV4Packet(data: ByteString): (ByteString, ByteString) = {
131+
private def decodeV4Packet(data: ByteString): (ByteString, ByteString) = {
133132
val encryptedPayloadSize = ByteUtils.bigEndianToShort(data.take(2).toArray)
134133
val (packetData, remainingData) = data.splitAt(encryptedPayloadSize + 2)
135134
packetData -> remainingData
@@ -148,7 +147,7 @@ class RLPxConnectionHandler(
148147
log.debug(s"Auth handshake succeeded for peer $peerId")
149148
context.parent ! ConnectionEstablished(remotePubKey)
150149
val messageCodec = messageCodecFactory(secrets, messageDecoder, protocolVersion)
151-
val messagesSoFar = messageCodec.readMessages(remainingData ,None)
150+
val messagesSoFar = messageCodec.readMessages(remainingData)
152151
messagesSoFar foreach processMessage
153152
context become handshaked(messageCodec)
154153

@@ -170,59 +169,54 @@ class RLPxConnectionHandler(
170169
* Handles sending and receiving messages from the Akka TCP connection, while also handling acknowledgement of
171170
* messages sent. Messages are only sent when all Ack from previous messages were received.
172171
*
173-
* @param messageCodec, for encoding the messages sent
174-
* @param messagesNotSent, messages not yet sent
175-
* @param cancellableAckTimeout, timeout for the message sent for which we are awaiting an acknowledgement (if there is one)
176-
* @param seqNumber, sequence number for the next message to be sent
172+
* @param messageCodec , for encoding the messages sent
173+
* @param messagesNotSent , messages not yet sent
174+
* @param cancellableAckTimeout , timeout for the message sent for which we are awaiting an acknowledgement (if there is one)
175+
* @param seqNumber , sequence number for the next message to be sent
177176
*/
178177
def handshaked(messageCodec: MessageCodec,
179178
messagesNotSent: Queue[MessageSerializable] = Queue.empty,
180179
cancellableAckTimeout: Option[CancellableAckTimeout] = None,
181-
seqNumber: Int = 0,
182-
p2pVersion: Option[Long] = None): Receive =
180+
seqNumber: Int = 0): Receive =
183181
handleWriteFailed orElse handleConnectionClosed orElse {
184182
case sm: SendMessage =>
185-
if(cancellableAckTimeout.isEmpty)
186-
sendMessage(messageCodec, sm.serializable, seqNumber, messagesNotSent, p2pVersion)
183+
if (cancellableAckTimeout.isEmpty)
184+
sendMessage(messageCodec, sm.serializable, seqNumber, messagesNotSent)
187185
else
188-
context become handshaked(messageCodec, messagesNotSent :+ sm.serializable, cancellableAckTimeout, seqNumber, p2pVersion)
186+
context become handshaked(messageCodec, messagesNotSent :+ sm.serializable, cancellableAckTimeout, seqNumber)
189187

190188
case Received(data) =>
191-
val messages = messageCodec.readMessages(data, p2pVersion)
189+
val messages = messageCodec.readMessages(data)
192190
messages foreach processMessage
193191

194192
case Ack if cancellableAckTimeout.nonEmpty =>
195193
//Cancel pending message timeout
196194
cancellableAckTimeout.foreach(_.cancellable.cancel())
197195

198196
//Send next message if there is one
199-
if(messagesNotSent.nonEmpty)
200-
sendMessage(messageCodec, messagesNotSent.head, seqNumber, messagesNotSent.tail, p2pVersion)
197+
if (messagesNotSent.nonEmpty)
198+
sendMessage(messageCodec, messagesNotSent.head, seqNumber, messagesNotSent.tail)
201199
else
202-
context become handshaked(messageCodec, Queue.empty, None, seqNumber, p2pVersion)
200+
context become handshaked(messageCodec, Queue.empty, None, seqNumber)
203201

204202
case AckTimeout(ackSeqNumber) if cancellableAckTimeout.exists(_.seqNumber == ackSeqNumber) =>
205203
cancellableAckTimeout.foreach(_.cancellable.cancel())
206204
log.debug(s"[Stopping Connection] Write to $peerId failed")
207205
context stop self
208-
209-
case PeerP2pVersion(p2pVer) =>
210-
// We have peer p2p version based on hello message, if version is >= 5 next messages will be compressed.
211-
context.become(handshaked(messageCodec, messagesNotSent, cancellableAckTimeout, seqNumber, Some(p2pVer)))
212206
}
213207

214208
/**
215209
* Sends an encoded message through the TCP connection, an Ack will be received when the message was
216210
* successfully queued for delivery. A cancellable timeout is created for the Ack message.
217211
*
218-
* @param messageCodec, for encoding the messages sent
219-
* @param messageToSend, message to be sent
220-
* @param seqNumber, sequence number for the message to be sent
221-
* @param remainingMsgsToSend, messages not yet sent
212+
* @param messageCodec , for encoding the messages sent
213+
* @param messageToSend , message to be sent
214+
* @param seqNumber , sequence number for the message to be sent
215+
* @param remainingMsgsToSend , messages not yet sent
222216
*/
223217
private def sendMessage(messageCodec: MessageCodec, messageToSend: MessageSerializable,
224-
seqNumber: Int, remainingMsgsToSend: Queue[MessageSerializable], p2pVersion: Option[Long]): Unit = {
225-
val out = messageCodec.encodeMessage(messageToSend, p2pVersion)
218+
seqNumber: Int, remainingMsgsToSend: Queue[MessageSerializable]): Unit = {
219+
val out = messageCodec.encodeMessage(messageToSend)
226220
connection ! Write(out, Ack)
227221
log.debug(s"Sent message: $messageToSend from $peerId")
228222

@@ -231,15 +225,14 @@ class RLPxConnectionHandler(
231225
messageCodec = messageCodec,
232226
messagesNotSent = remainingMsgsToSend,
233227
cancellableAckTimeout = Some(CancellableAckTimeout(seqNumber, timeout)),
234-
seqNumber = increaseSeqNumber(seqNumber),
235-
p2pVersion = p2pVersion
228+
seqNumber = increaseSeqNumber(seqNumber)
236229
)
237230
}
238231

239232
/**
240233
* Given a sequence number for the AckTimeouts, the next seq number is returned
241234
*
242-
* @param seqNumber, the current sequence number
235+
* @param seqNumber , the current sequence number
243236
* @return the sequence number for the next message sent
244237
*/
245238
private def increaseSeqNumber(seqNumber: Int): Int = seqNumber match {
@@ -258,13 +251,14 @@ class RLPxConnectionHandler(
258251
if (msg.isPeerClosed) {
259252
log.debug(s"[Stopping Connection] Connection with $peerId closed by peer")
260253
}
261-
if(msg.isErrorClosed){
254+
if (msg.isErrorClosed) {
262255
log.debug(s"[Stopping Connection] Connection with $peerId closed because of error ${msg.getErrorCause}")
263256
}
264257

265258
context stop self
266259
}
267260
}
261+
268262
}
269263

270264
object RLPxConnectionHandler {

0 commit comments

Comments
 (0)