Skip to content

Commit ae1692d

Browse files
authored
Add nullability inference support to dataframe-jdbc (#672)
* Add nullability inference support to dataframe-jdbc This update adds a new parameter `inferNullability` for various dataframe reading functions including `readSqlTable`, `readSqlQuery`, and `readResultSet`. It allows better control over how column nullability should be inferred. The `h2Test` file has been adjusted to test this new feature. * Refactor SQL query requirement message for readability The code format of the requirement message for SQL query validation in the readJdbc.kt file has been improved. The change enhances readability by wrapping the requirement message into its own block, splitting the long string into two separate lines rather than extending it across one long line. * Update inferNullability from Infer to Boolean This commit changes the inferNullability parameter from 'Infer' to a Boolean type in functions of readJdbc.kt and adjusts related function calls in h2Test.kt. Now, inferNullability takes a Boolean value with 'true' indicating Inference and 'false' meaning no inference, making it more intuitive and easier to use. * Remove unnecessary inferNulls call in readJdbc The Infer.Nulls call was redundant and has been removed from the readJdbc.kt file. This simplifies the code of reading JDBC in the DataFrame-JDBC module, without altering functionality.
1 parent 7de6022 commit ae1692d

File tree

2 files changed

+159
-24
lines changed
  • dataframe-jdbc/src
    • main/kotlin/org/jetbrains/kotlinx/dataframe/io
    • test/kotlin/org/jetbrains/kotlinx/dataframe/io

2 files changed

+159
-24
lines changed

dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging
44
import org.jetbrains.kotlinx.dataframe.AnyFrame
55
import org.jetbrains.kotlinx.dataframe.DataColumn
66
import org.jetbrains.kotlinx.dataframe.DataFrame
7+
import org.jetbrains.kotlinx.dataframe.api.Infer
78
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
89
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
910
import org.jetbrains.kotlinx.dataframe.io.db.DbType
@@ -105,15 +106,17 @@ public data class DatabaseConfiguration(val url: String, val user: String = "",
105106
* @param [dbConfig] the configuration for the database, including URL, user, and password.
106107
* @param [tableName] the name of the table to read data from.
107108
* @param [limit] the maximum number of rows to retrieve from the table.
109+
* @param [inferNullability] indicates how the column nullability should be inferred.
108110
* @return the DataFrame containing the data from the SQL table.
109111
*/
110112
public fun DataFrame.Companion.readSqlTable(
111113
dbConfig: DatabaseConfiguration,
112114
tableName: String,
113-
limit: Int = DEFAULT_LIMIT
115+
limit: Int = DEFAULT_LIMIT,
116+
inferNullability: Boolean = true,
114117
): AnyFrame {
115118
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
116-
return readSqlTable(connection, tableName, limit)
119+
return readSqlTable(connection, tableName, limit, inferNullability)
117120
}
118121
}
119122

@@ -123,14 +126,16 @@ public fun DataFrame.Companion.readSqlTable(
123126
* @param [connection] the database connection to read tables from.
124127
* @param [tableName] the name of the table to read data from.
125128
* @param [limit] the maximum number of rows to retrieve from the table.
129+
* @param [inferNullability] indicates how the column nullability should be inferred.
126130
* @return the DataFrame containing the data from the SQL table.
127131
*
128132
* @see DriverManager.getConnection
129133
*/
130134
public fun DataFrame.Companion.readSqlTable(
131135
connection: Connection,
132136
tableName: String,
133-
limit: Int = DEFAULT_LIMIT
137+
limit: Int = DEFAULT_LIMIT,
138+
inferNullability: Boolean = true,
134139
): AnyFrame {
135140
var preparedQuery = "SELECT * FROM $tableName"
136141
if (limit > 0) preparedQuery += " LIMIT $limit"
@@ -145,7 +150,7 @@ public fun DataFrame.Companion.readSqlTable(
145150
preparedQuery
146151
).use { rs ->
147152
val tableColumns = getTableColumnsMetadata(rs)
148-
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit)
153+
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
149154
}
150155
}
151156
}
@@ -159,15 +164,17 @@ public fun DataFrame.Companion.readSqlTable(
159164
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
160165
* @param [sqlQuery] the SQL query to execute.
161166
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
167+
* @param [inferNullability] indicates how the column nullability should be inferred.
162168
* @return the DataFrame containing the result of the SQL query.
163169
*/
164170
public fun DataFrame.Companion.readSqlQuery(
165171
dbConfig: DatabaseConfiguration,
166172
sqlQuery: String,
167-
limit: Int = DEFAULT_LIMIT
173+
limit: Int = DEFAULT_LIMIT,
174+
inferNullability: Boolean = true,
168175
): AnyFrame {
169176
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
170-
return readSqlQuery(connection, sqlQuery, limit)
177+
return readSqlQuery(connection, sqlQuery, limit, inferNullability)
171178
}
172179
}
173180

