diff --git a/formats/cbor/commonMain/src/kotlinx/serialization/cbor/internal/Streams.kt b/formats/cbor/commonMain/src/kotlinx/serialization/cbor/internal/Streams.kt index fdbfca674..848e58ad0 100644 --- a/formats/cbor/commonMain/src/kotlinx/serialization/cbor/internal/Streams.kt +++ b/formats/cbor/commonMain/src/kotlinx/serialization/cbor/internal/Streams.kt @@ -43,12 +43,18 @@ internal class ByteArrayOutput { private var position: Int = 0 private fun ensureCapacity(elementsToAppend: Int) { - if (position + elementsToAppend <= array.size) { + val requiredCapacityLong = position.toLong() + elementsToAppend.toLong() + if (requiredCapacityLong > Int.MAX_VALUE) { + throw IllegalArgumentException("Required capacity exceeds maximum array size (Int.MAX_VALUE).") + } + + val requiredCapacity = requiredCapacityLong.toInt() + if (requiredCapacity <= array.size) { return } - val newArray = ByteArray((position + elementsToAppend).takeHighestOneBit() shl 1) - array.copyInto(newArray) - array = newArray + + val newCapacity = nextPowerOfTwoCapacity(requiredCapacity) + array = array.copyOf(newCapacity) } public fun toByteArray(): ByteArray { @@ -86,4 +92,29 @@ internal class ByteArrayOutput { ensureCapacity(1) array[position++] = byteValue.toByte() } + + companion object { + /** + * Calculates the next power-of-two capacity based on the required minimum size. + * + * This function ensures the returned value is at least as large as `minCapacity`, + * and is always a power of two, unless `minCapacity` is less than or equal to zero, + * in which case it returns 0. If the calculated power of two exceeds `Integer.MAX_VALUE`, + * it returns `Integer.MAX_VALUE`. + * + * It's useful for resizing arrays with exponential growth. + * + * @param minCapacity The minimum required capacity. + * @return A capacity value that is a power of two and ≥ minCapacity, or 0 if `minCapacity` is ≤ 0. + */ + fun nextPowerOfTwoCapacity(minCapacity: Int): Int { + if (minCapacity <= 0) return 0 + + val highestOneBit = minCapacity.takeHighestOneBit() + val maxHighestOneBit = Integer.MAX_VALUE.takeHighestOneBit() + + // Check if shifting would exceed the maximum allowed value + return if (highestOneBit < maxHighestOneBit) highestOneBit shl 1 else Integer.MAX_VALUE + } + } } diff --git a/formats/cbor/commonTest/src/kotlinx/serialization/cbor/internal/StreamsTest.kt b/formats/cbor/commonTest/src/kotlinx/serialization/cbor/internal/StreamsTest.kt new file mode 100644 index 000000000..c13f09db7 --- /dev/null +++ b/formats/cbor/commonTest/src/kotlinx/serialization/cbor/internal/StreamsTest.kt @@ -0,0 +1,57 @@ +package kotlinx.serialization.cbor.internal + +import kotlinx.serialization.* +import kotlin.test.* + +class StreamsTest { + + @Test + fun powerOfTwoCapacity_negativeValue() { + assertEquals(0, ByteArrayOutput.nextPowerOfTwoCapacity(-1)) + assertEquals(0, ByteArrayOutput.nextPowerOfTwoCapacity(-17)) + } + + @Test + fun powerOfTwoCapacity_zeroValue() { + assertEquals(0, ByteArrayOutput.nextPowerOfTwoCapacity(0)) + } + + @Test + fun powerOfTwoCapacity_exactPowerOfTwo() { + assertEquals(16, ByteArrayOutput.nextPowerOfTwoCapacity(8)) + assertEquals(32, ByteArrayOutput.nextPowerOfTwoCapacity(16)) + assertEquals(64, ByteArrayOutput.nextPowerOfTwoCapacity(32)) + } + + @Test + fun powerOfTwoCapacity_nonPowerOfTwo() { + assertEquals(16, ByteArrayOutput.nextPowerOfTwoCapacity(9)) + assertEquals(64, ByteArrayOutput.nextPowerOfTwoCapacity(33)) + assertEquals(128, ByteArrayOutput.nextPowerOfTwoCapacity(65)) + } + + @Test + fun powerOfTwoCapacity_smallValues() { + assertEquals(2, ByteArrayOutput.nextPowerOfTwoCapacity(1)) + assertEquals(4, ByteArrayOutput.nextPowerOfTwoCapacity(2)) + assertEquals(4, ByteArrayOutput.nextPowerOfTwoCapacity(3)) + } + + @Test + fun powerOfTwoCapacity_boundaryValues() { + assertEquals(0, ByteArrayOutput.nextPowerOfTwoCapacity(0)) + assertEquals(2, ByteArrayOutput.nextPowerOfTwoCapacity(1)) + assertEquals(4, ByteArrayOutput.nextPowerOfTwoCapacity(3)) + assertEquals(8, ByteArrayOutput.nextPowerOfTwoCapacity(5)) + } + + @Test + fun powerOfTwoCapacity_largeValues() { + assertEquals(1073741824, ByteArrayOutput.nextPowerOfTwoCapacity(536870912)) + assertEquals(1073741824, ByteArrayOutput.nextPowerOfTwoCapacity(1073741823)) + assertEquals(Integer.MAX_VALUE, ByteArrayOutput.nextPowerOfTwoCapacity(1073741824)) + assertEquals(Integer.MAX_VALUE, ByteArrayOutput.nextPowerOfTwoCapacity(1073741825)) + assertEquals(Integer.MAX_VALUE, ByteArrayOutput.nextPowerOfTwoCapacity(Integer.MAX_VALUE-1)) + assertEquals(Integer.MAX_VALUE, ByteArrayOutput.nextPowerOfTwoCapacity(Integer.MAX_VALUE)) + } +}