Skip to content

Add calls-in-place contracts to unsafe operations #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bytestring/common/src/unsafe/UnsafeByteStringOperations.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package kotlinx.io.bytestring.unsafe

import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind.EXACTLY_ONCE
import kotlin.contracts.contract
import kotlinx.io.bytestring.ByteString

/**
Expand All @@ -16,6 +19,7 @@ import kotlinx.io.bytestring.ByteString
* consequences in the code using the byte string and should be avoided at all costs.
*/
@UnsafeByteStringApi
@OptIn(ExperimentalContracts::class)
public object UnsafeByteStringOperations {
/**
* Creates a new byte string by wrapping [array] without copying it.
Expand All @@ -32,6 +36,9 @@ public object UnsafeByteStringOperations {
* Consider using [ByteString.toByteArray] if it's impossible to guarantee that the array won't be modified.
*/
public inline fun withByteArrayUnsafe(byteString: ByteString, block: (ByteArray) -> Unit) {
contract {
callsInPlace(block, EXACTLY_ONCE)
}
block(byteString.getBackingArrayReference())
}
}
24 changes: 24 additions & 0 deletions bytestring/common/test/unsafe/UnsafeByteStringOperationsTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package kotlinx.io.bytestring.unsafe

import kotlin.test.Test
import kotlin.test.assertTrue
import kotlinx.io.bytestring.encodeToByteString

@OptIn(UnsafeByteStringApi::class)
class UnsafeByteStringOperationsTest {
@Test
fun callsInPlaceContract() {
val byteString = "hello byte string".encodeToByteString()

val called: Boolean
UnsafeByteStringOperations.withByteArrayUnsafe(byteString) {
called = true
}
assertTrue(called)
}
}
12 changes: 11 additions & 1 deletion core/common/src/Buffer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
*/
package kotlinx.io

import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind.EXACTLY_ONCE
import kotlin.contracts.contract
import kotlin.jvm.JvmSynthetic

