Skip to content

Commit 8aebfca

Browse files
committed
ETCM-167: Generic derivation of decoder.
1 parent b5f03cd commit 8aebfca

File tree

3 files changed

+134
-10
lines changed

3 files changed

+134
-10
lines changed

src/main/scala/io/iohk/ethereum/network/discovery/codecs/RLPCodecs.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package io.iohk.ethereum.network.discovery.codecs
22

33
import io.iohk.scalanet.discovery.ethereum.Node
44
import io.iohk.scalanet.discovery.ethereum.v4.Payload
5-
import io.iohk.ethereum.rlp.{RLPList, RLPEncodeable, RLPCodec, RLPEncoder}
5+
import io.iohk.ethereum.rlp.{RLPList, RLPEncodeable, RLPCodec, RLPEncoder, RLPDecoder}
66
import io.iohk.ethereum.rlp.RLPImplicits._
77
import io.iohk.ethereum.rlp.RLPImplicitConversions._
88
import io.iohk.ethereum.rlp.RLPImplicitDerivations._
@@ -32,5 +32,8 @@ object RLPCodecs {
3232
implicit val pingRLPEncoder: RLPEncoder[Payload.Ping] =
3333
deriveLabelledGenericRLPListEncoder
3434

35+
implicit val pingRLPDecoder: RLPDecoder[Payload.Ping] =
36+
deriveLabelledGenericRLPListDecoder
37+
3538
implicit def payloadCodec: Codec[Payload] = ???
3639
}

src/main/scala/io/iohk/ethereum/rlp/RLPImplicitDerivations.scala

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package io.iohk.ethereum.rlp
22

3-
import shapeless.{HList, HNil, Lazy, ::, LabelledGeneric, <:!<}
4-
import shapeless.labelled.FieldType
3+
import shapeless.{HList, HNil, Lazy, ::, LabelledGeneric, <:!<, Witness}
4+
import shapeless.labelled.{FieldType, field}
5+
import scala.util.control.NonFatal
56

67
/** Automatically derive RLP codecs for case classes. */
78
object RLPImplicitDerivations {
@@ -33,6 +34,23 @@ object RLPImplicitDerivations {
3334
}
3435
}
3536

37+
/** Specialized decoder for case classes that only accepts RLPList for input. */
38+
trait RLPListDecoder[T] extends RLPDecoder[T] {
39+
def decodeList(items: List[RLPEncodeable]): (T, List[FieldInfo])
40+
41+
override def decode(rlp: RLPEncodeable): T =
42+
rlp match {
43+
case list: RLPList => decodeList(list.items.toList)._1
44+
case _ => throw new RuntimeException("Expected to decode an RLPList.")
45+
}
46+
}
47+
object RLPListDecoder {
48+
def apply[T](f: List[RLPEncodeable] => (T, List[FieldInfo])): RLPListDecoder[T] =
49+
new RLPListDecoder[T] {
50+
override def decodeList(items: List[RLPEncodeable]) = f(items)
51+
}
52+
}
53+
3654
/** Encoder for the empty list of fields. */
3755
implicit val deriveHNilRLPListEncoder: RLPListEncoder[HNil] =
3856
RLPListEncoder(_ => RLPList() -> Nil)
@@ -75,28 +93,122 @@ object RLPImplicitDerivations {
7593
}
7694
}
7795

78-
/** Deriving RLP encoding for a HList of fields where the current field is non-optional. */
96+
/** Encoder for a HList of fields where the current field is non-optional. */
7997
implicit def deriveNonOptionHListRLPListEncoder[K, H, T <: HList](implicit
8098
hEncoder: Lazy[RLPEncoder[H]],
8199
tEncoder: Lazy[RLPListEncoder[T]],
82100
ev: H <:!< Option[_]
83101
): RLPListEncoder[FieldType[K, H] :: T] = {
84102
val hInfo = FieldInfo(isOptional = false)
103+
85104
RLPListEncoder { case head :: tail =>
86105
val hRLP = hEncoder.value.encode(head)
87106
val (tRLP, tInfos) = tEncoder.value.encodeList(tail)
88107
(hRLP :: tRLP, hInfo :: tInfos)
89108
}
90109
}
91110

92-
/** Derive an encoder for a case class based on its labelled generic record representation. */
111+
/** Encoder for a case class based on its labelled generic record representation. */
93112
implicit def deriveLabelledGenericRLPListEncoder[T, Rec](implicit
94113
// Auto-derived by Shapeless.
95114
generic: LabelledGeneric.Aux[T, Rec],
96115
// Derived by `deriveOptionHListRLPListEncoder` and `deriveNonOptionHListRLPListEncoder`.
97-
recEncoder: Lazy[RLPListEncoder[Rec]]
98-
): RLPListEncoder[T] = RLPListEncoder { value =>
99-
recEncoder.value.encodeList(generic.to(value))
116+
recEncoder: Lazy[RLPEncoder[Rec]]
117+
): RLPEncoder[T] = RLPEncoder { value =>
118+
recEncoder.value.encode(generic.to(value))
100119
}
101120

