Skip to content

[ETCM-355] Refactor messages decoding #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 13, 2021
31 changes: 18 additions & 13 deletions src/main/scala/io/iohk/ethereum/network/p2p/Message.scala
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
package io.iohk.ethereum.network.p2p

import akka.util.ByteString
import cats.implicits._

import scala.util.Try
import io.iohk.ethereum.utils.Logger

trait Message {
def code: Int
def toShortString: String
}

trait MessageSerializable extends Message {

//DummyImplicit parameter only used to differentiate from the other toBytes method
def toBytes(implicit di: DummyImplicit): ByteString

def toBytes: Array[Byte]

def underlyingMsg: Message
}

trait MessageDecoder { self =>
def fromBytes(`type`: Int, payload: Array[Byte]): Message
@FunctionalInterface
trait MessageDecoder extends Logger { self =>
import MessageDecoder._

def fromBytes(`type`: Int, payload: Array[Byte]): Either[DecodingError, Message]

def orElse(otherMessageDecoder: MessageDecoder): MessageDecoder = new MessageDecoder {
override def fromBytes(`type`: Int, payload: Array[Byte]): Either[DecodingError, Message] =
self.fromBytes(`type`, payload).leftFlatMap { err =>
log.debug(err.getLocalizedMessage())
otherMessageDecoder.fromBytes(`type`, payload)
}
}
}

def orElse(otherMessageDecoder: MessageDecoder): MessageDecoder = (`type`: Int, payload: Array[Byte]) =>
Try {
self.fromBytes(`type`, payload)
}.getOrElse(otherMessageDecoder.fromBytes(`type`, payload))
object MessageDecoder {
type DecodingError = Throwable // TODO: Replace Throwable with an ADT when feasible
}
113 changes: 58 additions & 55 deletions src/main/scala/io/iohk/ethereum/network/p2p/MessageDecoders.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.iohk.ethereum.network.p2p

import scala.util.Try

import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.SignedTransactions._
import io.iohk.ethereum.network.p2p.messages.Capability
import io.iohk.ethereum.network.p2p.messages.Codes
Expand All @@ -19,15 +21,17 @@ import io.iohk.ethereum.network.p2p.messages.WireProtocol.Ping._
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Pong._
import io.iohk.ethereum.network.p2p.messages.WireProtocol._

import MessageDecoder._

object NetworkMessageDecoder extends MessageDecoder {

override def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
override def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
msgCode match {
case Disconnect.code => payload.toDisconnect
case Ping.code => payload.toPing
case Pong.code => payload.toPong
case Hello.code => payload.toHello
case _ => throw new RuntimeException(s"Unknown network message type: $msgCode")
case Disconnect.code => Try(payload.toDisconnect).toEither
case Ping.code => Try(payload.toPing).toEither
case Pong.code => Try(payload.toPong).toEither
case Hello.code => Try(payload.toHello).toEither
case _ => Left(new RuntimeException(s"Unknown network message type: $msgCode"))
}

}
Expand All @@ -36,79 +40,78 @@ object ETC64MessageDecoder extends MessageDecoder {
import io.iohk.ethereum.network.p2p.messages.ETC64.Status._
import io.iohk.ethereum.network.p2p.messages.ETC64.NewBlock._

def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
msgCode match {
case Codes.StatusCode => payload.toStatus
case Codes.NewBlockCode => payload.toNewBlock
case Codes.GetNodeDataCode => payload.toGetNodeData
case Codes.NodeDataCode => payload.toNodeData
case Codes.GetReceiptsCode => payload.toGetReceipts
case Codes.ReceiptsCode => payload.toReceipts
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
case Codes.BlockHeadersCode => payload.toBlockHeaders
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
case Codes.BlockBodiesCode => payload.toBlockBodies
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
case Codes.SignedTransactionsCode => payload.toSignedTransactions
case _ => throw new RuntimeException(s"Unknown etc/64 message type: $msgCode")
case Codes.StatusCode => Try(payload.toStatus).toEither
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
case _ => Left(new RuntimeException(s"Unknown etc/64 message type: $msgCode"))
}
}

object ETH64MessageDecoder extends MessageDecoder {
import io.iohk.ethereum.network.p2p.messages.ETH64.Status._
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.NewBlock._

def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
msgCode match {
case Codes.GetNodeDataCode => payload.toGetNodeData
case Codes.NodeDataCode => payload.toNodeData
case Codes.GetReceiptsCode => payload.toGetReceipts
case Codes.ReceiptsCode => payload.toReceipts
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
case Codes.BlockHeadersCode => payload.toBlockHeaders
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
case Codes.BlockBodiesCode => payload.toBlockBodies
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
case Codes.StatusCode => payload.toStatus
case Codes.NewBlockCode => payload.toNewBlock
case Codes.SignedTransactionsCode => payload.toSignedTransactions
case _ => throw new RuntimeException(s"Unknown eth/64 message type: $msgCode")
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
case Codes.StatusCode => Try(payload.toStatus).toEither
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
case _ => Left(new RuntimeException(s"Unknown eth/64 message type: $msgCode"))
}
}

object ETH63MessageDecoder extends MessageDecoder {
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.Status._
import io.iohk.ethereum.network.p2p.messages.BaseETH6XMessages.NewBlock._

def fromBytes(msgCode: Int, payload: Array[Byte]): Message =
def fromBytes(msgCode: Int, payload: Array[Byte]): Either[DecodingError, Message] =
msgCode match {
case Codes.GetNodeDataCode => payload.toGetNodeData
case Codes.NodeDataCode => payload.toNodeData
case Codes.GetReceiptsCode => payload.toGetReceipts
case Codes.ReceiptsCode => payload.toReceipts
case Codes.NewBlockHashesCode => payload.toNewBlockHashes
case Codes.GetBlockHeadersCode => payload.toGetBlockHeaders
case Codes.BlockHeadersCode => payload.toBlockHeaders
case Codes.GetBlockBodiesCode => payload.toGetBlockBodies
case Codes.BlockBodiesCode => payload.toBlockBodies
case Codes.BlockHashesFromNumberCode => payload.toBlockHashesFromNumber
case Codes.StatusCode => payload.toStatus
case Codes.NewBlockCode => payload.toNewBlock
case Codes.SignedTransactionsCode => payload.toSignedTransactions
case _ => throw new RuntimeException(s"Unknown eth/63 message type: $msgCode")
case Codes.GetNodeDataCode => Try(payload.toGetNodeData).toEither
case Codes.NodeDataCode => Try(payload.toNodeData).toEither
case Codes.GetReceiptsCode => Try(payload.toGetReceipts).toEither
case Codes.ReceiptsCode => Try(payload.toReceipts).toEither
case Codes.NewBlockHashesCode => Try(payload.toNewBlockHashes).toEither
case Codes.GetBlockHeadersCode => Try(payload.toGetBlockHeaders).toEither
case Codes.BlockHeadersCode => Try(payload.toBlockHeaders).toEither
case Codes.GetBlockBodiesCode => Try(payload.toGetBlockBodies).toEither
case Codes.BlockBodiesCode => Try(payload.toBlockBodies).toEither
case Codes.BlockHashesFromNumberCode => Try(payload.toBlockHashesFromNumber).toEither
case Codes.StatusCode => Try(payload.toStatus).toEither
case Codes.NewBlockCode => Try(payload.toNewBlock).toEither
case Codes.SignedTransactionsCode => Try(payload.toSignedTransactions).toEither
case _ => Left(new RuntimeException(s"Unknown eth/63 message type: $msgCode"))
}
}

// scalastyle:off
object EthereumMessageDecoder {
type Decoder = (Int, Array[Byte]) => Message
def ethMessageDecoder(protocolVersion: Capability): MessageDecoder =
protocolVersion match {
case Capability.Capabilities.Etc64Capability => ETC64MessageDecoder.fromBytes
case Capability.Capabilities.Eth63Capability => ETH63MessageDecoder.fromBytes
case Capability.Capabilities.Eth64Capability => ETH64MessageDecoder.fromBytes
case Capability.Capabilities.Etc64Capability => ETC64MessageDecoder.orElse(NetworkMessageDecoder)
case Capability.Capabilities.Eth63Capability => ETH63MessageDecoder.orElse(NetworkMessageDecoder)
case Capability.Capabilities.Eth64Capability => ETH64MessageDecoder.orElse(NetworkMessageDecoder)
case _ => throw new RuntimeException(s"Unsupported Protocol Version $protocolVersion")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.xerial.snappy.Snappy
import io.iohk.ethereum.network.handshaker.EtcHelloExchangeState
import io.iohk.ethereum.network.p2p.Message
import io.iohk.ethereum.network.p2p.MessageDecoder
import io.iohk.ethereum.network.p2p.MessageDecoder.DecodingError
import io.iohk.ethereum.network.p2p.MessageSerializable
import io.iohk.ethereum.network.p2p.messages.WireProtocol.Hello

Expand All @@ -32,12 +33,12 @@ class MessageCodec(
val contextIdCounter = new AtomicInteger

// TODO: ETCM-402 - messageDecoder should use negotiated protocol version
def readMessages(data: ByteString): Seq[Try[Message]] = {
def readMessages(data: ByteString): Seq[Either[DecodingError, Message]] = {
val frames = frameCodec.readFrames(data)
readFrames(frames)
}

def readFrames(frames: Seq[Frame]): Seq[Try[Message]] =
def readFrames(frames: Seq[Frame]): Seq[Either[DecodingError, Message]] =
frames.map { frame =>
val frameData = frame.payload.toArray
val payloadTry =
Expand All @@ -47,7 +48,7 @@ class MessageCodec(
Success(frameData)
}

payloadTry.map { payload =>
payloadTry.toEither.flatMap { payload =>
messageDecoder.fromBytes(frame.`type`, payload)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.bouncycastle.util.encoders.Hex

import io.iohk.ethereum.network.p2p.EthereumMessageDecoder
import io.iohk.ethereum.network.p2p.Message
import io.iohk.ethereum.network.p2p.MessageDecoder._
import io.iohk.ethereum.network.p2p.MessageSerializable
import io.iohk.ethereum.network.p2p.NetworkMessageDecoder
import io.iohk.ethereum.network.p2p.messages.Capability
Expand Down Expand Up @@ -255,11 +256,11 @@ class RLPxConnectionHandler(
messagesSoFar.foreach(processMessage)
}

def processMessage(messageTry: Try[Message]): Unit = messageTry match {
case Success(message) =>
def processMessage(messageTry: Either[DecodingError, Message]): Unit = messageTry match {
case Right(message) =>
context.parent ! MessageReceived(message)

case Failure(ex) =>
case Left(ex) =>
log.info("Cannot decode message from {}, because of {}", peerId, ex.getMessage)
// break connection in case of failed decoding, to avoid attack which would send us garbage
context.stop(self)
Expand Down Expand Up @@ -395,7 +396,7 @@ object RLPxConnectionHandler {
negotiated: Capability,
p2pVersion: Long
): MessageCodec = {
val md = EthereumMessageDecoder.ethMessageDecoder(negotiated).orElse(NetworkMessageDecoder)
val md = EthereumMessageDecoder.ethMessageDecoder(negotiated)
new MessageCodec(frameCodec, md, p2pVersion)
}

Expand Down Expand Up @@ -450,8 +451,10 @@ object RLPxConnectionHandler {
private def extractHello(frame: Frame): Option[Hello] = {
val frameData = frame.payload.toArray
if (frame.`type` == Hello.code) {
val m = NetworkMessageDecoder.fromBytes(frame.`type`, frameData)
Some(m.asInstanceOf[Hello])
NetworkMessageDecoder.fromBytes(frame.`type`, frameData) match {
case Left(err) => throw err // TODO: rethink throwing here
case Right(msg) => Some(msg.asInstanceOf[Hello])
}
} else {
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {

// remote peer did not receive local status so it treats all remote messages as uncompressed
assert(remoteReadNotCompressedStatus.size == 1)
assert(remoteReadNotCompressedStatus.head.get == status)
assert(remoteReadNotCompressedStatus.head == Right(status))
}

it should "compress messages when remote side advertises p2p version larger or equal 5" in new TestSetup {
Expand All @@ -41,7 +41,7 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {
// remote peer did not receive local status so it treats all remote messages as uncompressed,
// but local peer compress messages after V5 Hello message
assert(remoteReadNotCompressedStatus.size == 1)
assert(remoteReadNotCompressedStatus.head.isFailure)
assert(remoteReadNotCompressedStatus.head.isLeft)
}

it should "compress messages when both sides advertises p2p version larger or equal 5" in new TestSetup {
Expand All @@ -56,7 +56,7 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {

// both peers exchanged v5 hellos, so they should send compressed messages
assert(remoteReadNextMessageAfterHello.size == 1)
assert(remoteReadNextMessageAfterHello.head.get == status)
assert(remoteReadNextMessageAfterHello.head == Right(status))
}

it should "compress and decompress first message after hello when receiving 2 frames" in new TestSetup {
Expand All @@ -72,8 +72,8 @@ class MessageCodecSpec extends AnyFlatSpec with Matchers {

// both peers exchanged v5 hellos, so they should send compressed messages
assert(remoteReadBothMessages.size == 2)
assert(remoteReadBothMessages.head.get == helloV5)
assert(remoteReadBothMessages.last.get == status)
assert(remoteReadBothMessages.head == Right(helloV5))
assert(remoteReadBothMessages.last == Right(status))
}

trait TestSetup extends SecureChannelSetup {
Expand Down
Loading