Skip to content

Commit 1d5f262

Browse files
authored
Prepare zlib's inflater for Kotlin/Native (#1422)
* Prepare zlib's inflater for Kotlin/Native * Don't check the result of deflateEnd It is different for different zlib versions. In particular, it returns Z_DATA_ERROR if the stream is closed without being used.
1 parent 062048a commit 1d5f262

File tree

4 files changed

+335
-3
lines changed

4 files changed

+335
-3
lines changed

okio/src/nativeMain/kotlin/okio/Deflater.kt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import platform.zlib.deflateEnd
3737
import platform.zlib.deflateInit2
3838
import platform.zlib.z_stream_s
3939

40-
private val emptyByteArray = byteArrayOf()
40+
internal val emptyByteArray = byteArrayOf()
4141

4242
/**
4343
* Deflate using Kotlin/Native's built-in zlib bindings. This uses the raw deflate format and omits
@@ -145,8 +145,7 @@ internal class Deflater : Closeable {
145145
if (closed) return
146146
closed = true
147147

148-
val deflateEndResult = deflateEnd(zStream.ptr)
149-
check(deflateEndResult == Z_OK)
148+
deflateEnd(zStream.ptr)
150149
nativeHeap.free(zStream)
151150
}
152151
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright (C) 2024 Square, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package okio
17+
18+
import kotlinx.cinterop.CPointer
19+
import kotlinx.cinterop.UByteVar
20+
import kotlinx.cinterop.addressOf
21+
import kotlinx.cinterop.alloc
22+
import kotlinx.cinterop.free
23+
import kotlinx.cinterop.nativeHeap
24+
import kotlinx.cinterop.ptr
25+
import kotlinx.cinterop.usePinned
26+
import platform.zlib.Z_BUF_ERROR
27+
import platform.zlib.Z_DATA_ERROR
28+
import platform.zlib.Z_NO_FLUSH
29+
import platform.zlib.Z_OK
30+
import platform.zlib.Z_STREAM_END
31+
import platform.zlib.inflateEnd
32+
import platform.zlib.inflateInit2
33+
import platform.zlib.z_stream_s
34+
35+
/**
36+
* Inflate using Kotlin/Native's built-in zlib bindings.
37+
*
38+
* The API is symmetric with [Deflater].
39+
*/
40+
internal class Inflater : Closeable {
41+
private val zStream: z_stream_s = nativeHeap.alloc<z_stream_s> {
42+
zalloc = null
43+
zfree = null
44+
opaque = null
45+
check(
46+
inflateInit2(
47+
strm = ptr,
48+
windowBits = -15, // Default value for raw deflate.
49+
) == Z_OK,
50+
)
51+
}
52+
53+
var source: ByteArray = emptyByteArray
54+
var sourcePos: Int = 0
55+
var sourceLimit: Int = 0
56+
57+
var target: ByteArray = emptyByteArray
58+
var targetPos: Int = 0
59+
var targetLimit: Int = 0
60+
61+
private var closed = false
62+
63+
/**
64+
* Returns true if no further calls to [inflate] are required because the source stream is
65+
* finished. Otherwise, ensure there's input data in [source] and output space in [target] and
66+
* call this again.
67+
*/
68+
fun inflate(): Boolean {
69+
check(!closed) { "closed" }
70+
require(0 <= sourcePos && sourcePos <= sourceLimit && sourceLimit <= source.size)
71+
require(0 <= targetPos && targetPos <= targetLimit && targetLimit <= target.size)
72+
73+
source.usePinned { pinnedSource ->
74+
target.usePinned { pinnedTarget ->
75+
val sourceByteCount = sourceLimit - sourcePos
76+
zStream.next_in = when {
77+
sourceByteCount > 0 -> pinnedSource.addressOf(sourcePos) as CPointer<UByteVar>
78+
else -> null
79+
}
80+
zStream.avail_in = sourceByteCount.toUInt()
81+
82+
val targetByteCount = targetLimit - targetPos
83+
zStream.next_out = when {
84+
targetByteCount > 0 -> pinnedTarget.addressOf(targetPos) as CPointer<UByteVar>
85+
else -> null
86+
}
87+
zStream.avail_out = targetByteCount.toUInt()
88+
89+
val inflateResult = platform.zlib.inflate(zStream.ptr, Z_NO_FLUSH)
90+
91+
sourcePos += sourceByteCount - zStream.avail_in.toInt()
92+
targetPos += targetByteCount - zStream.avail_out.toInt()
93+
94+
return when (inflateResult) {
95+
Z_OK -> false
96+
Z_BUF_ERROR -> false // Non-fatal but the caller needs to update source and/or target.
97+
Z_STREAM_END -> true
98+
Z_DATA_ERROR -> throw ProtocolException("Z_DATA_ERROR")
99+
100+
// One of Z_NEED_DICT, Z_STREAM_ERROR, Z_MEM_ERROR.
101+
else -> throw ProtocolException("unexpected inflate result: $inflateResult")
102+
}
103+
}
104+
}
105+
}
106+
107+
override fun close() {
108+
if (closed) return
109+
closed = true
110+
111+
inflateEnd(zStream.ptr)
112+
nativeHeap.free(zStream)
113+
}
114+
}