121+
/** Decoder for the empty list of fields.
122+
*
123+
* We can ignore extra items in the RLPList as optional fields we don't handle,
124+
* or extra random data, which we have for example in EIP8 test vectors.
125+
*/
126+
implicit val deriveHNilRLPListDecoder: RLPListDecoder[HNil] =
127+
RLPListDecoder(_ => HNil -> Nil)
128+
129+
/** Decoder for a list of fields in the generic represenation of a case class.
130+
*
131+
* This variant deals with trailing optional fields, which may be omitted from
132+
* the end of RLP lists.
133+
*/
134+
implicit def deriveOptionHListRLPListDecoder[K <: Symbol, H, V, T <: HList](implicit
135+
hDecoder: Lazy[RLPDecoder[H]],
136+
tDecoder: Lazy[RLPListDecoder[T]],
137+
// The witness provides access to the Symbols which LabelledGeneric uses
138+
// to tag the fields with their names, so we can use it to provide better
139+
// contextual error messages.
140+
witness: Witness.Aux[K],
141+
ev: Option[V] =:= H,
142+
policy: DerivationPolicy
143+
): RLPListDecoder[FieldType[K, H] :: T] = {
144+
val fieldName: String = witness.value.name
145+
val hInfo = FieldInfo(isOptional = true)
146+
147+
RLPListDecoder {
148+
case Nil if policy.omitTrailingOptionals =>
149+
val (tail, tInfos) = tDecoder.value.decodeList(Nil)
150+
val value: H = None
151+
val head: FieldType[K, H] = field[K](value)
152+
(head :: tail) -> (hInfo :: tInfos)
153+
154+
case Nil =>
155+
throw new RuntimeException(s"Cannot decode optional '${fieldName}': the RLPList is empty.")
156+
157+
case rlps =>
158+
val (tail, tInfos) = tDecoder.value.decodeList(rlps.tail)
159+
val value: H =
160+
try {
161+
if (policy.omitTrailingOptionals && tInfos.forall(_.isOptional)) {
162+
// Treat it as a value. We have a decoder for optional fields, so we have to wrap it.
163+
hDecoder.value.decode(RLPList(rlps.head))
164+
} else {
165+
hDecoder.value.decode(rlps.head)
166+
}
167+
} catch {
168+
case NonFatal(ex) =>
169+
throw new RuntimeException(s"Cannot decode optional '$fieldName' from RLP value: $ex")
170+
}
171+
172+
val head: FieldType[K, H] = field[K](value)
173+
(head :: tail) -> (hInfo :: tInfos)
174+
}
175+
}
176+
177+
/** Decoder for a non-optional field. */
178+
implicit def deriveNonOptionHListRLPListDecoder[K <: Symbol, H, T <: HList](implicit
179+
hDecoder: Lazy[RLPDecoder[H]],
180+
tDecoder: Lazy[RLPListDecoder[T]],
181+
witness: Witness.Aux[K],
182+
ev: H <:!< Option[_]
183+
): RLPListDecoder[FieldType[K, H] :: T] = {
184+
val fieldName: String = witness.value.name
185+
val hInfo = FieldInfo(isOptional = false)
186+
187+
RLPListDecoder {
188+
case Nil =>
189+
throw new RuntimeException(s"Cannot decode '${fieldName}': the RLPList is empty.")
190+
191+
case rlps =>
192+
val value: H =
193+
try {
194+
hDecoder.value.decode(rlps.head)
195+
} catch {
196+
case NonFatal(ex) =>
197+
throw new RuntimeException(s"Cannot decode '$fieldName' from RLP value: $ex")
198+
}
199+
val head: FieldType[K, H] = field[K](value)
200+
val (tail, tInfos) = tDecoder.value.decodeList(rlps.tail)
201+
(head :: tail) -> (hInfo :: tInfos)
202+
}
203+
}
204+
205+
/** Decoder for a case class based on its labelled generic record representation. */
206+
implicit def deriveLabelledGenericRLPListDecoder[T, Rec](implicit
207+
// Auto-derived by Shapeless.
208+
generic: LabelledGeneric.Aux[T, Rec],
209+
// Derived by `deriveOptionHListRLPListDecoder` and `deriveNonOptionHListRLPListDecoder`.
210+
recDecoder: Lazy[RLPDecoder[Rec]]
211+
): RLPDecoder[T] = RLPDecoder { rlp =>
212+
generic.from(recDecoder.value.decode(rlp))
213+
}
102214
}

src/test/scala/io/iohk/ethereum/network/discovery/codecs/RLPCodecsSpec.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import scodec.bits.BitVector
1010
import scodec.Codec
1111
import java.net.InetAddress
1212
import io.iohk.ethereum.rlp.RLPValue
13+
import io.iohk.ethereum.rlp.RLPDecoder
1314

1415
class RLPCodecsSpec extends AnyFlatSpec {
1516

@@ -127,13 +128,17 @@ class RLPCodecsSpec extends AnyFlatSpec {
127128
enrSeq = Some(1)
128129
)
129130

130-
RLPEncoder.encode(ping) match {
131+
val rlp = RLPEncoder.encode(ping)
132+
133+
rlp match {
131134
case list: RLPList =>
132135
list.items should have size 5
133136
list.items.last shouldBe an[RLPValue]
134137
case other =>
135138
fail(s"Expected RLPList; got $other")
136139
}
140+
141+
RLPDecoder.decode[Payload.Ping](rlp) shouldBe ping
137142
}
138143

139144
it should "encode a Ping without an ENR as 4 items" in {
@@ -145,11 +150,15 @@ class RLPCodecsSpec extends AnyFlatSpec {
145150
enrSeq = None
146151
)
147152

148-
RLPEncoder.encode(ping) match {
153+
val rlp = RLPEncoder.encode(ping)
154+
155+
rlp match {
149156
case list: RLPList =>
150157
list.items should have size 4
151158
case other =>
152159
fail(s"Expected RLPList; got $other")
153160
}
161+
162+
RLPDecoder.decode[Payload.Ping](rlp) shouldBe ping
154163
}
155164
}

0 commit comments

Comments
 (0)