/**
Expand Down Expand Up @@ -706,11 +709,18 @@ public class Buffer : Source, Sink {
*/
@PublishedApi
@JvmSynthetic
@OptIn(ExperimentalContracts::class)
internal inline fun <T> Buffer.seek(
fromIndex: Long,
lambda: (Segment?, Long) -> T
): T {
if (this.head == null) lambda(null, -1L)
contract {
callsInPlace(lambda, EXACTLY_ONCE)
}

if (this.head == null) {
return lambda(null, -1L)
}
Comment on lines +721 to +723
Copy link
Contributor Author

@JakeWharton JakeWharton Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it was a bug. Prior to this, the lambda would be called twice. Once with (null, -1) and then again in the if with (null, 0). Contract validation failed without changing the implementation to early-return here, as I expect was intended.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a bug, indeed. Thanks for catching and fixing it!


if (size - fromIndex < fromIndex) {
var s = tail
Expand Down
9 changes: 8 additions & 1 deletion core/common/src/Sinks.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

package kotlinx.io

import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind.EXACTLY_ONCE
import kotlin.contracts.contract

private val HEX_DIGIT_BYTES = ByteArray(16) {
((if (it < 10) '0'.code else ('a'.code - 10)) + it).toByte()
}
Expand Down Expand Up @@ -351,8 +355,11 @@ public fun Sink.writeDoubleLe(double: Double) {
* @throws IllegalStateException when the sink is closed.
*/
@DelicateIoApi
@OptIn(InternalIoApi::class)
@OptIn(InternalIoApi::class, ExperimentalContracts::class)
public inline fun Sink.writeToInternalBuffer(lambda: (Buffer) -> Unit) {
contract {
callsInPlace(lambda, EXACTLY_ONCE)
}
lambda(this.buffer)
this.hintEmit()
}
31 changes: 31 additions & 0 deletions core/common/src/unsafe/UnsafeBufferOperations.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

package kotlinx.io.unsafe

import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind.EXACTLY_ONCE
import kotlin.contracts.contract
import kotlinx.io.*
import kotlin.jvm.JvmSynthetic

@UnsafeIoApi
@OptIn(ExperimentalContracts::class)
public object UnsafeBufferOperations {
/**
* Maximum value that is safe to pass to [writeToTail].
Expand Down Expand Up @@ -88,6 +92,10 @@ public object UnsafeBufferOperations {
buffer: Buffer,
readAction: (bytes: ByteArray, startIndexInclusive: Int, endIndexExclusive: Int) -> Int
): Int {
contract {
callsInPlace(readAction, EXACTLY_ONCE)
}

require(!buffer.exhausted()) { "Buffer is empty" }
val head = buffer.head!!
val bytesRead = readAction(head.dataAsByteArray(true), head.pos, head.limit)
Expand Down Expand Up @@ -128,6 +136,10 @@ public object UnsafeBufferOperations {
* @sample kotlinx.io.samples.unsafe.UnsafeBufferOperationsSamples.readUleb128
*/
public inline fun readFromHead(buffer: Buffer, readAction: (SegmentReadContext, Segment) -> Int): Int {
contract {
callsInPlace(readAction, EXACTLY_ONCE)
}

require(!buffer.exhausted()) { "Buffer is empty" }
val head = buffer.head!!
val bytesRead = readAction(SegmentReadContextImpl, head)
Expand Down Expand Up @@ -176,6 +188,10 @@ public object UnsafeBufferOperations {
buffer: Buffer, minimumCapacity: Int,
writeAction: (bytes: ByteArray, startIndexInclusive: Int, endIndexExclusive: Int) -> Int
): Int {
contract {
callsInPlace(writeAction, EXACTLY_ONCE)
}

val tail = buffer.writableSegment(minimumCapacity)

val data = tail.dataAsByteArray(false)
Expand Down Expand Up @@ -240,6 +256,10 @@ public object UnsafeBufferOperations {
minimumCapacity: Int,
writeAction: (SegmentWriteContext, Segment) -> Int
): Int {
contract {
callsInPlace(writeAction, EXACTLY_ONCE)
}

val tail = buffer.writableSegment(minimumCapacity)
val bytesWritten = writeAction(SegmentWriteContextImpl, tail)

Expand Down Expand Up @@ -285,6 +305,9 @@ public object UnsafeBufferOperations {
* @sample kotlinx.io.samples.unsafe.UnsafeBufferOperationsSamples.crc32Unsafe
*/
public inline fun iterate(buffer: Buffer, iterationAction: (BufferIterationContext, Segment?) -> Unit) {
contract {
callsInPlace(iterationAction, EXACTLY_ONCE)
}
iterationAction(BufferIterationContextImpl, buffer.head)
}

Expand Down Expand Up @@ -314,6 +337,10 @@ public object UnsafeBufferOperations {
buffer: Buffer, offset: Long,
iterationAction: (BufferIterationContext, Segment?, Long) -> Unit
) {
contract {
callsInPlace(iterationAction, EXACTLY_ONCE)
}

require(offset >= 0) { "Offset must be non-negative: $offset" }
if (offset >= buffer.size) {
throw IndexOutOfBoundsException("Offset should be less than buffer's size (${buffer.size}): $offset")
Expand Down Expand Up @@ -365,7 +392,11 @@ public interface SegmentReadContext {
*/
@UnsafeIoApi
@JvmSynthetic
@OptIn(ExperimentalContracts::class)
public inline fun SegmentReadContext.withData(segment: Segment, readAction: (ByteArray, Int, Int) -> Unit) {
contract {
callsInPlace(readAction, EXACTLY_ONCE)
}
readAction(segment.dataAsByteArray(true), segment.pos, segment.limit)
}

Expand Down
12 changes: 12 additions & 0 deletions core/common/test/DelicateApiTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@ package kotlinx.io

import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue

@OptIn(DelicateIoApi::class)
class DelicateApiTest {
@Test
fun callsInPlaceContract() {
val sink: Sink = Buffer()

val called: Boolean
sink.writeToInternalBuffer {
called = true
}
assertTrue(called)
}

@Test
@OptIn(InternalIoApi::class)
fun testWriteIntoBuffer() {
Expand Down
29 changes: 29 additions & 0 deletions core/common/test/unsafe/UnsafeBufferOperationsIterationTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@ import kotlin.test.*
@OptIn(UnsafeIoApi::class)

class UnsafeBufferOperationsIterationTest {
@Test
fun callsInPlaceContract() {
val buffer = Buffer().also { it.writeString("hello buffer") }

val called: Boolean
UnsafeBufferOperations.iterate(buffer) { ctx, segment ->
called = true

val withDataCalled: Boolean
ctx.withData(segment!!) { _, _, _ ->
withDataCalled = true
}
assertTrue(withDataCalled)
}
assertTrue(called)

val offsetCalled: Boolean
UnsafeBufferOperations.iterate(buffer, 1) { ctx, segment, _ ->
offsetCalled = true

val withDataCalled: Boolean
ctx.withData(segment!!) { _, _, _ ->
withDataCalled = true
}
assertTrue(withDataCalled)
}
assertTrue(offsetCalled)
}

@Test
fun emptyBuffer() {
UnsafeBufferOperations.iterate(Buffer()) { _, head ->
Expand Down
19 changes: 19 additions & 0 deletions core/common/test/unsafe/UnsafeBufferOperationsReadTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@ import kotlin.test.fail
class UnsafeBufferOperationsReadTest {
private class TestException : RuntimeException()

@Test
fun callsInPlaceContract() {
val buffer = Buffer().apply { writeString("hello world") }

val bytesCalled: Boolean
UnsafeBufferOperations.readFromHead(buffer) { _, _, _ ->
bytesCalled = true
0
}
assertTrue(bytesCalled)

val segmentsCalled: Boolean
UnsafeBufferOperations.readFromHead(buffer) { _, _ ->
segmentsCalled = true
0
}
assertTrue(segmentsCalled)
}

@Test
fun bufferCapacity() {
val buffer = Buffer().apply { writeString("hello world") }
Expand Down
17 changes: 17 additions & 0 deletions core/common/test/unsafe/UnsafeBufferOperationsWriteTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,23 @@ import kotlin.test.*
class UnsafeBufferOperationsWriteTest {
private class TestException : RuntimeException()

@Test
fun callsInPlaceContract() {
val bytesCalled: Boolean
UnsafeBufferOperations.writeToTail(Buffer(), 1) { _, _, _ ->
bytesCalled = true
0
}
assertTrue(bytesCalled)

val segmentsCalled: Boolean
UnsafeBufferOperations.writeToTail(Buffer(), 1) { _, _ ->
segmentsCalled = true
0
}
assertTrue(segmentsCalled)
}

@Test
fun bufferCapacity() {
val buffer = Buffer()
Expand Down
16 changes: 16 additions & 0 deletions core/jvm/src/unsafe/UnsafeBufferOperationsJvm.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package kotlinx.io.unsafe

import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind.EXACTLY_ONCE
import kotlin.contracts.contract
import kotlinx.io.Buffer
import kotlinx.io.Segment
import kotlinx.io.UnsafeIoApi
Expand Down Expand Up @@ -39,7 +42,11 @@ import java.nio.ByteBuffer
* @sample kotlinx.io.samples.unsafe.UnsafeReadWriteSamplesJvm.writeToByteChannel
*/
@UnsafeIoApi
@OptIn(ExperimentalContracts::class)
public inline fun UnsafeBufferOperations.readFromHead(buffer: Buffer, readAction: (ByteBuffer) -> Unit): Int {
contract {
callsInPlace(readAction, EXACTLY_ONCE)
}
return readFromHead(buffer) { rawData, pos, limit ->
val bb = ByteBuffer.wrap(rawData, pos, limit - pos).slice().asReadOnlyBuffer()
readAction(bb)
Expand Down Expand Up @@ -81,11 +88,15 @@ public inline fun UnsafeBufferOperations.readFromHead(buffer: Buffer, readAction
* @sample kotlinx.io.samples.unsafe.UnsafeReadWriteSamplesJvm.readFromByteChannel
*/
@UnsafeIoApi
@OptIn(ExperimentalContracts::class)
public inline fun UnsafeBufferOperations.writeToTail(
buffer: Buffer,
minimumCapacity: Int,
writeAction: (ByteBuffer) -> Unit
): Int {
contract {
callsInPlace(writeAction, EXACTLY_ONCE)
}
return writeToTail(buffer, minimumCapacity) { rawData, pos, limit ->
val bb = ByteBuffer.wrap(rawData, pos, limit - pos).slice()
writeAction(bb)
Expand Down Expand Up @@ -134,11 +145,16 @@ public inline fun UnsafeBufferOperations.writeToTail(
*
*/
@UnsafeIoApi
@OptIn(ExperimentalContracts::class)
public inline fun UnsafeBufferOperations.readBulk(
buffer: Buffer,
iovec: Array<ByteBuffer?>,
readAction: (iovec: Array<ByteBuffer?>, iovecSize: Int) -> Long
): Long {
contract {
callsInPlace(readAction, EXACTLY_ONCE)
}

val head = buffer.head ?: throw IllegalArgumentException("buffer is empty.")
if (iovec.isEmpty()) throw IllegalArgumentException("iovec is empty.")

Expand Down
13 changes: 13 additions & 0 deletions core/jvm/test/unsafe/UnsafeBufferOperationsJvmReadBulkTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ import kotlin.test.*
class UnsafeBufferOperationsJvmReadBulkTest {
private class TestException : RuntimeException()

@Test
fun callsInPlaceContract() {
val buffer = Buffer().apply { writeString("hello world") }
val array = Array<ByteBuffer?>(16) { null }

val called: Boolean
UnsafeBufferOperations.readBulk(buffer, array) { _, _ ->
called = true
0
}
assertTrue(called)
}

@Test
fun readAllFromEmptyBuffer() {
assertFailsWith<IllegalArgumentException> {
Expand Down
11 changes: 11 additions & 0 deletions core/jvm/test/unsafe/UnsafeBufferOperationsJvmReadFromHeadTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ import kotlin.test.fail
class UnsafeBufferOperationsJvmReadFromHeadTest {
private class TestException : RuntimeException()

@Test
fun callsInPlaceContract() {
val buffer = Buffer().apply { writeString("hello world") }

val called: Boolean
UnsafeBufferOperations.readFromHead(buffer) { _ ->
called = true
}
assertTrue(called)
}

@Test
fun bufferCapacity() {
val buffer = Buffer().apply { writeString("hello world") }
Expand Down
Loading