Skip to content

Commit 6b6c4d8

Browse files
committed
[ETCM-533] return root nodeas part of the proof, add tests
1 parent a7830e1 commit 6b6c4d8

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

src/main/scala/io/iohk/ethereum/mpt/MerklePatriciaTrie.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class MerklePatriciaTrie[K, V] private (private[mpt] val rootNode: Option[MptNod
105105
def getProof(key: K): Option[Vector[MptNode]] = {
106106
pathTraverse[Vector[MptNode]](Vector.empty, mkKeyNibbles(key)) { case (acc, node) =>
107107
node match {
108-
case Some(nextNodeOnExt @ (_: BranchNode | _: ExtensionNode | _: LeafNode)) => acc :+ nextNodeOnExt
108+
case Some(nextNodeOnExt @ (_: BranchNode | _: ExtensionNode | _: LeafNode | _: HashNode)) => acc :+ nextNodeOnExt
109109
case _ => acc
110110
}
111111
}
@@ -155,7 +155,7 @@ class MerklePatriciaTrie[K, V] private (private[mpt] val rootNode: Option[MptNod
155155

156156
rootNode match {
157157
case Some(root) =>
158-
pathTraverse(acc, root, searchKey, op)
158+
pathTraverse(op(acc, Some(root)), root, searchKey, op)
159159
case None =>
160160
None
161161
}

src/test/scala/io/iohk/ethereum/ObjectGenerators.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ trait ObjectGenerators {
5050
} yield (aByteList.toArray, t)
5151
}
5252

53-
def keyValueListGen(): Gen[List[(Int, Int)]] = {
53+
def keyValueListGen(minValue: Int = Int.MinValue, maxValue: Int = Int.MaxValue): Gen[List[(Int, Int)]] = {
5454
for {
55-
aKeyList <- Gen.nonEmptyListOf(Arbitrary.arbitrary[Int]).map(_.distinct)
55+
values <- Gen.chooseNum(minValue, maxValue)
56+
aKeyList <- Gen.nonEmptyListOf(values).map(_.distinct)
5657
} yield aKeyList.zip(aKeyList)
5758
}
5859

src/test/scala/io/iohk/ethereum/mpt/MerklePatriciaTrieSuite.scala

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,40 @@ class MerklePatriciaTrieSuite extends AnyFunSuite with ScalaCheckPropertyChecks
554554
assert(proof.isEmpty)
555555
}
556556

557-
test("getProof returns proof result for non-existing address") {
557+
test("PatriciaTrie can get proof(at least the root node) for all inserted key-value pairs") {
558+
forAll(keyValueListGen()) { keyValueList: Seq[(Int, Int)] =>
559+
val trie = addEveryKeyValuePair(keyValueList)
560+
assertCanGetProofForEveryKeyValue(trie, keyValueList)
561+
}
562+
}
563+
564+
test("PatriciaTrie return root as proof when no common nibbles are found between MPT root hash and search key") {
565+
forAll(keyValueListGen(1, 10)) { keyValueList: Seq[(Int, Int)] =>
566+
val trie = addEveryKeyValuePair(keyValueList)
567+
val wrongKey = 22
568+
val proof = trie.getProof(wrongKey)
569+
assert(proof.getOrElse(Vector.empty).toList match {
570+
case _ @ HashNode(_) :: Nil => true
571+
case _ => false
572+
})
573+
}
574+
}
575+
576+
test("PatriciaTrie return proof when having all nibbles in common except the last one between MPT root hash and search key") {
577+
578+
val key = 1111
579+
val wrongKey = 1112
580+
val emptyTrie = MerklePatriciaTrie[Int, Int](emptyEphemNodeStorage)
581+
.put(key, 1)
582+
.put(wrongKey, 2)
583+
val proof = emptyTrie.getProof(key = wrongKey)
584+
assert(proof.getOrElse(Vector.empty).toList match {
585+
case _ @ HashNode(_) :: tail => tail.nonEmpty
586+
case _ => false
587+
})
588+
}
589+
590+
test("getProof returns proof result for non-existing key") {
558591
// given
559592
val EmptyTrie = MerklePatriciaTrie[Array[Byte], Array[Byte]](emptyEphemNodeStorage)
560593
val key1: Array[Byte] = Hex.decode("10000001")
@@ -611,6 +644,12 @@ class MerklePatriciaTrieSuite extends AnyFunSuite with ScalaCheckPropertyChecks
611644
assert(obtained.get == value)
612645
}
613646

647+
private def assertCanGetProofForEveryKeyValue[K, V](trie: MerklePatriciaTrie[K, V], kvs: Seq[(K, V)]): Unit =
648+
kvs.foreach { case (key, _) =>
649+
val obtained = trie.getProof(key)
650+
assert(obtained.getOrElse(Vector.empty).nonEmpty)
651+
}
652+
614653
private def assertCanGetEveryKeyValues[K, V](trie: MerklePatriciaTrie[K, Array[V]], kvs: List[(K, Array[V])]): Unit =
615654
kvs.foreach { case (key, value) =>
616655
val obtained = trie.get(key)

0 commit comments

Comments
 (0)