@@ -180,16 +187,21 @@ public fun DataFrame.Companion.readSqlQuery(
180187
* @param [connection] the database connection to execute the SQL query.
181188
* @param [sqlQuery] the SQL query to execute.
182189
* @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution.
190+
* @param [inferNullability] indicates how the column nullability should be inferred.
183191
* @return the DataFrame containing the result of the SQL query.
184192
*
185193
* @see DriverManager.getConnection
186194
*/
187195
public fun DataFrame.Companion.readSqlQuery(
188196
connection: Connection,
189197
sqlQuery: String,
190-
limit: Int = DEFAULT_LIMIT
198+
limit: Int = DEFAULT_LIMIT,
199+
inferNullability: Boolean = true,
191200
): AnyFrame {
192-
require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " }
201+
require(isValid(sqlQuery)) {
202+
"SQL query should start from SELECT and contain one query for reading data without any manipulation. " +
203+
"Also it should not contain any separators like `;`."
204+
}
193205

194206
val url = connection.metaData.url
195207
val dbType = extractDBTypeFromUrl(url)
@@ -202,12 +214,12 @@ public fun DataFrame.Companion.readSqlQuery(
202214
connection.createStatement().use { st ->
203215
st.executeQuery(internalSqlQuery).use { rs ->
204216
val tableColumns = getTableColumnsMetadata(rs)
205-
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT)
217+
return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability)
206218
}
207219
}
208220
}
209221

210-
/** SQL-query is accepted only if it starts from SELECT */
222+
/** SQL query is accepted only if it starts from SELECT */
211223
private fun isValid(sqlQuery: String): Boolean {
212224
val normalizedSqlQuery = sqlQuery.trim().uppercase()
213225

@@ -221,15 +233,17 @@ private fun isValid(sqlQuery: String): Boolean {
221233
* @param [resultSet] the [ResultSet] containing the data to read.
222234
* @param [dbType] the type of database that the [ResultSet] belongs to.
223235
* @param [limit] the maximum number of rows to read from the [ResultSet].
236+
* @param [inferNullability] indicates how the column nullability should be inferred.
224237
* @return the DataFrame generated from the [ResultSet] data.
225238
*/
226239
public fun DataFrame.Companion.readResultSet(
227240
resultSet: ResultSet,
228241
dbType: DbType,
229-
limit: Int = DEFAULT_LIMIT
242+
limit: Int = DEFAULT_LIMIT,
243+
inferNullability: Boolean = true,
230244
): AnyFrame {
231245
val tableColumns = getTableColumnsMetadata(resultSet)
232-
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit)
246+
return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability)
233247
}
234248

235249
/**
@@ -238,33 +252,38 @@ public fun DataFrame.Companion.readResultSet(
238252
* @param [resultSet] the [ResultSet] containing the data to read.
239253
* @param [connection] the connection to the database (it's required to extract the database type).
240254
* @param [limit] the maximum number of rows to read from the [ResultSet].
255+
* @param [inferNullability] indicates how the column nullability should be inferred.
241256
* @return the DataFrame generated from the [ResultSet] data.
242257
*/
243258
public fun DataFrame.Companion.readResultSet(
244259
resultSet: ResultSet,
245260
connection: Connection,
246-
limit: Int = DEFAULT_LIMIT
261+
limit: Int = DEFAULT_LIMIT,
262+
inferNullability: Boolean = true,
247263
): AnyFrame {
248264
val url = connection.metaData.url
249265
val dbType = extractDBTypeFromUrl(url)
250266

251-
return readResultSet(resultSet, dbType, limit)
267+
return readResultSet(resultSet, dbType, limit, inferNullability)
252268
}
253269

254270
/**
255271
* Reads all tables from the given database using the provided database configuration and limit.
256272
*
257273
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
258274
* @param [limit] the maximum number of rows to read from each table.
275+
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
276+
* @param [inferNullability] indicates how the column nullability should be inferred.
259277
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
260278
*/
261279
public fun DataFrame.Companion.readAllSqlTables(
262280
dbConfig: DatabaseConfiguration,
263281
catalogue: String? = null,
264-
limit: Int = DEFAULT_LIMIT
282+
limit: Int = DEFAULT_LIMIT,
283+
inferNullability: Boolean = true,
265284
): List<AnyFrame> {
266285
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
267-
return readAllSqlTables(connection, catalogue, limit)
286+
return readAllSqlTables(connection, catalogue, limit, inferNullability)
268287
}
269288
}
270289