okio/src/nativeTest/kotlin/okio/DeflaterTest.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package okio
1717

1818
import kotlin.test.Test
1919
import kotlin.test.assertEquals
20+
import kotlin.test.assertFailsWith
2021
import kotlin.test.assertFalse
2122
import kotlin.test.assertTrue
2223
import okio.ByteString.Companion.decodeBase64
@@ -174,4 +175,21 @@ class DeflaterTest {
174175

175176
deflater.close()
176177
}
178+
179+
@Test
180+
fun cannotDeflateAfterClose() {
181+
val deflater = Deflater()
182+
deflater.close()
183+
184+
assertFailsWith<IllegalStateException> {
185+
deflater.deflate()
186+
}
187+
}
188+
189+
@Test
190+
fun closeIsIdemptent() {
191+
val deflater = Deflater()
192+
deflater.close()
193+
deflater.close()
194+
}
177195
}
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Copyright (C) 2024 Square, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package okio
17+
18+
import kotlin.test.Test
19+
import kotlin.test.assertEquals
20+
import kotlin.test.assertFailsWith
21+
import kotlin.test.assertFalse
22+
import kotlin.test.assertTrue
23+
import okio.ByteString.Companion.decodeBase64
24+
import okio.ByteString.Companion.decodeHex
25+
import okio.ByteString.Companion.toByteString
26+
27+
class InflaterTest {
28+
@Test
29+
fun happyPath() {
30+
val inflater = Inflater().apply {
31+
source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA="
32+
.decodeBase64()!!.toByteArray()
33+
sourcePos = 0
34+
sourceLimit = source.size
35+
36+
target = ByteArray(256)
37+
targetPos = 0
38+
targetLimit = target.size
39+
}
40+
41+
assertTrue(inflater.inflate())
42+
assertEquals(inflater.sourceLimit, inflater.sourcePos)
43+
44+
val inflated = inflater.target.toByteString(0, inflater.targetPos)
45+
assertEquals(
46+
"God help us, we're in the hands of engineers.",
47+
inflated.utf8(),
48+
)
49+
50+
inflater.close()
51+
}
52+
53+
@Test
54+
fun inflateInParts() {
55+
val inflater = Inflater().apply {
56+
target = ByteArray(256)
57+
targetPos = 0
58+
targetLimit = target.size
59+
}
60+
61+
inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJ".decodeBase64()!!.toByteArray()
62+
inflater.sourcePos = 0
63+
inflater.sourceLimit = inflater.source.size
64+
assertFalse(inflater.inflate())
65+
assertEquals(inflater.sourceLimit, inflater.sourcePos)
66+
67+
inflater.source = "SFXISMxLKVbIT1NIzUvPzEtNLSrWAwA=".decodeBase64()!!.toByteArray()
68+
inflater.sourcePos = 0
69+
inflater.sourceLimit = inflater.source.size
70+
assertTrue(inflater.inflate())
71+
assertEquals(inflater.sourceLimit, inflater.sourcePos)
72+
73+
val inflated = inflater.target.toByteString(0, inflater.targetPos)
74+
assertEquals(
75+
"God help us, we're in the hands of engineers.",
76+
inflated.utf8(),
77+
)
78+
79+
inflater.close()
80+
}
81+
82+
@Test
83+
fun inflateInsufficientSpaceInTarget() {
84+
val targetBuffer = Buffer()
85+
86+
val inflater = Inflater().apply {
87+
source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA="
88+
.decodeBase64()!!.toByteArray()
89+
sourcePos = 0
90+
sourceLimit = source.size
91+
}
92+
93+
inflater.target = ByteArray(31)
94+
inflater.targetPos = 0
95+
inflater.targetLimit = inflater.target.size
96+
assertFalse(inflater.inflate())
97+
assertEquals(inflater.targetLimit, inflater.targetPos)
98+
targetBuffer.write(inflater.target)
99+
100+
inflater.target = ByteArray(256)
101+
inflater.targetPos = 0
102+
inflater.targetLimit = inflater.target.size
103+
assertTrue(inflater.inflate())
104+
assertEquals(inflater.sourcePos, inflater.sourceLimit)
105+
targetBuffer.write(inflater.target, 0, inflater.targetPos)
106+
107+
assertEquals(
108+
"God help us, we're in the hands of engineers.",
109+
targetBuffer.readUtf8(),
110+
)
111+
112+
inflater.close()
113+
}
114+
115+
@Test
116+
fun inflateEmptyContent() {
117+
val inflater = Inflater().apply {
118+
source = "AwA=".decodeBase64()!!.toByteArray()
119+
sourcePos = 0
120+
sourceLimit = source.size
121+
122+
target = ByteArray(256)
123+
targetPos = 0
124+
targetLimit = target.size
125+
}
126+
127+
assertTrue(inflater.inflate())
128+
129+
val inflated = inflater.target.toByteString(0, inflater.targetPos)
130+
assertEquals(
131+
"",
132+
inflated.utf8(),
133+
)
134+
135+
inflater.close()
136+
}
137+
138+
@Test
139+
fun inflateInPartsStartingWithEmptySource() {
140+
val inflater = Inflater().apply {
141+
target = ByteArray(256)
142+
targetPos = 0
143+
targetLimit = target.size
144+
}
145+
146+
inflater.source = ByteArray(256)
147+
inflater.sourcePos = 0
148+
inflater.sourceLimit = 0
149+
assertFalse(inflater.inflate())
150+
151+
inflater.source = "c89PUchIzSlQKC3WUShPVS9KVcjMUyjJSFXISMxLKVbIT1NIzUvPzEtNLSrWAwA="
152+
.decodeBase64()!!.toByteArray()
153+
inflater.sourcePos = 0
154+
inflater.sourceLimit = inflater.source.size
155+
assertTrue(inflater.inflate())
156+
157+
val inflated = inflater.target.toByteString(0, inflater.targetPos)
158+
assertEquals(
159+
"God help us, we're in the hands of engineers.",
160+
inflated.utf8(),
161+
)
162+
163+
inflater.close()
164+
}
165+
166+
@Test
167+
fun inflateInvalidData() {
168+
val inflater = Inflater().apply {
169+
target = ByteArray(256)
170+
targetPos = 0
171+
targetLimit = target.size
172+
}
173+
174+
inflater.source = "ffffffffffffffff".decodeHex().toByteArray()
175+
inflater.sourcePos = 0
176+
inflater.sourceLimit = inflater.source.size
177+
val exception = assertFailsWith<ProtocolException> {
178+
inflater.inflate()
179+
}
180+
assertEquals("Z_DATA_ERROR", exception.message)
181+
182+
inflater.close()
183+
}
184+
185+
@Test
186+
fun cannotInflateAfterClose() {
187+
val inflater = Inflater()
188+
inflater.close()
189+
190+
assertFailsWith<IllegalStateException> {
191+
inflater.inflate()
192+
}
193+
}
194+
195+
@Test
196+
fun closeIsIdemptent() {
197+
val inflater = Inflater()
198+
inflater.close()
199+
inflater.close()
200+
}
201+
}

0 commit comments

Comments
 (0)