Skip to content

Commit 130d371

Browse files
[ETCM-355] Make fromBytes in MessageDecoder safe
Rename old implementation to fromBytesUnsafe
1 parent 5eb69c5 commit 130d371

File tree

8 files changed

+120
-106
lines changed

8 files changed

+120
-106
lines changed

src/main/scala/io/iohk/ethereum/network/p2p/Message.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package io.iohk.ethereum.network.p2p
22

33
import akka.util.ByteString
44

5-
import scala.util.Try
5+
import cats.implicits._
6+
7+
import io.iohk.ethereum.utils.Logger
68

79
trait Message {
810
def code: Int
@@ -19,11 +21,23 @@ trait MessageSerializable extends Message {
1921
def underlyingMsg: Message
2022
}
2123

22-
trait MessageDecoder { self =>
23-
def fromBytes(`type`: Int, payload: Array[Byte]): Message
24+
@FunctionalInterface
25+
trait MessageDecoder extends Logger { self =>
26+
27+
type DecodingError = Throwable // TODO: Replace Throwable with an ADT when feasible
28+
29+
def fromBytes(`type`: Int, payload: Array[Byte]): Either[DecodingError, Message]
30+
31+
def fromBytesUnsafe(`type`: Int, payload: Array[Byte]): Message = self.fromBytes(`type`, payload) match {
32+
case Left(err) => throw err
33+
case Right(res) => res
34+
}
2435

25-
def orElse(otherMessageDecoder: MessageDecoder): MessageDecoder = (`type`: Int, payload: Array[Byte]) =>
26-
Try {
27-
self.fromBytes(`type`, payload)
28-
}.getOrElse(otherMessageDecoder.fromBytes(`type`, payload))
36+
def orElse(otherMessageDecoder: MessageDecoder): MessageDecoder = new MessageDecoder {
37+
override def fromBytes(`type`: Int, payload: Array[Byte]): Either[DecodingError, Message] =
38+
self.fromBytes(`type`, payload).leftFlatMap { err =>
39+
log.debug(err.getLocalizedMessage())
40+
otherMessageDecoder.fromBytes(`type`, payload)
41+
}
42+
}
2943
}

src/main/scala/io/iohk/ethereum/network/p2p/MessageDecoders.scala

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@ import io.iohk.ethereum.network.p2p.messages.WireProtocol.Hello._
1818
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Ping._
1919
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Pong._
2020
import io.iohk.ethereum.network.p2p.messages.WireProtocol._
21+
import scala.util.Try
2122

2223
object NetworkMessageDecoder extends MessageDecoder {
2324

24-
override def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
25+
override def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
2526
msgCode match {
26-
case Disconnect.code => payload.toDisconnect
27-
case Ping.code => payload.toPing
28-
case Pong.code => payload.toPong
29-
case Hello.code => payload.toHello
30-
case _ => throw new RuntimeException(s"Unknown network message type: $msgCode")
27+
case Disconnect.code => Try(payload.toDisconnect).toEither
28+
case Ping.code => Try(payload.toPing).toEither
29+
case Pong.code => Try(payload.toPong).toEither
30+
case Hello.code => Try(payload.toHello).toEither
31+
case _ => Left(new RuntimeException(s"Unknown network message type: $msgCode"))
3132
}
3233

3334
}
@@ -36,68 +37,68 @@ object ETC64MessageDecoder extends MessageDecoder {
3637
import io.iohk.ethereum.network.p2p.messages.ETC64.Status._
3738
import io.iohk.ethereum.network.p2p.messages.ETC64.NewBlock._
3839

39-
def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
40+
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
4041
msgCode match {
41-
case Codes.StatusCode => payload.toStatus
42-
case Codes.NewBlockCode => payload.toNewBlock
43-
case Codes.GetNodeDataCode => payload.toGetNodeData
44-
case Codes.NodeDataCode => payload.toNodeData
45-
case Codes.GetReceiptsCode => payload.toGetReceipts
46-
case Codes.ReceiptsCode => payload.toReceipts
47-
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
48-
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
49-
case Codes.BlockHeadersCode => payload.toBlockHeaders
50-
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
51-
case Codes.BlockBodiesCode => payload.toBlockBodies
52-
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
53-
case Codes.SignedTransactionsCode => payload.toSignedTransactions
54-
case _ => throw new RuntimeException(s"Unknown etc/64 message type: $msgCode")
42+
case Codes.StatusCode => Try(payload.toStatus).toEither
43+
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
44+
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
45+
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
46+
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
47+
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
48+
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
49+
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
50+
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
51+
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
52+
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
53+
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
54+
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
55+
case _ => Left(new RuntimeException(s"Unknown etc/64 message type: $msgCode"))
5556
}
5657
}
5758

5859
object ETH64MessageDecoder extends MessageDecoder {
5960
import io.iohk.ethereum.network.p2p.messages.ETH64.Status._
6061
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.NewBlock._
6162

62-
def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
63+
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
6364
msgCode match {
64-
case Codes.GetNodeDataCode => payload.toGetNodeData
65-
case Codes.NodeDataCode => payload.toNodeData
66-
case Codes.GetReceiptsCode => payload.toGetReceipts
67-
case Codes.ReceiptsCode => payload.toReceipts
68-
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
69-
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
70-
case Codes.BlockHeadersCode => payload.toBlockHeaders
71-
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
72-
case Codes.BlockBodiesCode => payload.toBlockBodies
73-
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
74-
case Codes.StatusCode => payload.toStatus
75-
case Codes.NewBlockCode => payload.toNewBlock
76-
case Codes.SignedTransactionsCode => payload.toSignedTransactions
77-
case _ => throw new RuntimeException(s"Unknown eth/64 message type: $msgCode")
65+
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
66+
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
67+
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
68+
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
69+
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
70+
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
71+
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
72+
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
73+
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
74+
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
75+
case Codes.StatusCode => Try(payload.toStatus).toEither
76+
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
77+
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
78+
case _ => Left(new RuntimeException(s"Unknown eth/64 message type: $msgCode"))
7879
}
7980
}
8081

8182
object ETH63MessageDecoder extends MessageDecoder {
8283
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.Status._
8384
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.NewBlock._
8485

85-
def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
86+
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
8687
msgCode match {
87-
case Codes.GetNodeDataCode => payload.toGetNodeData
88-
case Codes.NodeDataCode => payload.toNodeData
89-
case Codes.GetReceiptsCode => payload.toGetReceipts
90-
case Codes.ReceiptsCode => payload.toReceipts
91-
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
92-
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
93-
case Codes.BlockHeadersCode => payload.toBlockHeaders
94-
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
95-
case Codes.BlockBodiesCode => payload.toBlockBodies
96-
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
97-
case Codes.StatusCode => payload.toStatus
98-
case Codes.NewBlockCode => payload.toNewBlock
99-
case Codes.SignedTransactionsCode => payload.toSignedTransactions
100-
case _ => throw new RuntimeException(s"Unknown eth/63 message type: $msgCode")
88+
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
89+
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
90+
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
91+
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
92+
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
93+
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
94+
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
95+
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
96+
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
97+
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
98+
case Codes.StatusCode => Try(payload.toStatus).toEither
99+
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
100+
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
101+
case _ => Left(new RuntimeException(s"Unknown eth/63 message type: $msgCode"))
101102
}
102103
}
103104

src/main/scala/io/iohk/ethereum/network/rlpx/MessageCodec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class MessageCodec(
4848
}
4949

5050
payloadTry.map { payload =>
51-
messageDecoder.fromBytes(frame.`type`, payload)
51+
messageDecoder.fromBytesUnsafe(frame.`type`, payload)
5252
}
5353
}
5454

src/main/scala/io/iohk/ethereum/network/rlpx/RLPxConnectionHandler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ object RLPxConnectionHandler {
450450
private def extractHello(frame: Frame): Option[Hello] = {
451451
val frameData = frame.payload.toArray
452452
if (frame.`type` == Hello.code) {
453-
val m = NetworkMessageDecoder.fromBytes(frame.`type`, frameData)
453+
val m = NetworkMessageDecoder.fromBytesUnsafe(frame.`type`, frameData)
454454
Some(m.asInstanceOf[Hello])
455455
} else {
456456
None

0 commit comments

Comments
 (0)