@@ -273,14 +292,17 @@ public fun DataFrame.Companion.readAllSqlTables(
273292
*
274293
* @param [connection] the database connection to read tables from.
275294
* @param [limit] the maximum number of rows to read from each table.
295+
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
296+
* @param [inferNullability] indicates how the column nullability should be inferred.
276297
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
277298
*
278299
* @see DriverManager.getConnection
279300
*/
280301
public fun DataFrame.Companion.readAllSqlTables(
281302
connection: Connection,
282303
catalogue: String? = null,
283-
limit: Int = DEFAULT_LIMIT
304+
limit: Int = DEFAULT_LIMIT,
305+
inferNullability: Boolean = true,
284306
): List<AnyFrame> {
285307
val metaData = connection.metaData
286308
val url = connection.metaData.url
@@ -304,7 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables(
304326
// could be Dialect/Database specific
305327
logger.debug { "Reading table: $tableName" }
306328

307-
val dataFrame = readSqlTable(connection, tableName, limit)
329+
val dataFrame = readSqlTable(connection, tableName, limit, inferNullability)
308330
dataFrames += dataFrame
309331
logger.debug { "Finished reading table: $tableName" }
310332
}
@@ -450,7 +472,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection):
450472
val dbType = extractDBTypeFromUrl(url)
451473

452474
val tableTypes = arrayOf("TABLE")
453-
// exclude system and other tables without data
475+
// exclude a system and other tables without data
454476
val tables = metaData.getTables(null, null, null, tableTypes)
455477

456478
val dataFrameSchemas = mutableListOf<DataFrameSchema>()
@@ -561,13 +583,15 @@ private fun manageColumnNameDuplication(columnNameCounter: MutableMap<String, In
561583
* @param [rs] the ResultSet object containing the data to be fetched and converted.
562584
* @param [dbType] the type of the database.
563585
* @param [limit] the maximum number of rows to fetch and convert.
586+
* @param [inferNullability] indicates how the column nullability should be inferred.
564587
* @return A mutable map containing the fetched and converted data.
565588
*/
566589
private fun fetchAndConvertDataFromResultSet(
567590
tableColumns: MutableList<TableColumnMetadata>,
568591
rs: ResultSet,
569592
dbType: DbType,
570-
limit: Int
593+
limit: Int,
594+
inferNullability: Boolean,
571595
): AnyFrame {
572596
val data = List(tableColumns.size) { mutableListOf<Any?>() }
573597

@@ -596,6 +620,7 @@ private fun fetchAndConvertDataFromResultSet(
596620
DataColumn.createValueColumn(
597621
name = tableColumns[index].name,
598622
values = values,
623+
infer = convertNullabilityInference(inferNullability),
599624
type = kotlinTypesForSqlColumns[index]!!
600625
)
601626
}.toDataFrame()
@@ -605,6 +630,8 @@ private fun fetchAndConvertDataFromResultSet(
605630
return dataFrame
606631
}
607632

633+
private fun convertNullabilityInference(inferNullability: Boolean) = if (inferNullability) Infer.Nulls else Infer.None
634+
608635
private fun extractNewRowFromResultSetAndAddToData(
609636
tableColumns: MutableList<TableColumnMetadata>,
610637
data: List<MutableList<Any?>>,

dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@ import org.h2.jdbc.JdbcSQLSyntaxErrorException
66
import org.intellij.lang.annotations.Language
77
import org.jetbrains.kotlinx.dataframe.DataFrame
88
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
9-
import org.jetbrains.kotlinx.dataframe.api.add
10-
import org.jetbrains.kotlinx.dataframe.api.cast
11-
import org.jetbrains.kotlinx.dataframe.api.filter
12-
import org.jetbrains.kotlinx.dataframe.api.select
9+
import org.jetbrains.kotlinx.dataframe.api.*
1310
import org.jetbrains.kotlinx.dataframe.io.db.H2
1411
import org.junit.AfterClass
1512
import org.junit.BeforeClass
@@ -677,4 +674,115 @@ class JdbcTest {
677674
saleDataSchema1.columns.size shouldBe 3
678675
saleDataSchema1.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()
679676
}
677+
678+
@Test
679+
fun `infer nullability`() {
680+
// prepare tables and data
681+
@Language("SQL")
682+
val createTestTable1Query = """
683+
CREATE TABLE TestTable1 (
684+
id INT PRIMARY KEY,
685+
name VARCHAR(50),
686+
surname VARCHAR(50),
687+
age INT NOT NULL
688+
)
689+
"""
690+
691+
connection.createStatement().execute(createTestTable1Query)
692+
693+
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)")
694+
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)")
695+
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)")
696+
connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)")
697+
698+
// start testing `readSqlTable` method
699+
700+
// with default inferNullability: Boolean = true
701+
val tableName = "TestTable1"
702+
val df = DataFrame.readSqlTable(connection, tableName)
703+
df.schema().columns["id"]!!.type shouldBe typeOf<Int>()
704+
df.schema().columns["name"]!!.type shouldBe typeOf<String>()
705+
df.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
706+
df.schema().columns["age"]!!.type shouldBe typeOf<Int>()
707+
708+
val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName)
709+
dataSchema.columns.size shouldBe 4
710+
dataSchema.columns["id"]!!.type shouldBe typeOf<Int>()
711+
dataSchema.columns["name"]!!.type shouldBe typeOf<String?>()
712+
dataSchema.columns["surname"]!!.type shouldBe typeOf<String?>()
713+
dataSchema.columns["age"]!!.type shouldBe typeOf<Int>()
714+
715+
// with inferNullability: Boolean = false
716+
val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false)
717+
df1.schema().columns["id"]!!.type shouldBe typeOf<Int>()
718+
df1.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
719+
df1.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
720+
df1.schema().columns["age"]!!.type shouldBe typeOf<Int>()
721+
722+
// end testing `readSqlTable` method
723+
724+
// start testing `readSQLQuery` method
725+
726+
// ith default inferNullability: Boolean = true
727+
@Language("SQL")
728+
val sqlQuery = """
729+
SELECT name, surname, age FROM TestTable1
730+
""".trimIndent()
731+
732+
val df2 = DataFrame.readSqlQuery(connection, sqlQuery)
733+
df2.schema().columns["name"]!!.type shouldBe typeOf<String>()
734+
df2.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
735+
df2.schema().columns["age"]!!.type shouldBe typeOf<Int>()
736+
737+
val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery)
738+
dataSchema2.columns.size shouldBe 3
739+
dataSchema2.columns["name"]!!.type shouldBe typeOf<String?>()
740+
dataSchema2.columns["surname"]!!.type shouldBe typeOf<String?>()
741+
dataSchema2.columns["age"]!!.type shouldBe typeOf<Int>()
742+
743+
// with inferNullability: Boolean = false
744+
val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false)
745+
df3.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
746+
df3.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
747+
df3.schema().columns["age"]!!.type shouldBe typeOf<Int>()
748+
749+
// end testing `readSQLQuery` method
750+
751+
// start testing `readResultSet` method
752+
753+
connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st ->
754+
@Language("SQL")
755+
val selectStatement = "SELECT * FROM TestTable1"
756+
757+
st.executeQuery(selectStatement).use { rs ->
758+
// ith default inferNullability: Boolean = true
759+
val df4 = DataFrame.readResultSet(rs, H2)
760+
df4.schema().columns["id"]!!.type shouldBe typeOf<Int>()
761+
df4.schema().columns["name"]!!.type shouldBe typeOf<String>()
762+
df4.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
763+
df4.schema().columns["age"]!!.type shouldBe typeOf<Int>()
764+
765+
rs.beforeFirst()
766+
767+
val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2)
768+
dataSchema3.columns.size shouldBe 4
769+
dataSchema3.columns["id"]!!.type shouldBe typeOf<Int>()
770+
dataSchema3.columns["name"]!!.type shouldBe typeOf<String?>()
771+
dataSchema3.columns["surname"]!!.type shouldBe typeOf<String?>()
772+
dataSchema3.columns["age"]!!.type shouldBe typeOf<Int>()
773+
774+
// with inferNullability: Boolean = false
775+
rs.beforeFirst()
776+
777+
val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false)
778+
df5.schema().columns["id"]!!.type shouldBe typeOf<Int>()
779+
df5.schema().columns["name"]!!.type shouldBe typeOf<String?>() // <=== this column changed a type because it doesn't contain nulls
780+
df5.schema().columns["surname"]!!.type shouldBe typeOf<String?>()
781+
df5.schema().columns["age"]!!.type shouldBe typeOf<Int>()
782+
}
783+
}
784+
// end testing `readResultSet` method
785+
786+
connection.createStatement().execute("DROP TABLE TestTable1")
787+
}
680788
}

0 commit comments

Comments
 (0)