Skip to content

Commit 69e7d4b

Browse files
[ETCM-355] Refactor messages decoding (#1046)
* [ETCM-355] Make fromBytes in MessageDecoder safe Rename old implementation to fromBytesUnsafe * [ETCM-355] RLPxConnectionHandler.processMessage doesn't use Try * [ETCM-355] Remove MessageDecoders.fromBytesUnsafe * [ETCM-355] Remove unused method * [ETCM-355] Fix formatting * [ETCM-355] Fix style * [ETCM-355] Move fallback on NetworkMessageDecoder inside .ehtMessageDecoder * [ETCM-355] Remove obsolete type alias * fixup! [ETCM-355] Fix formatting
1 parent 67db780 commit 69e7d4b

File tree

9 files changed

+152
-127
lines changed

9 files changed

+152
-127
lines changed
Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
package io.iohk.ethereum.network.p2p
22

3-
import akka.util.ByteString
3+
import cats.implicits._
44

5-
import scala.util.Try
5+
import io.iohk.ethereum.utils.Logger
66

77
trait Message {
88
def code: Int
99
def toShortString: String
1010
}
1111

1212
trait MessageSerializable extends Message {
13-
14-
//DummyImplicit parameter only used to differentiate from the other toBytes method
15-
def toBytes(implicit di: DummyImplicit): ByteString
16-
1713
def toBytes: Array[Byte]
18-
1914
def underlyingMsg: Message
2015
}
2116

22-
trait MessageDecoder { self =>
23-
def fromBytes(`type`: Int, payload: Array[Byte]): Message
17+
@FunctionalInterface
18+
trait MessageDecoder extends Logger { self =>
19+
import MessageDecoder._
20+
21+
def fromBytes(`type`: Int, payload: Array[Byte]): Either[DecodingError, Message]
22+
23+
def orElse(otherMessageDecoder: MessageDecoder): MessageDecoder = new MessageDecoder {
24+
override def fromBytes(`type`: Int, payload: Array[Byte]): Either[DecodingError, Message] =
25+
self.fromBytes(`type`, payload).leftFlatMap { err =>
26+
log.debug(err.getLocalizedMessage())
27+
otherMessageDecoder.fromBytes(`type`, payload)
28+
}
29+
}
30+
}
2431

25-
def orElse(otherMessageDecoder: MessageDecoder): MessageDecoder = (`type`: Int, payload: Array[Byte]) =>
26-
Try {
27-
self.fromBytes(`type`, payload)
28-
}.getOrElse(otherMessageDecoder.fromBytes(`type`, payload))
32+
object MessageDecoder {
33+
type DecodingError = Throwable // TODO: Replace Throwable with an ADT when feasible
2934
}
Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.iohk.ethereum.network.p2p
22

3+
import scala.util.Try
4+
35
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.SignedTransactions._
46
import io.iohk.ethereum.network.p2p.messages.Capability
57
import io.iohk.ethereum.network.p2p.messages.Codes
@@ -19,15 +21,17 @@ import io.iohk.ethereum.network.p2p.messages.WireProtocol.Ping._
1921
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Pong._
2022
import io.iohk.ethereum.network.p2p.messages.WireProtocol._
2123

24+
import MessageDecoder._
25+
2226
object NetworkMessageDecoder extends MessageDecoder {
2327

24-
override def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
28+
override def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
2529
msgCode match {
26-
case Disconnect.code => payload.toDisconnect
27-
case Ping.code => payload.toPing
28-
case Pong.code => payload.toPong
29-
case Hello.code => payload.toHello
30-
case _ => throw new RuntimeException(s"Unknown network message type: $msgCode")
30+
case Disconnect.code => Try(payload.toDisconnect).toEither
31+
case Ping.code => Try(payload.toPing).toEither
32+
case Pong.code => Try(payload.toPong).toEither
33+
case Hello.code => Try(payload.toHello).toEither
34+
case _ => Left(new RuntimeException(s"Unknown network message type: $msgCode"))
3135
}
3236

3337
}
@@ -36,79 +40,78 @@ object ETC64MessageDecoder extends MessageDecoder {
3640
import io.iohk.ethereum.network.p2p.messages.ETC64.Status._
3741
import io.iohk.ethereum.network.p2p.messages.ETC64.NewBlock._
3842

39-
def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
43+
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
4044
msgCode match {
41-
case Codes.StatusCode => payload.toStatus
42-
case Codes.NewBlockCode => payload.toNewBlock
43-
case Codes.GetNodeDataCode => payload.toGetNodeData
44-
case Codes.NodeDataCode => payload.toNodeData
45-
case Codes.GetReceiptsCode => payload.toGetReceipts
46-
case Codes.ReceiptsCode => payload.toReceipts
47-
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
48-
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
49-
case Codes.BlockHeadersCode => payload.toBlockHeaders
50-
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
51-
case Codes.BlockBodiesCode => payload.toBlockBodies
52-
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
53-
case Codes.SignedTransactionsCode => payload.toSignedTransactions
54-
case _ => throw new RuntimeException(s"Unknown etc/64 message type: $msgCode")
45+
case Codes.StatusCode => Try(payload.toStatus).toEither
46+
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
47+
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
48+
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
49+
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
50+
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
51+
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
52+
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
53+
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
54+
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
55+
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
56+
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
57+
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
58+
case _ => Left(new RuntimeException(s"Unknown etc/64 message type: $msgCode"))
5559
}
5660
}
5761

5862
object ETH64MessageDecoder extends MessageDecoder {
5963
import io.iohk.ethereum.network.p2p.messages.ETH64.Status._
6064
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.NewBlock._
6165

62-
def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
66+
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
6367
msgCode match {
64-
case Codes.GetNodeDataCode => payload.toGetNodeData
65-
case Codes.NodeDataCode => payload.toNodeData
66-
case Codes.GetReceiptsCode => payload.toGetReceipts
67-
case Codes.ReceiptsCode => payload.toReceipts
68-
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
69-
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
70-
case Codes.BlockHeadersCode => payload.toBlockHeaders
71-
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
72-
case Codes.BlockBodiesCode => payload.toBlockBodies
73-
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
74-
case Codes.StatusCode => payload.toStatus
75-
case Codes.NewBlockCode => payload.toNewBlock
76-
case Codes.SignedTransactionsCode => payload.toSignedTransactions
77-
case _ => throw new RuntimeException(s"Unknown eth/64 message type: $msgCode")
68+
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
69+
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
70+
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
71+
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
72+
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
73+
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
74+
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
75+
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
76+
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
77+
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
78+
case Codes.StatusCode => Try(payload.toStatus).toEither
79+
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
80+
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
81+
case _ => Left(new RuntimeException(s"Unknown eth/64 message type: $msgCode"))
7882
}
7983
}
8084

8185
object ETH63MessageDecoder extends MessageDecoder {
8286
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.Status._
8387
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.NewBlock._
8488

85-
def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
89+
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
8690
msgCode match {
87-
case Codes.GetNodeDataCode => payload.toGetNodeData
88-
case Codes.NodeDataCode => payload.toNodeData
89-
case Codes.GetReceiptsCode => payload.toGetReceipts
90-
case Codes.ReceiptsCode => payload.toReceipts
91-
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
92-
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
93-
case Codes.BlockHeadersCode => payload.toBlockHeaders
94-
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
95-
case Codes.BlockBodiesCode => payload.toBlockBodies
96-
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
97-
case Codes.StatusCode => payload.toStatus
98-
case Codes.NewBlockCode => payload.toNewBlock
99-
case Codes.SignedTransactionsCode => payload.toSignedTransactions
100-
case _ => throw new RuntimeException(s"Unknown eth/63 message type: $msgCode")
91+
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
92+
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
93+
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
94+
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
95+
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
96+
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
97+
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
98+
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
99+
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
100+
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
101+
case Codes.StatusCode => Try(payload.toStatus).toEither
102+
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
103+
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
104+
case _ => Left(new RuntimeException(s"Unknown eth/63 message type: $msgCode"))
101105
}
102106
}
103107

104108
// scalastyle:off
105109
object EthereumMessageDecoder {
106-
type Decoder = (Int, Array[Byte]) => Message
107110
def ethMessageDecoder(protocolVersion: Capability): MessageDecoder =
108111
protocolVersion match {
109-
case Capability.Capabilities.Etc64Capability => ETC64MessageDecoder.fromBytes
110-
case Capability.Capabilities.Eth63Capability => ETH63MessageDecoder.fromBytes
111-
case Capability.Capabilities.Eth64Capability => ETH64MessageDecoder.fromBytes
112+
case Capability.Capabilities.Etc64Capability => ETC64MessageDecoder.orElse(NetworkMessageDecoder)
113+
case Capability.Capabilities.Eth63Capability => ETH63MessageDecoder.orElse(NetworkMessageDecoder)
114+
case Capability.Capabilities.Eth64Capability => ETH64MessageDecoder.orElse(NetworkMessageDecoder)
112115
case _ => throw new RuntimeException(s"Unsupported Protocol Version $protocolVersion")
113116
}
114117
}

src/main/scala/io/iohk/ethereum/network/rlpx/MessageCodec.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import org.xerial.snappy.Snappy
1313
import io.iohk.ethereum.network.handshaker.EtcHelloExchangeState
1414
import io.iohk.ethereum.network.p2p.Message
1515
import io.iohk.ethereum.network.p2p.MessageDecoder
16+
import io.iohk.ethereum.network.p2p.MessageDecoder.DecodingError
1617
import io.iohk.ethereum.network.p2p.MessageSerializable
1718
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Hello
1819

@@ -32,12 +33,12 @@ class MessageCodec(
3233
val contextIdCounter = new AtomicInteger
3334

3435
// TODO: ETCM-402 - messageDecoder should use negotiated protocol version
35-
def readMessages(data: ByteString): Seq[Try[Message]] = {
36+
def readMessages(data: ByteString): Seq[Either[DecodingError, Message]] = {
3637
val frames = frameCodec.readFrames(data)
3738
readFrames(frames)
3839
}
3940

40-
def readFrames(frames: Seq[Frame]): Seq[Try[Message]] =
41+
def readFrames(frames: Seq[Frame]): Seq[Either[DecodingError, Message]] =
4142
frames.map { frame =>
4243
val frameData = frame.payload.toArray
4344
val payloadTry =
@@ -47,7 +48,7 @@ class MessageCodec(
4748
Success(frameData)
4849
}
4950

50-
payloadTry.map { payload =>
51+
payloadTry.toEither.flatMap { payload =>
5152
messageDecoder.fromBytes(frame.`type`, payload)
5253
}
5354
}

src/main/scala/io/iohk/ethereum/network/rlpx/RLPxConnectionHandler.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import org.bouncycastle.util.encoders.Hex
1919

2020
import io.iohk.ethereum.network.p2p.EthereumMessageDecoder
2121
import io.iohk.ethereum.network.p2p.Message
22+
import io.iohk.ethereum.network.p2p.MessageDecoder._
2223
import io.iohk.ethereum.network.p2p.MessageSerializable
2324
import io.iohk.ethereum.network.p2p.NetworkMessageDecoder
2425
import io.iohk.ethereum.network.p2p.messages.Capability
@@ -255,11 +256,11 @@ class RLPxConnectionHandler(
255256
messagesSoFar.foreach(processMessage)
256257
}
257258

258-
def processMessage(messageTry: Try[Message]): Unit = messageTry match {
259-
case Success(message) =>
259+
def processMessage(messageTry: Either[DecodingError, Message]): Unit = messageTry match {
260+
case Right(message) =>
260261
context.parent ! MessageReceived(message)
261262

262-
case Failure(ex) =>
263+
case Left(ex) =>
263264
log.info("Cannot decode message from {}, because of {}", peerId, ex.getMessage)
264265
// break connection in case of failed decoding, to avoid attack which would send us garbage
265266
context.stop(self)
@@ -395,7 +396,7 @@ object RLPxConnectionHandler {
395396
negotiated: Capability,
396397
p2pVersion: Long
397398
): MessageCodec = {
398-
val md = EthereumMessageDecoder.ethMessageDecoder(negotiated).orElse(NetworkMessageDecoder)
399+
val md = EthereumMessageDecoder.ethMessageDecoder(negotiated)
399400
new MessageCodec(frameCodec, md, p2pVersion)
400401
}
401402

@@ -450,8 +451,10 @@ object RLPxConnectionHandler {
450451
private def extractHello(frame: Frame): Option[Hello] = {
451452
val frameData = frame.payload.toArray
452453
if (frame.`type` == Hello.code) {
453-
val m = NetworkMessageDecoder.fromBytes(frame.`type`, frameData)
454-
Some(m.asInstanceOf[Hello])
454+
NetworkMessageDecoder.fromBytes(frame.`type`, frameData) match {
455+
case Left(err) => throw err // TODO: rethink throwing here
456+
case Right(msg) => Some(msg.asInstanceOf[Hello])
457+
}
455458
} else {
456459
None
457460
}

src/test/scala/io/iohk/ethereum/network/p2p/MessageCodecSpec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {
2525

2626
// remote peer did not receive local status so it treats all remote messages as uncompressed
2727
assert(remoteReadNotCompressedStatus.size == 1)
28-
assert(remoteReadNotCompressedStatus.head.get == status)
28+
assert(remoteReadNotCompressedStatus.head == Right(status))
2929
}
3030

3131
it should "compress messages when remote side advertises p2p version larger or equal 5" in new TestSetup {
@@ -41,7 +41,7 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {
4141
// remote peer did not receive local status so it treats all remote messages as uncompressed,
4242
// but local peer compress messages after V5 Hello message
4343
assert(remoteReadNotCompressedStatus.size == 1)
44-
assert(remoteReadNotCompressedStatus.head.isFailure)
44+
assert(remoteReadNotCompressedStatus.head.isLeft)
4545
}
4646

4747
it should "compress messages when both sides advertises p2p version larger or equal 5" in new TestSetup {
@@ -56,7 +56,7 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {
5656

5757
// both peers exchanged v5 hellos, so they should send compressed messages
5858
assert(remoteReadNextMessageAfterHello.size == 1)
59-
assert(remoteReadNextMessageAfterHello.head.get == status)
59+
assert(remoteReadNextMessageAfterHello.head == Right(status))
6060
}
6161

6262
it should "compress and decompress first message after hello when receiving 2 frames" in new TestSetup {
@@ -72,8 +72,8 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {
7272

7373
// both peers exchanged v5 hellos, so they should send compressed messages
7474
assert(remoteReadBothMessages.size == 2)
75-
assert(remoteReadBothMessages.head.get == helloV5)
76-
assert(remoteReadBothMessages.last.get == status)
75+
assert(remoteReadBothMessages.head == Right(helloV5))
76+
assert(remoteReadBothMessages.last == Right(status))
7777
}
7878

7979
trait TestSetup extends SecureChannelSetup {

0 commit comments

Comments
 (0)