Skip to content

Reimplement Source::indexOf(ByteString) without Source::peek calls #242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/api/kotlinx-io-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ public final class kotlinx/io/BuffersKt {
}

public final class kotlinx/io/ByteStringsKt {
public static final fun indexOf (Lkotlinx/io/Buffer;Lkotlinx/io/bytestring/ByteString;J)J
public static final fun indexOf (Lkotlinx/io/Source;Lkotlinx/io/bytestring/ByteString;J)J
public static synthetic fun indexOf$default (Lkotlinx/io/Buffer;Lkotlinx/io/bytestring/ByteString;JILjava/lang/Object;)J
public static synthetic fun indexOf$default (Lkotlinx/io/Source;Lkotlinx/io/bytestring/ByteString;JILjava/lang/Object;)J
public static final fun readByteString (Lkotlinx/io/Source;)Lkotlinx/io/bytestring/ByteString;
public static final fun readByteString (Lkotlinx/io/Source;I)Lkotlinx/io/bytestring/ByteString;
Expand Down
83 changes: 54 additions & 29 deletions core/common/src/ByteStrings.kt
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,63 @@ public fun Source.indexOf(byteString: ByteString, startIndex: Long = 0): Long {
}

var offset = startIndex
val peek = peek()
if (!request(startIndex)) {
return -1L
while (request(offset + byteString.size)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just write an idea for the future: if a request appears, it may be worth adding a modification of the function to which fetchSize can be passed. For example, it can help if byteString.size == 2, so that there is no frequent reading of two bytes (which will most likely be slower than reading a large chunk)

val idx = buffer.indexOf(byteString, offset)
if (idx < 0) {
// The buffer does not contain the pattern, let's try fetching at least one extra byte
// and start a new search attempt so that the pattern would fit in the suffix of
// the current buffer + 1 extra byte.
offset = buffer.size - byteString.size + 1
} else {
return idx
}
}
peek.skip(offset)
var resultingIndex = -1L
UnsafeByteStringOperations.withByteArrayUnsafe(byteString) { data ->
while (!peek.exhausted()) {
val index = peek.indexOf(data[0])
if (index == -1L) {
return@withByteArrayUnsafe
}
offset += index
peek.skip(index)
if (!peek.request(byteString.size.toLong())) {
return@withByteArrayUnsafe
}
return -1
}

var matches = true
for (idx in data.indices) {
if (data[idx] != peek.buffer[idx.toLong()]) {
matches = false
offset++
peek.skip(1)
break
}
}
if (matches) {
resultingIndex = offset
return@withByteArrayUnsafe
@OptIn(UnsafeByteStringApi::class)
public fun Buffer.indexOf(byteString: ByteString, startIndex: Long = 0): Long {
require(startIndex <= size) {
"startIndex ($startIndex) should not exceed size ($size)"
}
if (byteString.isEmpty()) return 0
if (startIndex > size - byteString.size) return -1L

UnsafeByteStringOperations.withByteArrayUnsafe(byteString) { byteStringData ->
seek(startIndex) { seg, o ->
if (o == -1L) {
return -1L
}
var segment = seg!!
var offset = o
do {
// If start index within this segment, the diff will be positive and
// we'll scan the segment starting from the corresponding offset.
// Otherwise, the diff will be negative and we'll scan the segment from the beginning.
val startOffset = maxOf((startIndex - offset).toInt(), 0)
// Try to search the pattern within the current segment.
val idx = segment.indexOfBytesInbound(byteStringData, startOffset)
if (idx != -1) {
// The offset corresponds to the segment's start, idx - to offset within the segment.
return offset + idx.toLong()
}
// firstOutboundOffset corresponds to a first byte starting reading the pattern from which
// will result in running out of the current segment bounds.
val firstOutboundOffset = maxOf(startOffset, segment.size - byteStringData.size + 1)
// Try to find a pattern in all suffixes shorter than the pattern. These suffixes start
// in the current segment, but ends in the following segments; thus we're using outbound function.
val idx1 = segment.indexOfBytesOutbound(byteStringData, firstOutboundOffset, head)
if (idx1 != -1) {
// Offset corresponds to the segment's start, idx - to offset within the segment.
return offset + idx1.toLong()
}

// We scanned the whole segment, so let's go to the next one
offset += segment.size
segment = segment.next!!
} while (segment !== head && offset + byteString.size <= size)
return -1L
}
}
return resultingIndex
return -1
}
76 changes: 75 additions & 1 deletion core/common/src/Segment.kt
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ internal fun Segment.indexOf(byte: Byte, startOffset: Int, endOffset: Int): Int
require(startOffset in 0 until size) {
"$startOffset"
}
require(endOffset in startOffset..size) { "$endOffset" }
require(endOffset in startOffset..size) {
"$endOffset"
}
val p = pos
for (idx in startOffset until endOffset) {
if (data[p + idx] == byte) {
Expand All @@ -210,3 +212,75 @@ internal fun Segment.indexOf(byte: Byte, startOffset: Int, endOffset: Int): Int
}
return -1
}

/**
* Searches for a `bytes` pattern within this segment starting at the offset `startOffset`.
* `startOffset` is relative and should be within `[0, size)`.
*/
internal fun Segment.indexOfBytesInbound(bytes: ByteArray, startOffset: Int): Int {
// require(startOffset in 0 until size)
var offset = startOffset
val limit = size - bytes.size + 1
val firstByte = bytes[0]
while (offset < limit) {
val idx = indexOf(firstByte, offset, limit)
if (idx < 0) {
return -1
}
var found = true
for (innerIdx in 1 until bytes.size) {
if (data[pos + idx + innerIdx] != bytes[innerIdx]) {
found = false
break
}
}
if (found) {
return idx
} else {
offset++
}
}
return -1
}

/**
* Searches for a `bytes` pattern starting in between offset `startOffset` and `size` within this segment
* and continued in the following segments.
* `startOffset` is relative and should be within `[0, size)`.
*/
internal fun Segment.indexOfBytesOutbound(bytes: ByteArray, startOffset: Int, head: Segment?): Int {
var offset = startOffset
val firstByte = bytes[0]

while (offset in 0 until size) {
val idx = indexOf(firstByte, offset, size)
if (idx < 0) {
return -1
}
// The pattern should start in this segment
var seg = this
var scanOffset = offset

var found = true
for (element in bytes) {
// We ran out of bytes in this segment,
// so let's take the next one and continue the scan there.
if (scanOffset == seg.size) {
val next = seg.next
if (next === head) return -1
seg = next!!
scanOffset = 0 // we're scanning the next segment right from the beginning
}
if (element != seg.data[seg.pos + scanOffset]) {
found = false
break
}
scanOffset++
}
if (found) {
return offset
}
offset++
}
return -1
}
13 changes: 13 additions & 0 deletions core/common/test/AbstractSourceTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -1769,4 +1769,17 @@ abstract class AbstractBufferedSourceTest internal constructor(
assertEquals((Segment.SIZE * 2 + 1).toLong(), source.indexOf("fg".encodeToByteString()))
assertEquals((Segment.SIZE * 2 + 2).toLong(), source.indexOf("g".encodeToByteString()))
}

@Test
fun indexOfByteStringSpanningAcrossMultipleSegments() {
sink.writeString("a".repeat(SEGMENT_SIZE))
sink.emit()
sink.writeString("bbbb")
sink.emit()
sink.write(Buffer().also { it.writeString("c".repeat(SEGMENT_SIZE)) }, SEGMENT_SIZE.toLong())
sink.emit()

source.skip(SEGMENT_SIZE - 10L)
assertEquals(9, source.indexOf("abbbbc".encodeToByteString()))
}
}