Skip to content

ETCM-199: Remove ConnectableSubject #106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ trait ScalanetModule extends ScalaModule {
trait ScalanetPublishModule extends PublishModule {
def description: String

override def publishVersion = "0.4.4-SNAPSHOT"
override def publishVersion = "0.5.0-SNAPSHOT"

override def pomSettings = PomSettings(
description = description,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.typesafe.scalalogging.LazyLogging
import io.iohk.scalanet.discovery.crypto.{PrivateKey, PublicKey, SigAlg}
import io.iohk.scalanet.discovery.ethereum.Node
import io.iohk.scalanet.discovery.hash.Hash
import io.iohk.scalanet.peergroup.implicits.NextOps
import io.iohk.scalanet.peergroup.{Addressable, Channel, PeerGroup}
import io.iohk.scalanet.peergroup.PeerGroup.ServerEvent.ChannelCreated
import io.iohk.scalanet.peergroup.Channel.{MessageReceived, DecodingError, UnexpectedError}
Expand Down Expand Up @@ -83,8 +84,7 @@ object DiscoveryNetwork {
override def startHandling(handler: DiscoveryRPC[Peer[A]]): Task[CancelableF[Task]] =
for {
cancelToken <- Deferred[Task, Unit]
_ <- peerGroup
.nextServerEvent()
_ <- peerGroup.nextServerEvent
.withCancelToken(cancelToken)
.toIterant
.mapEval {
Expand All @@ -111,8 +111,7 @@ object DiscoveryNetwork {
channel: Channel[A, Packet],
cancelToken: Deferred[Task, Unit]
): Task[Unit] = {
channel
.nextMessage()
channel.nextChannelEvent
.withCancelToken(cancelToken)
.timeout(config.messageExpiration) // Messages older than this would be ignored anyway.
.toIterant
Expand Down Expand Up @@ -315,8 +314,7 @@ object DiscoveryNetwork {
// The absolute end we are willing to wait for the correct message to arrive.
deadline: Deadline
)(pf: PartialFunction[Payload.Response, T]): Iterant[Task, T] =
channel
.nextMessage()
channel.nextChannelEvent
.timeoutL(Task(config.requestTimeout.min(deadline.timeLeft)))
.toIterant
.collect {
Expand Down Expand Up @@ -394,18 +392,6 @@ object DiscoveryNetwork {
}
}

// Functions to be applied on the `.nextMessage()` or `.nextServerEvent()` results.
private implicit class NextOps[A](next: Task[Option[A]]) {
def toIterant: Iterant[Task, A] =
Iterant.repeatEvalF(next).takeWhile(_.isDefined).map(_.get)

def withCancelToken(token: Deferred[Task, Unit]): Task[Option[A]] =
Task.race(token.get, next).map {
case Left(()) => None
case Right(x) => x
}
}

/** Estimate how many neihbors we can fit in the maximum protol message size. */
def getMaxNeighborsPerPacket(implicit codec: Codec[Payload], sigalg: SigAlg): Int = {
val sampleNode = Node(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ object DiscoveryService {
cancelToken <- network.startHandling(service)
// Contact the bootstrap nodes.
// Setting the enrolled status here because we could potentially repeat enrollment until it succeeds.
enroll = service.enroll().guarantee(stateRef.update(_.setEnrolled))
enroll = service.enroll.guarantee(stateRef.update(_.setEnrolled))
// Periodically discover new nodes.
discover = service.lookupRandom.delayExecution(config.discoveryPeriod).loopForever
// Enrollment can be run in the background if it takes very long.
Expand Down Expand Up @@ -871,7 +871,7 @@ object DiscoveryService {
* or `false` if none of them responded with a correct ENR,
* which would mean we don't have anyone to do lookups with.
*/
protected[v4] def enroll(): Task[Boolean] =
protected[v4] def enroll: Task[Boolean] =
if (config.knownPeers.isEmpty)
Task.pure(false)
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import io.iohk.scalanet.kademlia.KMessage.{KRequest, KResponse}
import io.iohk.scalanet.kademlia.KMessage.KRequest.{FindNodes, Ping}
import io.iohk.scalanet.kademlia.KMessage.KResponse.{Nodes, Pong}
import io.iohk.scalanet.kademlia.KRouter.NodeRecord
import io.iohk.scalanet.peergroup.implicits._
import io.iohk.scalanet.peergroup.{Channel, PeerGroup}
import io.iohk.scalanet.peergroup.Channel.MessageReceived
import monix.eval.Task
Expand Down Expand Up @@ -51,7 +52,7 @@ object KNetwork {
) extends KNetwork[A] {

override lazy val kRequests: Observable[(KRequest[A], Option[KResponse[A]] => Task[Unit])] = {
peerGroup.server.refCount.collectChannelCreated
peerGroup.serverEventObservable.collectChannelCreated
.mergeMap {
case (channel: Channel[A, KMessage[A]], release: Task[Unit]) =>
// NOTE: We cannot use mapEval with a Task here, because that would hold up
Expand All @@ -60,7 +61,7 @@ object KNetwork {
// discards, `headL` would eventually time out but while we wait for
// that the next incoming channel would not be picked up.
Observable.fromTask {
channel.in.refCount
channel.channelEventObservable
.collect { case MessageReceived(req: KRequest[A]) => req }
.headL
.timeout(requestTimeout)
Expand Down Expand Up @@ -112,7 +113,8 @@ object KNetwork {
): Task[Response] = {
for {
_ <- clientChannel.sendMessage(message).timeout(requestTimeout)
response <- clientChannel.in.refCount
// This assumes that `requestTemplate` always opens a new channel.
response <- clientChannel.channelEventObservable
.collect {
case MessageReceived(m) if pf.isDefinedAt(m) => pf(m)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers {
)

override val test = for {
enrolled <- service.enroll()
enrolled <- service.enroll
state <- stateRef.get
} yield {
enrolled shouldBe true
Expand All @@ -840,7 +840,7 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers {
)

override val test = for {
enrolled <- service.enroll()
enrolled <- service.enroll
state <- stateRef.get
} yield {
enrolled shouldBe false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class MockPeerGroup[A, M](
private val serverEvents = ConcurrentQueue[Task].unsafe[ServerEvent[A, M]](BufferCapacity.Unbounded())

// Intended for the System Under Test to read incoming channels.
override def nextServerEvent(): Task[Option[PeerGroup.ServerEvent[A, M]]] =
override def nextServerEvent: Task[Option[PeerGroup.ServerEvent[A, M]]] =
serverEvents.poll.map(Some(_))

// Intended for the System Under Test to open outgoing channels.
Expand Down Expand Up @@ -69,7 +69,7 @@ class MockChannel[A, M](
messagesFromSUT.offer(MessageReceived(message))

// Messages consumed by the System Under Test.
override def nextMessage(): Task[Option[Channel.ChannelEvent[M]]] =
override def nextChannelEvent: Task[Option[Channel.ChannelEvent[M]]] =
messagesToSUT.poll.map(Some(_))

// Send a message from the test.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ object KNetworkRequestProcessing {
Some((p, h))
case (_, h) =>
ignore(h)
None
}
.collect { case Some(v) => v }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import monix.execution.Scheduler.Implicits.global
import org.scalatest.concurrent.ScalaFutures._
import io.iohk.scalanet.TaskValues._
import KNetworkSpec._
import io.iohk.scalanet.monix_subject.ConnectableSubject
import io.iohk.scalanet.peergroup.Channel.MessageReceived
import io.iohk.scalanet.peergroup.PeerGroup.ServerEvent.ChannelCreated
import io.iohk.scalanet.kademlia.KMessage.KRequest
import org.scalatest.prop.TableDrivenPropertyChecks._
import io.iohk.scalanet.peergroup.PeerGroup.ServerEvent
import io.iohk.scalanet.peergroup.Channel.ChannelEvent
import java.util.concurrent.atomic.AtomicInteger

class KNetworkSpec extends FlatSpec {
import KNetworkRequestProcessing._
Expand Down Expand Up @@ -65,9 +67,8 @@ class KNetworkSpec extends FlatSpec {

forAll(rpcs) { (label, request, response, requestExtractor, clientRpc) =>
s"Server $label" should "not close server channels while yielding requests (it is the responsibility of the response handler)" in new Fixture {
when(peerGroup.server)
.thenReturn(ConnectableSubject(Observable.eval(channelCreated)))
when(channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(request))))
mockServerEvents(peerGroup, channelCreated)
mockChannelEvents(channel, MessageReceived(request))

val actualRequest = requestExtractor(network).evaluated

Expand All @@ -76,20 +77,21 @@ class KNetworkSpec extends FlatSpec {
}

s"Server $label" should "close server channels when a request does not arrive before a timeout" in new Fixture {
when(peerGroup.server)
.thenReturn(ConnectableSubject(Observable.eval(channelCreated)))
when(channel.in).thenReturn(ConnectableSubject(Observable.never))
mockServerEvents(peerGroup, channelCreated)
mockChannelEvents(channel)

val t = requestExtractor(network).runToFuture.failed.futureValue

// The timeout on the channel doesn't cause this exception, but rather the fact
// that there's no subsequent server event and the server observable
// gets closed, so `getActualRequest` fails because it uses `.headL`.
t shouldBe a[NoSuchElementException]
channelClosed.get shouldBe true
}

s"Server $label" should "close server channel in the response task" in new Fixture {
when(peerGroup.server)
.thenReturn(ConnectableSubject(Observable.eval(channelCreated)))
when(channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(request))))
mockServerEvents(peerGroup, channelCreated)
mockChannelEvents(channel, MessageReceived(request))
when(channel.sendMessage(response)).thenReturn(Task.unit)

sendResponse(network, response).evaluated
Expand All @@ -98,9 +100,8 @@ class KNetworkSpec extends FlatSpec {
}

s"Server $label" should "close server channel in timed out response task" in new Fixture {
when(peerGroup.server)
.thenReturn(ConnectableSubject(Observable.eval(channelCreated)))
when(channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(request))))
mockServerEvents(peerGroup, channelCreated)
mockChannelEvents(channel, MessageReceived(request))
when(channel.sendMessage(response)).thenReturn(Task.never)

sendResponse(network, response).evaluatedFailure shouldBe a[TimeoutException]
Expand All @@ -111,9 +112,9 @@ class KNetworkSpec extends FlatSpec {
val channel1 = new MockChannel
val channel2 = new MockChannel

when(peerGroup.server).thenReturn(ConnectableSubject(Observable(channel1.created, channel2.created)))
when(channel1.channel.in).thenReturn(ConnectableSubject(Observable.never)) // Should be closed after a timeout.
when(channel2.channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(request))))
mockServerEvents(peerGroup, channel1.created, channel2.created)
mockChannelEvents(channel1.channel)
mockChannelEvents(channel2.channel, MessageReceived(request))

// Process incoming channels and requests. Need to wait a little to allow channel1 to time out.
val actualRequest = requestExtractor(network).delayResult(requestTimeout).evaluated
Expand All @@ -126,7 +127,7 @@ class KNetworkSpec extends FlatSpec {
s"Client $label" should "close client channels when requests are successful" in new Fixture {
when(peerGroup.client(targetRecord.routingAddress)).thenReturn(channelResource)
when(channel.sendMessage(request)).thenReturn(Task.unit)
when(channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(response))))
mockChannelEvents(channel, MessageReceived(response))

val actualResponse = clientRpc(network).evaluated

Expand Down Expand Up @@ -155,7 +156,7 @@ class KNetworkSpec extends FlatSpec {
s"Client $label" should "close client channels when response fails to arrive" in new Fixture {
when(peerGroup.client(targetRecord.routingAddress)).thenReturn(channelResource)
when(channel.sendMessage(request)).thenReturn(Task.unit)
when(channel.in).thenReturn(ConnectableSubject(Observable.fromTask(Task.never)))
mockChannelEvents(channel)

clientRpc(network).evaluatedFailure shouldBe a[TimeoutException]
channelClosed.get shouldBe true
Expand All @@ -165,11 +166,10 @@ class KNetworkSpec extends FlatSpec {
s"In consuming only PING" should "channels should be closed for unhandled FIND_NODES requests" in new Fixture {
val channel1 = new MockChannel
val channel2 = new MockChannel
when(peerGroup.server)
.thenReturn(ConnectableSubject(Observable(channel1.created, channel2.created)))
mockServerEvents(peerGroup, channel1.created, channel2.created)

when(channel1.channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(findNodes))))
when(channel2.channel.in).thenReturn(ConnectableSubject(Observable.eval(MessageReceived(ping))))
mockChannelEvents(channel1.channel, MessageReceived(findNodes))
mockChannelEvents(channel2.channel, MessageReceived(ping))

when(channel2.channel.sendMessage(pong)).thenReturn(Task.unit)

Expand Down Expand Up @@ -200,10 +200,31 @@ object KNetworkSpec {

private def createKNetwork: (KNetwork[String], PeerGroup[String, KMessage[String]]) = {
val peerGroup = mock[PeerGroup[String, KMessage[String]]]
when(peerGroup.server).thenReturn(ConnectableSubject(Observable.empty))
when(peerGroup.nextServerEvent).thenReturn(Task.pure(None))
(new KNetworkScalanetImpl(peerGroup, requestTimeout), peerGroup)
}

private def mockServerEvents(
peerGroup: PeerGroup[String, KMessage[String]],
events: ServerEvent[String, KMessage[String]]*
) =
when(peerGroup.nextServerEvent).thenReturn(nextTask(events, complete = true))

private def mockChannelEvents(
channel: Channel[String, KMessage[String]],
events: ChannelEvent[KMessage[String]]*
) =
when(channel.nextChannelEvent).thenReturn(nextTask(events, complete = false))

private def nextTask[T](events: Seq[T], complete: Boolean): Task[Option[T]] = {
val count = new AtomicInteger(0)
Task(count.getAndIncrement()).flatMap {
case i if i < events.size => Task(Some(events(i)))
case _ if complete => Task(None)
case _ => Task.never
}
}

private def getActualRequest[Request <: KRequest[String]](rpc: KNetwork[String] => Observable[(Request, _)])(
network: KNetwork[String]
): Task[Request] = {
Expand Down

This file was deleted.

Loading