diff --git a/.travis.yml b/.travis.yml index 348df9ad..5ddb0fc6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,9 @@ os: linux dist: focal language: java -services: docker +services: + - docker + - mysql jdk: openjdk11 addons: hosts: @@ -10,7 +12,7 @@ addons: before_install: - git clone https://github.com/mariadb-corporation/connector-test-machine.git -env: packet=40 +env: packet=40 local=0 clear_text_plugin=0 RUN_LONG_TEST=true install: - |- case $TRAVIS_OS_NAME in @@ -20,7 +22,7 @@ install: connector-test-machine/launch.bat -t "$srv" -v "$v" -d testr2 ;; linux) - source connector-test-machine/launch.sh -t "$srv" -v "$v" -d testr2 -n 0 -l "$local" -p "$packet" + source connector-test-machine/launch.sh -t "$srv" -v "$v" -d testr2 -n 0 -l "$local" -p "$packet" -c "$clear_text_plugin" ;; esac @@ -37,16 +39,14 @@ jobs: - env: srv=mariadb v=10.3 local=1 - env: srv=mariadb v=10.4 local=1 - env: srv=mariadb v=10.5 local=1 - - env: srv=mariadb v=10.5 local=1 - jdk: openjdk8 - dist: bionic - env: srv=mariadb v=10.6 local=1 - - env: srv=mariadb v=10.5 NO_BACKSLASH_ESCAPES=true + - env: srv=mariadb v=10.7 local=1 clear_text_plugin=1 + - env: srv=mariadb v=10.6 NO_BACKSLASH_ESCAPES=true - env: srv=mariadb v=10.6 BENCH=1 - if: type = push AND fork = false env: srv=maxscale - if: type = push AND fork = false - env: srv=mariadb-es v=10.5 + env: srv=mariadb-es v=10.6 - if: type = push AND fork = false env: srv=skysql - if: type = push AND fork = false diff --git a/CHANGELOG.md b/CHANGELOG.md index cf6d0cac..3c5fc522 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,34 @@ # Change Log +## [1.1.2](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/tree/1.1.2) (12 Mai 2022) +[Full Changelog](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/compare/1.1.1...1.1.2) + +* [R2DBC-54] Support r2dbc spec 0.9.1 version +* [R2DBC-42] Specification precision on Statement::add +* [R2DBC-44] simplify client side prepared statement +* [R2DBC-45] Implement SPI TestKit to validate driver with spec tests +* [R2DBC-46] Add sql to R2DBC exception hierarchy +* [R2DBC-47] ensure driver follow spec precision about Row.get returning error. +* [R2DBC-48] after spec batch clarification trailing batch should fail +* [R2DBC-49] Support for failover and load balancing modes +* [R2DBC-50] TIME data without indication default to return Duration in place of LocalTime +* [R2DBC-56] Transaction isolation spec precision +* [R2DBC-57] varbinary data default must return byte[] +* [R2DBC-63] backpressure handling +* [R2DBC-64] Support batch cancellation +* [R2DBC-53] correct RowMetadata case-sensitivity lookup +* [R2DBC-62] Prepared statement wrong column type on prepare meta not skipped + +## [1.1.1-rc](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/tree/1.1.1) (13 Sept 2021) +[Full Changelog](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/compare/1.1.0...1.1.1) + +Changes: +* [R2DBC-37] Full java 9 JPMS module +* [R2DBC-38] Permit sharing channels with option loopResources + +Corrections: +* [R2DBC-40] netty buffer leaks when not consuming results +* [R2DBC-39] MariadbResult.getRowsUpdated() fails with ClassCastException for RETURNING command ## [1.0.3](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/tree/1.0.3) (13 Sept 2021) [Full Changelog](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/compare/1.0.2...1.0.3) @@ -86,4 +115,4 @@ Changes compared to 0.8.2.alpha1: second Alpha release ## [0.8.1](https://github.com/mariadb-corporation/mariadb-connector-r2dbc/tree/0.8.1) (23 Mar. 2020) -First Alpha release \ No newline at end of file +First Alpha release diff --git a/README.md b/README.md index 66b09991..ecae94d9 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ MariaDB and MySQL client, 100% Java, compatible with Java8+, apache 2.0 licensed - Driver permits ed25519, PAM authentication that comes with MariaDB. - use MariaDB 10.5 returning fonction to permit Statement.returnGeneratedValues -Driver follow [R2DBC 0.8.4 specifications](https://r2dbc.io/spec/0.8.4.RELEASE/spec/html/) +Driver follow [R2DBC 0.9.1 specifications](https://r2dbc.io/spec/0.9.1.RELEASE/spec/html/) ## Documentation @@ -32,7 +32,7 @@ The MariaDB Connector is available through maven using : org.mariadb r2dbc-mariadb - 1.0.3 + 1.1.2 ``` @@ -52,7 +52,7 @@ MariadbConnectionFactory factory = new MariadbConnectionFactory(conf); //OR -ConnectionFactory factory = ConnectionFactories.get("r2dbc:mariadb://user:password@host:3306/myDB?option1=value"); +ConnectionFactory factory = ConnectionFactories.get("r2dbc:mariadb://user:password@host:3306,host2:3302/myDB?option1=value"); ``` Basic example: @@ -82,7 +82,7 @@ Basic example: |---:|---|:---:|:---:| | **`username`** | User to access database. |*string* | | **`password`** | User password. |*string* | -| **`host`** | IP address or DNS of the database server. *Not used when using option `socketPath`*. |*string*| "localhost"| +| **`host`** | IP address or DNS of the database server. Multiple host can be set, separate by comma. If first host is not reachable (timeout is connectTimeout), driver use next hosts.*Not used when using option `socketPath`*. |*string*| "localhost"| | **`port`** | Database server port number. *Not used when using option `socketPath`*|*integer*| 3306| | **`database`** | Default database to use when establishing the connection. | *string* | | **`connectTimeout`** | Sets the connection timeout | *Duration* | 10s| @@ -108,13 +108,53 @@ Basic example: | **`pamOtherPwd`** | Permit to provide additional password for PAM authentication with multiple authentication step. If multiple passwords, value must be URL encoded.|*string* | | | **`autocommit`** | Set default autocommit value on connection initialization" |*boolean* | true | | **`tinyInt1isBit`** | Convert Bit(1)/TINYINT(1) default to boolean type |*boolean* | true | +| **`restrictedAuth`** | if set, restrict authentication plugin to secure list. Default provided plugins are mysql_native_password, mysql_clear_password, client_ed25519, dialog, sha256_password and caching_sha2_password |*string* | | +| **`loopResources`** | permits to share netty EventLoopGroup among multiple async libraries/framework |*LoopResources* | | -## Roadmap +## Failover -* Performance ! -* Fast batch using mariadb bulk -* GeoJSON datatype -* Pluggable types for MariaDB 10.5 (JSON, INET4, INET6, BOOLEAN, ...) +Failover occurs when a connection to a primary database server fails and the connector opens up a connection to another database server. +For example, server A has the current connection. After a failure (server crash, network down …) the connection will switch to another server (B). + +Load balancing allows load to be distributed over multiple servers : +When initializing a connection or after a failed connection, the connector will attempt to connect to a host. The connection is selected randomly among the valid hosts. Thereafter, all statements will run on that database server until the connection will be closed (or fails). +Example: when creating a pool of 60 connections, each one will use a random host. With 3 master hosts, the pool will have about 20 connections to each host. + +```java +ConnectionFactory factory = ConnectionFactories.get("r2dbc:mariadb:sequential://user:password@host:3306,host2:3302/myDB?option1=value"); +``` + + +### Failover behaviour + +Failover parameter is set (i.e. prefixing connection string with `r2dbc:mariadb:[sequential|loadbalancing]://...` or using HaMode builder). + +There can be multiple fail causes. When a failure occurs many things will be done: + +* connection recovery (re-establishing connection transparently) +* re-execute command/transaction if possible + +During failover, the fail host address will be put on a blacklist (shared by JVM) for 60 seconds. Connector will always try to connect non blacklisted host first, but can retry to connect blacklisted host before 60s if all hosts are blacklisted. + +### re-execution +The driver will try to reconnect to any valid host (not blasklisted, or if all primary host are blacklisted trying blacklisted hosts). If reconnection fail, an Exception with be thrown with SQLState "08XXX". If using a pool, this connection will be discarded. + +on successful reconnection, there will be different cases. + +If driver identify that command can be replayed without issue (for example connection.isValid(), a PREPARE/ROLLBACK command), driver will execute command without throwing any error. + +Driver cannot transparently handle all cases : imagine that the failover occurs when executing an INSERT command without a transaction: driver cannot know that command has been received and executed on server. In those case, an SQLException with be thrown with SQLState "25S03". + +#### Option `transactionReplay` : +Most of the time, queries occurs in transaction (ORM for example doesn't permit using auto-commit), so redo transaction implementation will solve most of failover cases transparently for user point of view. + +Redo transaction approach is to save commands in transaction. When a failover occurs during a transaction, the connector can automatically reconnect and replay transaction, making failover completely transparent. + +There is some limitations : + +driver will buffer up commands in a transaction until some inner limit. +huge command will temporarily disable transaction buffering for current transaction. +Commands must be idempotent only (queries can be "replayable") ## Tracker diff --git a/intellij-style.xml b/intellij-style.xml index cff4d36b..d23008ef 100644 --- a/intellij-style.xml +++ b/intellij-style.xml @@ -1,6 +1,6 @@ @@ -451,4 +451,4 @@ - \ No newline at end of file + diff --git a/pom.xml b/pom.xml index 8819c02e..95db3fb4 100644 --- a/pom.xml +++ b/pom.xml @@ -1,6 +1,6 @@ 4.0.0 org.mariadb r2dbc-mariadb - 1.0.3 + 1.1.2 jar https://github.com/mariadb-corporation/mariadb-connector-r2dbc @@ -21,14 +21,14 @@ 1.8 3.0.2 - 5.7.2 - 1.32 - 1.2.3 - 4.1.65.Final + 5.8.2 + 1.34 + 1.2.10 + 4.1.73.Final UTF-8 - 0.8.5.RELEASE - Dysprosium-SR21 - 3.0.0-alpha + 0.9.1.RELEASE + 2020.0.15 + 3.0.4 benchmarks @@ -97,12 +97,26 @@ + + + sonatype-nexus-snapshots + Sonatype Nexus Snapshots + https://oss.sonatype.org/content/repositories/snapshots + + + io.r2dbc r2dbc-spi ${r2dbc-spi.version} + + io.r2dbc + r2dbc-spi-test + ${r2dbc-spi.version} + test + io.projectreactor reactor-core @@ -134,16 +148,6 @@ reactor-test test - - - - - - - - - - org.junit.jupiter junit-jupiter-engine @@ -155,6 +159,12 @@ ${logback.version} test + + org.mariadb.jdbc + mariadb-java-client + ${mariadb-jdbc.version} + test + @@ -165,17 +175,41 @@ 3.8.1 - -Xlint:all,-options,-path,-processing + -Xlint:all,-options,-path,-processing,-requires-transitive-automatic,-requires-automatic true ${java.version} ${java.version} + + + compile-java-8 + + compile + + + + compile-java-9 + compile + + compile + + + 9 + 9 + 9 + + ${project.basedir}/src/main/java9 + + true + + + org.apache.maven.plugins maven-jar-plugin - 3.1.2 + 3.2.0 @@ -184,7 +218,7 @@ ${r2dbc-spi.version} - r2dbc.mariadb + true @@ -305,7 +339,7 @@ org.jacoco jacoco-maven-plugin - 0.8.5 + 0.8.7 **/ed25519/**/*.class @@ -342,16 +376,16 @@ jmh-generator-annprocess ${jmh.version} - - org.mariadb.jdbc - mariadb-java-client - ${mariadb-jdbc.version} - ch.qos.logback logback-classic ${logback.version} + + org.mariadb.jdbc + mariadb-java-client + ${mariadb-jdbc.version} + @@ -369,7 +403,7 @@ org.codehaus.mojo build-helper-maven-plugin - 3.0.0 + 3.3.0 add-source @@ -402,7 +436,7 @@ org.apache.maven.plugins maven-shade-plugin - 3.2.1 + 3.2.4 package diff --git a/src/benchmark/java/org/mariadb/r2dbc/Common.java b/src/benchmark/java/org/mariadb/r2dbc/Common.java index b179f50e..7cdfcd30 100644 --- a/src/benchmark/java/org/mariadb/r2dbc/Common.java +++ b/src/benchmark/java/org/mariadb/r2dbc/Common.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; @@ -36,6 +36,7 @@ public static class MyState { protected Connection jdbcPrepare; protected io.r2dbc.spi.Connection r2dbc; + protected io.r2dbc.spi.Connection r2dbcFailover; protected io.r2dbc.spi.Connection r2dbcPrepare; // protected io.r2dbc.spi.Connection r2dbcMysql; @@ -49,7 +50,15 @@ public void doSetup() throws Exception { .password(password) .database(database) .build(); - + MariadbConnectionConfiguration confFailover = + MariadbConnectionConfiguration.builder() + .host(host) + .port(port) + .username(username) + .password(password) + .database(database) + .haMode(HaMode.SEQUENTIAL.name()) + .build(); MariadbConnectionConfiguration confPrepare = MariadbConnectionConfiguration.builder() .host(host) @@ -60,25 +69,16 @@ public void doSetup() throws Exception { .useServerPrepStmts(true) .build(); -// MySqlConnectionConfiguration confMysql = -// MySqlConnectionConfiguration.builder() -// .host(host) -// .username(username) -// .database(database) -// .password(password) -// .sslMode(SslMode.DISABLED) -// .port(port) -// .build(); String jdbcUrl = String.format( - "mariadb://%s:%s/%s?user=%s&password=%s", host, port, database, username, password); + "jdbc:mariadb://%s:%s/%s?user=%s&password=%s", host, port, database, username, password); try { - jdbc = DriverManager.getConnection("jdbc:" + jdbcUrl); - jdbcPrepare = DriverManager.getConnection("jdbc:" + jdbcUrl + "&useServerPrepStmts=true"); + jdbc = DriverManager.getConnection(jdbcUrl); + jdbcPrepare = DriverManager.getConnection(jdbcUrl + "&useServerPrepStmts=true"); r2dbc = MariadbConnectionFactory.from(conf).create().block(); + r2dbcFailover = MariadbConnectionFactory.from(confFailover).create().block(); r2dbcPrepare = MariadbConnectionFactory.from(confPrepare).create().block(); -// r2dbcMysql = MySqlConnectionFactory.from(confMysql).create().block(); } catch (SQLException e) { e.printStackTrace(); @@ -91,8 +91,8 @@ public void doTearDown() throws SQLException { jdbc.close(); jdbcPrepare.close(); Mono.from(r2dbc.close()).block(); + Mono.from(r2dbcFailover.close()).block(); Mono.from(r2dbcPrepare.close()).block(); -// Mono.from(r2dbcMysql.close()).block(); } } } diff --git a/src/benchmark/java/org/mariadb/r2dbc/Select_1.java b/src/benchmark/java/org/mariadb/r2dbc/Select_1.java index 323c1474..ce57f41b 100644 --- a/src/benchmark/java/org/mariadb/r2dbc/Select_1.java +++ b/src/benchmark/java/org/mariadb/r2dbc/Select_1.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/benchmark/java/org/mariadb/r2dbc/Select_10000_Rows.java b/src/benchmark/java/org/mariadb/r2dbc/Select_10000_Rows.java index 18816429..14f1fafd 100644 --- a/src/benchmark/java/org/mariadb/r2dbc/Select_10000_Rows.java +++ b/src/benchmark/java/org/mariadb/r2dbc/Select_10000_Rows.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/benchmark/java/org/mariadb/r2dbc/Select_1000_params.java b/src/benchmark/java/org/mariadb/r2dbc/Select_1000_params.java index 41e71bd8..c6f29ee8 100644 --- a/src/benchmark/java/org/mariadb/r2dbc/Select_1000_params.java +++ b/src/benchmark/java/org/mariadb/r2dbc/Select_1000_params.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/benchmark/java/org/mariadb/r2dbc/Select_1_user.java b/src/benchmark/java/org/mariadb/r2dbc/Select_1_user.java index a1a89fda..6448c336 100644 --- a/src/benchmark/java/org/mariadb/r2dbc/Select_1_user.java +++ b/src/benchmark/java/org/mariadb/r2dbc/Select_1_user.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; @@ -34,6 +34,11 @@ public Object[] testR2dbcPrepare(MyState state) throws Throwable { return consume(state.r2dbcPrepare); } + @Benchmark + public Object[] testR2dbcFailover(MyState state) throws Throwable { + return consume(state.r2dbcFailover); + } + // @Benchmark // public Object[] testR2dbcMySql(MyState state) throws Throwable { // return consume(state.r2dbcMysql, blackhole); diff --git a/src/benchmark/resources/logback-test.xml b/src/benchmark/resources/logback-test.xml index d86430e7..843d84e2 100644 --- a/src/benchmark/resources/logback-test.xml +++ b/src/benchmark/resources/logback-test.xml @@ -1,7 +1,7 @@ diff --git a/src/main/java/org/mariadb/r2dbc/ExceptionFactory.java b/src/main/java/org/mariadb/r2dbc/ExceptionFactory.java index 6ce8f75b..ae373542 100644 --- a/src/main/java/org/mariadb/r2dbc/ExceptionFactory.java +++ b/src/main/java/org/mariadb/r2dbc/ExceptionFactory.java @@ -1,11 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; import io.r2dbc.spi.*; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.server.ErrorPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; import reactor.core.publisher.SynchronousSink; public final class ExceptionFactory { @@ -22,14 +22,31 @@ public static ExceptionFactory withSql(String sql) { } public static R2dbcException createException(ErrorPacket error, String sql) { - return createException(error.getMessage(), error.getSqlState(), error.getErrorCode(), sql); + return createException(error.message(), error.sqlState(), error.errorCode(), sql); + } + + public R2dbcException createParsingException(String message) { + return new R2dbcNonTransientResourceException(message, "H1000", 9000, this.sql); + } + + public R2dbcException createParsingException(String message, Throwable cause) { + return new R2dbcNonTransientResourceException(message, "H1000", 9000, this.sql, cause); + } + + public R2dbcException createConnectionErrorException(String message) { + return new R2dbcNonTransientResourceException(message, "08000", 9000, this.sql); + } + + public R2dbcException createConnectionErrorException(String message, Throwable cause) { + return new R2dbcNonTransientResourceException( + message + " : " + cause.getMessage(), "08000", 9000, this.sql, cause); } public static R2dbcException createException( String message, String sqlState, int errorCode, String sql) { - if ("70100".equals(sqlState)) { // ER_QUERY_INTERRUPTED - return new R2dbcTimeoutException(message, sqlState, errorCode); + if ("70100".equals(sqlState) || errorCode == 3024) { // ER_QUERY_INTERRUPTED + return new R2dbcTimeoutException(message, sqlState, errorCode, sql); } String sqlClass = sqlState.substring(0, 2); @@ -44,17 +61,18 @@ public static R2dbcException createException( return new R2dbcBadGrammarException(message, sqlState, errorCode, sql); case "25": case "28": - return new R2dbcPermissionDeniedException(message, sqlState, errorCode); + return new R2dbcPermissionDeniedException(message, sqlState, errorCode, sql); case "21": case "23": - return new R2dbcDataIntegrityViolationException(message, sqlState, errorCode); + return new R2dbcDataIntegrityViolationException(message, sqlState, errorCode, sql); + case "H1": case "08": - return new R2dbcNonTransientResourceException(message, sqlState, errorCode); + return new R2dbcNonTransientResourceException(message, sqlState, errorCode, sql); case "40": - return new R2dbcRollbackException(message, sqlState, errorCode); + return new R2dbcRollbackException(message, sqlState, errorCode, sql); } - return new R2dbcTransientResourceException(message, sqlState, errorCode); + return new R2dbcTransientResourceException(message, sqlState, errorCode, sql); } public R2dbcException createException(String message, String sqlState, int errorCode) { diff --git a/src/main/java/org/mariadb/r2dbc/HaMode.java b/src/main/java/org/mariadb/r2dbc/HaMode.java new file mode 100644 index 00000000..4e47790d --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/HaMode.java @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiFunction; +import java.util.stream.Collectors; +import org.mariadb.r2dbc.client.Client; +import org.mariadb.r2dbc.client.SimpleClient; +import org.mariadb.r2dbc.message.flow.AuthenticationFlow; +import org.mariadb.r2dbc.util.HostAddress; +import reactor.core.publisher.Mono; +import reactor.netty.resources.ConnectionProvider; + +/** Failover (High-availability) mode */ +public enum HaMode { + /** sequential: driver will always connect according to connection string order */ + SEQUENTIAL("sequential") { + public List getAvailableHost( + List hostAddresses, ConcurrentMap denyList) { + return getAvailableHostInOrder(hostAddresses, denyList); + } + + public Mono connectHost( + MariadbConnectionConfiguration conf, ReentrantLock lock, boolean failFast) { + long endingNanoTime = + CONNECTION_LOOP_DURATION.getSeconds() * 1_000_000_000 + System.nanoTime(); + return connectHost(conf, lock, failFast, this::getAvailableHost, endingNanoTime); + } + }, + + /** load-balance: driver will randomly connect to any host, permitting balancing connections */ + LOADBALANCE("load-balance") { + public List getAvailableHost( + List hostAddresses, ConcurrentMap denyList) { + // use in order not blacklisted server + List loopAddress = + new ArrayList<>(HaMode.getAvailableHostInOrder(hostAddresses, denyList)); + Collections.shuffle(loopAddress); + return loopAddress; + } + + public Mono connectHost( + MariadbConnectionConfiguration conf, ReentrantLock lock, boolean failFast) { + long endingNanoTime = + CONNECTION_LOOP_DURATION.getSeconds() * 1_000_000_000 + System.nanoTime(); + return connectHost(conf, lock, failFast, this::getAvailableHost, endingNanoTime); + } + }, + + /** no ha-mode. Connect to first host only */ + NONE("") { + public List getAvailableHost( + List hostAddresses, ConcurrentMap denyList) { + return hostAddresses; + } + + public Mono connectHost( + MariadbConnectionConfiguration conf, ReentrantLock lock, boolean failFast) { + return connectHost(conf, lock, true, this::getAvailableHost, 0L); + } + }; + + /** temporary blacklisted hosts */ + private static final ConcurrentMap denyList = new ConcurrentHashMap<>(); + + /** denied timeout */ + private static final long DENIED_LIST_TIMEOUT = + Long.parseLong(System.getProperty("deniedListTimeout", "60000000000")); + + private static final Duration CONNECTION_LOOP_DURATION = + Duration.parse(System.getProperty("connectionLoopDuration", "PT10S")); + + private final String value; + + HaMode(String value) { + this.value = value; + } + + /** + * Get HAMode from values or aliases + * + * @param value value or alias + * @return HaMode if corresponding mode is found + */ + public static HaMode from(String value) { + for (HaMode haMode : values()) { + if (haMode.value.equalsIgnoreCase(value) || haMode.name().equalsIgnoreCase(value)) { + return haMode; + } + } + throw new IllegalArgumentException( + String.format("Wrong argument value '%s' for HaMode", value)); + } + + private static Mono connect( + MariadbConnectionConfiguration conf, ReentrantLock lock, HostAddress hostAddress) { + return SimpleClient.connect( + ConnectionProvider.newConnection(), + InetSocketAddress.createUnresolved(hostAddress.getHost(), hostAddress.getPort()), + hostAddress, + conf, + lock) + .delayUntil(client -> AuthenticationFlow.exchange(client, conf, hostAddress)) + .doOnError(e -> HaMode.failHost(hostAddress)) + .cast(Client.class) + .flatMap( + client -> + MariadbConnectionFactory.setSessionVariables(conf, client).then(Mono.just(client))); + } + /** + * return hosts of without blacklisted hosts. hosts in blacklist reaching blacklist timeout will + * be present. order corresponds to connection string order. + * + * @param hostAddresses hosts + * @param denyList blacklist + * @return list without denied hosts + */ + private static List getAvailableHostInOrder( + List hostAddresses, ConcurrentMap denyList) { + // use in order not blacklisted server + List copiedList = new ArrayList<>(hostAddresses); + denyList.entrySet().stream() + .filter(e -> e.getValue() < System.nanoTime()) + .forEach(e -> denyList.remove(e.getKey())); + copiedList.removeAll(denyList.keySet()); + return copiedList; + } + + public static Mono resumeConnect( + Throwable t, + MariadbConnectionConfiguration conf, + ReentrantLock lock, + boolean failFast, + List availableHosts, + BiFunction, ConcurrentMap, List> availHost, + Iterator iterator, + long endingNanoTime) { + if (!iterator.hasNext()) { + if (failFast || System.nanoTime() > endingNanoTime) { + return Mono.error( + ExceptionFactory.INSTANCE.createParsingException( + String.format( + "Fail to establish connection to %s %s: %s", + endingNanoTime == 0L ? "" : ", reaching timeout", + availableHosts, + t.getMessage()), + t)); + } + try { + Thread.sleep(250); + } catch (InterruptedException e) { + // eat + } + return connectHost(conf, lock, failFast, availHost, endingNanoTime); + } + return HaMode.connect(conf, lock, iterator.next()) + .onErrorResume( + tt -> + resumeConnect( + tt, conf, lock, failFast, availableHosts, availHost, iterator, endingNanoTime)); + } + + public static Mono connectHost( + MariadbConnectionConfiguration conf, + ReentrantLock lock, + boolean failFast, + BiFunction, ConcurrentMap, List> availHost, + long endingNanoTime) { + List nonBlacklistHosts = availHost.apply(conf.getHostAddresses(), denyList); + List availableHosts; + if (!failFast) { + nonBlacklistHosts.addAll(denyList.keySet()); + // remove host from denyList not in initial host list + availableHosts = + nonBlacklistHosts.stream() + .filter(h -> conf.getHostAddresses().contains(h)) + .collect(Collectors.toList()); + } else { + availableHosts = nonBlacklistHosts; + } + + Iterator iterator = availableHosts.iterator(); + if (!iterator.hasNext()) + return Mono.error( + ExceptionFactory.INSTANCE.createParsingException( + "Fail to establish connection: no available host")); + return HaMode.connect(conf, lock, iterator.next()) + .onErrorResume( + t -> + resumeConnect( + t, conf, lock, failFast, availableHosts, availHost, iterator, endingNanoTime)); + } + + /** + * List of hosts without blacklist entries, ordered according to HA mode + * + * @param hostAddresses hosts + * @param denyList hosts temporary denied + * @return list without denied hosts + */ + public abstract List getAvailableHost( + List hostAddresses, ConcurrentMap denyList); + + public abstract Mono connectHost( + MariadbConnectionConfiguration conf, ReentrantLock lock, boolean failFast); + + public static void failHost(HostAddress hostAddress) { + denyList.put(hostAddress, System.nanoTime() + DENIED_LIST_TIMEOUT); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbBatch.java b/src/main/java/org/mariadb/r2dbc/MariadbBatch.java index 8c3a119d..7144c648 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbBatch.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbBatch.java @@ -1,16 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; import java.util.ArrayList; +import java.util.Iterator; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import org.mariadb.r2dbc.api.MariadbResult; import org.mariadb.r2dbc.client.Client; +import org.mariadb.r2dbc.message.Protocol; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.client.QueryPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.ClientPrepareResult; import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; /** Basic implementation for batch. //TODO implement bulk */ final class MariadbBatch implements org.mariadb.r2dbc.api.MariadbBatch { @@ -28,9 +33,12 @@ final class MariadbBatch implements org.mariadb.r2dbc.api.MariadbBatch { public MariadbBatch add(String sql) { Assert.requireNonNull(sql, "sql must not be null"); - if (!MariadbSimpleQueryStatement.supports(sql, this.client)) { - throw new IllegalArgumentException( - String.format("Statement with parameters cannot be batched (sql:'%s')", sql)); + // ensure commands doesn't have parameters + if (sql.contains("?") || sql.contains(":")) { + if (ClientPrepareResult.hasParameter(sql, client.noBackslashEscapes())) { + throw new IllegalArgumentException( + String.format("Statement with parameters cannot be batched (sql:'%s')", sql)); + } } this.statements.add(sql); @@ -40,34 +48,61 @@ public MariadbBatch add(String sql) { @Override public Flux execute() { if (configuration.allowMultiQueries()) { - return new MariadbSimpleQueryStatement(this.client, String.join(";", this.statements)) - .execute(); + Flux messages = + this.client.sendCommand(new QueryPacket(String.join(";", this.statements)), true); + return MariadbCommonStatement.toResult( + Protocol.TEXT, + this.client, + messages, + ExceptionFactory.INSTANCE, + null, + null, + configuration); + } else { + Iterator iterator = this.statements.iterator(); + Sinks.Many commandsSink = Sinks.many().unicast().onBackpressureBuffer(); + AtomicBoolean canceled = new AtomicBoolean(); + return commandsSink + .asFlux() + .map( + sql -> { + Flux messages = + this.client + .sendCommand(new QueryPacket(sql), false) + .doOnComplete(() -> tryNextCommand(iterator, commandsSink, canceled)); + + return MariadbCommonStatement.toResult( + Protocol.TEXT, + this.client, + messages, + ExceptionFactory.INSTANCE, + null, + null, + configuration); + }) + .flatMap(mariadbResultFlux -> mariadbResultFlux) + .doOnCancel(() -> canceled.set(true)) + .doOnSubscribe( + it -> commandsSink.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST)); + } + } - Flux> fluxMsg = - Flux.create( - sink -> { - for (String sql : this.statements) { - Flux in = this.client.sendCommand(new QueryPacket(sql)); - sink.next(in); - in.subscribe(); - } - sink.complete(); - }); + protected static void tryNextCommand( + Iterator iterator, Sinks.Many bindingSink, AtomicBoolean canceled) { - return fluxMsg - .flatMap(Flux::from) - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new org.mariadb.r2dbc.MariadbResult( - true, - null, - dataRow, - ExceptionFactory.INSTANCE, - null, - client.getVersion().supportReturning(), - client.getConf())); + if (canceled.get()) { + return; + } + + try { + if (iterator.hasNext()) { + bindingSink.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST); + } else { + bindingSink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST); + } + } catch (Exception e) { + bindingSink.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST); } } } diff --git a/src/main/java/org/mariadb/r2dbc/MariadbClientParameterizedQueryStatement.java b/src/main/java/org/mariadb/r2dbc/MariadbClientParameterizedQueryStatement.java index 67680b52..d5d35257 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbClientParameterizedQueryStatement.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbClientParameterizedQueryStatement.java @@ -1,116 +1,44 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.mariadb.r2dbc.api.MariadbStatement; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import org.mariadb.r2dbc.client.Client; -import org.mariadb.r2dbc.codec.Codec; -import org.mariadb.r2dbc.codec.Codecs; -import org.mariadb.r2dbc.codec.Parameter; +import org.mariadb.r2dbc.client.DecoderState; +import org.mariadb.r2dbc.message.Protocol; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.client.QueryPacket; import org.mariadb.r2dbc.message.client.QueryWithParametersPacket; -import org.mariadb.r2dbc.message.server.RowPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.Binding; import org.mariadb.r2dbc.util.ClientPrepareResult; import reactor.core.publisher.Flux; -import reactor.util.annotation.Nullable; +import reactor.core.publisher.Sinks; -final class MariadbClientParameterizedQueryStatement implements MariadbStatement { +final class MariadbClientParameterizedQueryStatement extends MariadbCommonStatement { - private final Client client; - private final String sql; private final ClientPrepareResult prepareResult; - private final MariadbConnectionConfiguration configuration; - private Parameter[] parameters; - private List[]> batchingParameters; - private String[] generatedColumns; MariadbClientParameterizedQueryStatement( Client client, String sql, MariadbConnectionConfiguration configuration) { - this.client = client; - this.configuration = configuration; - this.sql = Assert.requireNonNull(sql, "sql must not be null"); + super(client, sql, configuration, Protocol.TEXT); this.prepareResult = - ClientPrepareResult.parameterParts(this.sql, this.client.noBackslashEscapes()); - this.parameters = null; + ClientPrepareResult.parameterParts(this.initialSql, this.client.noBackslashEscapes()); + this.expectedSize = this.prepareResult.getParamCount(); } - @Override - public MariadbClientParameterizedQueryStatement add() { - // check valid parameters - if (this.parameters != null) { - for (int i = 0; i < prepareResult.getParamCount(); i++) { - if (parameters[i] == null) { - throw new IllegalArgumentException( - String.format("Parameter at position %s is not set", i)); - } - } - if (batchingParameters == null) batchingParameters = new ArrayList<>(); - batchingParameters.add(parameters); - parameters = null; - } - return this; - } - - @Override - public MariadbClientParameterizedQueryStatement bind( - @Nullable String identifier, @Nullable Object value) { - Assert.requireNonNull(identifier, "identifier cannot be null"); - return bind(getColumn(identifier), value); - } - - @SuppressWarnings({"rawtypes", "unchecked"}) - @Override - public MariadbClientParameterizedQueryStatement bind(int index, @Nullable Object value) { - if (value == null) return bindNull(index, null); - if (index >= prepareResult.getParamCount() || index < 0) { - throw new IndexOutOfBoundsException( - String.format( - "index must be in 0-%d range but value is %d", - prepareResult.getParamCount() - 1, index)); - } - - for (Codec codec : Codecs.LIST) { - if (codec.canEncode(value.getClass())) { - if (parameters == null) parameters = new Parameter[prepareResult.getParamCount()]; - parameters[index] = (Parameter) new Parameter(codec, value); - return this; - } - } - throw new IllegalArgumentException( - String.format( - "No encoder for class %s (parameter at index %s) ", value.getClass().getName(), index)); - } - - @Override - public MariadbClientParameterizedQueryStatement bindNull( - @Nullable String identifier, @Nullable Class type) { - Assert.requireNonNull(identifier, "identifier cannot be null"); - return bindNull(getColumn(identifier), type); - } - - @Override - public MariadbClientParameterizedQueryStatement bindNull(int index, @Nullable Class type) { - if (index >= prepareResult.getParamCount() || index < 0) { - throw new IndexOutOfBoundsException( - String.format( - "index must be in 0-%d range but value is " + "%d", - prepareResult.getParamCount() - 1, index)); - } - if (parameters == null) parameters = new Parameter[prepareResult.getParamCount()]; - parameters[index] = Parameter.NULL_PARAMETER; - return this; - } - - private int getColumn(String name) { + protected int getColumnIndex(String name) { + Assert.requireNonNull(name, "identifier cannot be null"); for (int i = 0; i < this.prepareResult.getParamNameList().size(); i++) { if (name.equals(this.prepareResult.getParamNameList().get(i))) return i; } - throw new IllegalArgumentException( + if (prepareResult.getParamCount() <= 0) { + throw new IndexOutOfBoundsException( + String.format("Binding parameters is not supported for the statement '%s'", initialSql)); + } + throw new NoSuchElementException( String.format( "No parameter with name '%s' found (possible values %s)", name, this.prepareResult.getParamNameList().toString())); @@ -118,64 +46,93 @@ private int getColumn(String name) { @Override public Flux execute() { - - if (batchingParameters == null) { - if (parameters == null) { - throw new IllegalArgumentException("No parameter have been set"); - } - // valid parameters - for (int i = 0; i < prepareResult.getParamCount(); i++) { - if (parameters[i] == null) { - throw new IllegalArgumentException( - String.format("Parameter at position %s is not set", i)); - } - } - return executeSingleQuery(this.sql, this.prepareResult, this.generatedColumns); + String sql; + ExceptionFactory factory; + if (this.generatedColumns == null || !client.getVersion().supportReturning()) { + sql = this.initialSql; + factory = this.factory; } else { - // add current set of parameters. see https://github.com/r2dbc/r2dbc-spi/issues/229 - add(); + sql = augment(this.initialSql, this.generatedColumns); + factory = ExceptionFactory.withSql(sql); + } - String[] generatedCols = - generatedColumns != null && client.getVersion().supportReturning() - ? generatedColumns - : null; - Flux fluxMsg = - this.client.sendCommand( - new QueryWithParametersPacket( - prepareResult, this.batchingParameters.get(0), generatedCols)); - int index = 1; - while (index < this.batchingParameters.size()) { - fluxMsg = - fluxMsg.concatWith( - this.client.sendCommand( - new QueryWithParametersPacket( - prepareResult, this.batchingParameters.get(index++), generatedCols))); + if (this.getExpectedSize() != 0) { + if (this.bindings.size() == 0) { + throw new IllegalStateException("No parameters have been set"); } - this.batchingParameters = null; - this.parameters = null; - - Flux flux = - fluxMsg - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new MariadbResult( - true, + this.bindings.forEach(b -> b.validate(this.getExpectedSize())); + return Flux.defer( + () -> { + if (this.bindings.size() == 1) { + // single query + Binding binding = this.bindings.pollFirst(); + + Flux messages = + bindingParameterResults(binding, getExpectedSize()) + .flatMapMany( + values -> + this.client.sendCommand( + new QueryWithParametersPacket( + prepareResult, + values, + client.getVersion().supportReturning() + ? generatedColumns + : null), + false)); + return toResult( + Protocol.TEXT, client, messages, factory, null, generatedColumns, configuration); + } + + // batch + Iterator iterator = this.bindings.iterator(); + Sinks.Many bindingSink = Sinks.many().unicast().onBackpressureBuffer(); + AtomicBoolean canceled = new AtomicBoolean(); + return bindingSink + .asFlux() + .map( + it -> { + Flux messages = + bindingParameterResults(it, getExpectedSize()) + .flatMapMany( + values -> + this.client.sendCommand( + new QueryWithParametersPacket( + prepareResult, + values, + client.getVersion().supportReturning() + ? generatedColumns + : null), + false)) + .doOnComplete(() -> tryNextBinding(iterator, bindingSink, canceled)); + + return toResult( + Protocol.TEXT, + this.client, + messages, + factory, null, - dataRow, - ExceptionFactory.INSTANCE, generatedColumns, - client.getVersion().supportReturning(), - client.getConf())); - return flux.doOnDiscard(RowPacket.class, RowPacket::release); + configuration); + }) + .flatMap(mariadbResultFlux -> mariadbResultFlux) + .doOnCancel(() -> clearBindings(iterator, canceled)) + .doOnError(e -> clearBindings(iterator, canceled)) + .doOnSubscribe( + it -> + bindingSink.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST)); + }); + } else { + return Flux.defer( + () -> { + Flux messages = + this.client.sendCommand( + new QueryPacket(sql), DecoderState.QUERY_RESPONSE, sql, false); + return toResult( + Protocol.TEXT, client, messages, factory, null, generatedColumns, configuration); + }); } } - @Override - public MariadbClientParameterizedQueryStatement fetchSize(int rows) { - return this; - } - @Override public MariadbClientParameterizedQueryStatement returnGeneratedValues(String... columns) { Assert.requireNonNull(columns, "columns must not be null"); @@ -189,52 +146,18 @@ public MariadbClientParameterizedQueryStatement returnGeneratedValues(String... return this; } - private Flux executeSingleQuery( - String sql, ClientPrepareResult prepareResult, String[] generatedColumns) { - ExceptionFactory factory = ExceptionFactory.withSql(sql); - - Flux response = - this.client - .sendCommand( - new QueryWithParametersPacket( - prepareResult, - parameters, - generatedColumns != null && client.getVersion().supportReturning() - ? generatedColumns - : null)) - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new MariadbResult( - true, - null, - dataRow, - factory, - generatedColumns, - client.getVersion().supportReturning(), - client.getConf())); - return response - .concatWith( - Flux.create( - sink -> { - sink.complete(); - parameters = new Parameter[prepareResult.getParamCount()]; - })) - .doOnDiscard(RowPacket.class, RowPacket::release); - } - @Override public String toString() { return "MariadbClientParameterizedQueryStatement{" + "client=" + client + ", sql='" - + sql + + initialSql + '\'' + ", prepareResult=" + prepareResult - + ", parameters=" - + Arrays.toString(parameters) + + ", bindings=" + + Arrays.toString(bindings.toArray()) + ", configuration=" + configuration + ", generatedColumns=" diff --git a/src/main/java/org/mariadb/r2dbc/MariadbColumnMetadata.java b/src/main/java/org/mariadb/r2dbc/MariadbColumnMetadata.java deleted file mode 100644 index d946364f..00000000 --- a/src/main/java/org/mariadb/r2dbc/MariadbColumnMetadata.java +++ /dev/null @@ -1,79 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc; - -import io.r2dbc.spi.ColumnMetadata; -import io.r2dbc.spi.Nullability; -import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; - -final class MariadbColumnMetadata implements ColumnMetadata { - - private ColumnDefinitionPacket columnDefinitionPacket; - - MariadbColumnMetadata(ColumnDefinitionPacket columnDefinitionPacket) { - this.columnDefinitionPacket = columnDefinitionPacket; - } - - @Override - public String getName() { - return columnDefinitionPacket.getColumnAlias(); - } - - @Override - public Class getJavaType() { - return columnDefinitionPacket.getJavaClass(); - } - - @Override - public ColumnDefinitionPacket getNativeTypeMetadata() { - return columnDefinitionPacket; - } - - @Override - public Nullability getNullability() { - return this.columnDefinitionPacket.getNullability(); - } - - @Override - public Integer getPrecision() { - switch (columnDefinitionPacket.getType()) { - case OLDDECIMAL: - case DECIMAL: - // DECIMAL and OLDDECIMAL are "exact" fixed-point number. - // so : - // - if can be signed, 1 byte is saved for sign - // - if decimal > 0, one byte more for dot - if (columnDefinitionPacket.isSigned()) { - return (int) - (columnDefinitionPacket.getLength() - - ((columnDefinitionPacket.getDecimals() > 0) ? 2 : 1)); - } else { - return (int) - (columnDefinitionPacket.getLength() - - ((columnDefinitionPacket.getDecimals() > 0) ? 1 : 0)); - } - default: - return (int) columnDefinitionPacket.getLength(); - } - } - - @Override - public Integer getScale() { - switch (columnDefinitionPacket.getType()) { - case OLDDECIMAL: - case TINYINT: - case SMALLINT: - case INTEGER: - case FLOAT: - case DOUBLE: - case BIGINT: - case MEDIUMINT: - case BIT: - case DECIMAL: - return (int) columnDefinitionPacket.getDecimals(); - default: - return 0; - } - } -} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbCommonStatement.java b/src/main/java/org/mariadb/r2dbc/MariadbCommonStatement.java new file mode 100644 index 00000000..25caac57 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbCommonStatement.java @@ -0,0 +1,194 @@ +package org.mariadb.r2dbc; + +import java.util.ArrayDeque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nonnull; +import org.mariadb.r2dbc.api.MariadbStatement; +import org.mariadb.r2dbc.client.Client; +import org.mariadb.r2dbc.client.MariadbResult; +import org.mariadb.r2dbc.codec.Codecs; +import org.mariadb.r2dbc.message.Protocol; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.server.RowPacket; +import org.mariadb.r2dbc.util.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +public abstract class MariadbCommonStatement implements MariadbStatement { + public static final int UNKNOWN_SIZE = -1; + protected final ArrayDeque bindings = new ArrayDeque<>(); + protected int expectedSize; + protected final Client client; + protected final String initialSql; + protected final MariadbConnectionConfiguration configuration; + protected ExceptionFactory factory; + protected String[] generatedColumns; + private final Protocol defaultProtocol; + + public MariadbCommonStatement( + Client client, + String sql, + MariadbConnectionConfiguration configuration, + Protocol defaultProtocol) { + this.defaultProtocol = defaultProtocol; + this.client = client; + this.configuration = configuration; + this.initialSql = Assert.requireNonNull(sql, "sql must not be null"); + this.factory = ExceptionFactory.withSql(sql); + } + + public MariadbStatement add() { + Binding binding = this.bindings.peekLast(); + if (binding != null) { + binding.validate(getExpectedSize()); + } else if (getExpectedSize() > 0) { + throw new IllegalArgumentException( + String.format("No parameter have been bind, but expect %s values", getExpectedSize())); + } + this.bindings.add(new Binding(getExpectedSize())); + return this; + } + + @Override + public MariadbStatement bind(String identifier, Object value) { + return bind(getColumnIndex(identifier), value); + } + + @Override + public MariadbStatement bindNull(String identifier, Class type) { + return bindNull(getColumnIndex(identifier), type); + } + + @Override + public MariadbStatement bindNull(int index, Class type) { + if (index < 0) { + throw new IndexOutOfBoundsException( + String.format("wrong index value %d, index must be positive", index)); + } + if (index >= expectedSize && expectedSize != UNKNOWN_SIZE) { + + throw new IndexOutOfBoundsException( + (getExpectedSize() == 0) + ? String.format( + "Binding parameters is not supported for the statement '%s'", initialSql) + : String.format( + "Cannot bind parameter %d, statement has %d parameters", index, expectedSize)); + } + getCurrentOrFirstBinding().add(index, Codecs.encodeNull(type, index)); + return this; + } + + @Override + public MariadbStatement bind(int index, Object value) { + Assert.requireNonNull(value, "value must not be null"); + if (index < 0) { + throw new IndexOutOfBoundsException( + String.format("wrong index value %d, index must be positive", index)); + } + + getCurrentOrFirstBinding() + .add(index, Codecs.encode(value, index, defaultProtocol, factory, client.getContext())); + return this; + } + + protected abstract int getColumnIndex(String name); + + @Nonnull + protected Binding getCurrentOrFirstBinding() { + Binding binding = this.bindings.peekLast(); + if (binding == null) { + Binding newBinding = new Binding(getExpectedSize()); + this.bindings.add(newBinding); + return newBinding; + } else { + return binding; + } + } + /** + * Augments an SQL statement with a {@code RETURNING} statement and column names. If the + * collection is empty, uses {@code *} for column names. + * + * @param sql the SQL to augment + * @param generatedColumns the names of the columns to augment with + * @return an augmented sql statement returning the specified columns or a wildcard + * @throws IllegalArgumentException if {@code sql} or {@code generatedColumns} is {@code null} + */ + public static String augment(String sql, String[] generatedColumns) { + Assert.requireNonNull(sql, "sql must not be null"); + Assert.requireNonNull(generatedColumns, "generatedColumns must not be null"); + return String.format( + "%s RETURNING %s", + sql, generatedColumns.length == 0 ? "*" : String.join(", ", generatedColumns)); + } + + static Mono> bindingParameterResults(Binding binding, int expectedSize) { + return Flux.fromIterable(binding.getBindResultParameters(expectedSize)) + .flatMap( + f -> { + if (f.isNull()) { + return Mono.just(new BindEncodedValue(f.getCodec(), null)); + } else { + return f.getValue().map(b -> new BindEncodedValue(f.getCodec(), b)); + } + }) + .collectList(); + } + + public static Flux toResult( + final Protocol protocol, + Client client, + Flux messages, + ExceptionFactory factory, + AtomicReference prepareResult, + String[] generatedColumns, + MariadbConnectionConfiguration configuration) { + return messages + .windowUntil(it -> it.resultSetEnd()) + .map( + dataRow -> + new MariadbResult( + protocol == Protocol.TEXT, + prepareResult, + dataRow, + factory, + generatedColumns, + client.getVersion().supportReturning(), + configuration)) + .cast(org.mariadb.r2dbc.api.MariadbResult.class) + .doOnDiscard(RowPacket.class, RowPacket::release); + } + + protected static void tryNextBinding( + Iterator iterator, Sinks.Many bindingSink, AtomicBoolean canceled) { + + if (canceled.get()) { + return; + } + + try { + if (iterator.hasNext()) { + bindingSink.emitNext(iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST); + } else { + bindingSink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST); + } + } catch (Exception e) { + bindingSink.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST); + } + } + + protected void clearBindings(Iterator iterator, AtomicBoolean canceled) { + canceled.set(true); + while (iterator.hasNext()) { + iterator.next(); + } + this.bindings.forEach(Binding::clear); + } + + protected int getExpectedSize() { + return expectedSize; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbConnection.java b/src/main/java/org/mariadb/r2dbc/MariadbConnection.java index d94cb001..daf2390f 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbConnection.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbConnection.java @@ -1,37 +1,42 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.TransactionDefinition; import io.r2dbc.spi.ValidationDepth; +import java.time.Duration; +import java.util.function.Function; import org.mariadb.r2dbc.api.MariadbStatement; import org.mariadb.r2dbc.client.Client; +import org.mariadb.r2dbc.message.client.ChangeSchemaPacket; import org.mariadb.r2dbc.message.client.PingPacket; import org.mariadb.r2dbc.message.client.QueryPacket; import org.mariadb.r2dbc.util.Assert; import org.mariadb.r2dbc.util.PrepareCache; +import org.mariadb.r2dbc.util.constants.Capabilities; +import org.mariadb.r2dbc.util.constants.ServerStatus; import reactor.core.publisher.Mono; import reactor.util.Logger; import reactor.util.Loggers; -final class MariadbConnection implements org.mariadb.r2dbc.api.MariadbConnection { +public final class MariadbConnection implements org.mariadb.r2dbc.api.MariadbConnection { private final Logger logger = Loggers.getLogger(this.getClass()); private final Client client; private final MariadbConnectionConfiguration configuration; + private volatile IsolationLevel sessionIsolationLevel; private volatile IsolationLevel isolationLevel; + private volatile String database; - MariadbConnection( + public MariadbConnection( Client client, IsolationLevel isolationLevel, MariadbConnectionConfiguration configuration) { this.client = Assert.requireNonNull(client, "client must not be null"); - this.isolationLevel = Assert.requireNonNull(isolationLevel, "isolationLevel must not be null"); + this.sessionIsolationLevel = + Assert.requireNonNull(isolationLevel, "isolationLevel must not be null"); this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - - // save Global isolation level to avoid asking each new connection with same configuration - if (configuration.getIsolationLevel() == null) { - configuration.setIsolationLevel(isolationLevel); - } + this.database = configuration.getDatabase(); } @Override @@ -39,6 +44,26 @@ public Mono beginTransaction() { return this.client.beginTransaction(); } + @Override + public Mono beginTransaction(TransactionDefinition definition) { + Mono request = Mono.empty(); + + // set isolation level for next transaction if set + IsolationLevel isoLevel = definition.getAttribute(TransactionDefinition.ISOLATION_LEVEL); + if (isoLevel != null && !isoLevel.equals(getTransactionIsolationLevel())) { + String sql = String.format("SET TRANSACTION ISOLATION LEVEL %s", isoLevel.asSql()); + ExceptionFactory exceptionFactory = ExceptionFactory.withSql(sql); + request = + client + .sendCommand(new QueryPacket(sql), true) + .handle(exceptionFactory::handleErrorResponse) + .then() + .doOnSuccess(ignore -> this.isolationLevel = isoLevel); + } + + return request.then(this.client.beginTransaction(definition)); + } + @Override public Mono close() { return this.client.close().then(Mono.empty()); @@ -46,7 +71,7 @@ public Mono close() { @Override public Mono commitTransaction() { - return this.client.commitTransaction(); + return this.client.commitTransaction().then().doOnSuccess(i -> this.isolationLevel = null); } @Override @@ -57,7 +82,13 @@ public MariadbBatch createBatch() { @Override public Mono createSavepoint(String name) { Assert.requireNonNull(name, "name must not be null"); - return this.client.createSavepoint(name); + Mono needsBegin = isAutoCommit() ? this.client.beginTransaction() : Mono.empty(); + String cmd = String.format("SAVEPOINT `%s`", name.replace("`", "``")); + return needsBegin.then( + client + .sendCommand(new QueryPacket(cmd), true) + .handle(ExceptionFactory.withSql(cmd)::handleErrorResponse) + .then()); } @Override @@ -66,14 +97,11 @@ public MariadbStatement createStatement(String sql) { if (sql.trim().isEmpty()) { throw new IllegalArgumentException("Statement cannot be empty."); } - if (MariadbSimpleQueryStatement.supports(sql, this.client)) { - return new MariadbSimpleQueryStatement(this.client, sql); - } else { - if (this.configuration.useServerPrepStmts()) { - return new MariadbServerParameterizedQueryStatement(this.client, sql, this.configuration); - } - return new MariadbClientParameterizedQueryStatement(this.client, sql, this.configuration); + + if (this.configuration.useServerPrepStmts() || sql.contains("call")) { + return new MariadbServerParameterizedQueryStatement(this.client, sql, this.configuration); } + return new MariadbClientParameterizedQueryStatement(this.client, sql, this.configuration); } @Override @@ -83,18 +111,26 @@ public MariadbConnectionMetadata getMetadata() { @Override public IsolationLevel getTransactionIsolationLevel() { - return this.isolationLevel; + if (isolationLevel != null) return isolationLevel; + if ((client.getContext().getClientCapabilities() | Capabilities.CLIENT_SESSION_TRACK) > 0 + && client.getContext().getIsolationLevel() != null) + return client.getContext().getIsolationLevel(); + return this.sessionIsolationLevel; } @Override public boolean isAutoCommit() { - return this.client.isAutoCommit(); + return this.client.isAutoCommit() && !this.client.isInTransaction(); } @Override public Mono releaseSavepoint(String name) { Assert.requireNonNull(name, "name must not be null"); - return this.client.releaseSavepoint(name); + String cmd = String.format("RELEASE SAVEPOINT `%s`", name.replace("`", "``")); + return client + .sendCommand(new QueryPacket(cmd), true) + .handle(ExceptionFactory.withSql(cmd)::handleErrorResponse) + .then(); } @Override @@ -102,9 +138,29 @@ public long getThreadId() { return this.client.getThreadId(); } + @Override + public boolean isInTransaction() { + return (this.client.getContext().getServerStatus() & ServerStatus.IN_TRANSACTION) > 0; + } + + @Override + public boolean isInReadOnlyTransaction() { + return (this.client.getContext().getServerStatus() & ServerStatus.STATUS_IN_TRANS_READONLY) > 0; + } + + @Override + public String getHost() { + return this.client.getHostAddress() != null ? this.client.getHostAddress().getHost() : null; + } + + @Override + public int getPort() { + return this.client.getHostAddress() != null ? this.client.getHostAddress().getPort() : 3306; + } + @Override public Mono rollbackTransaction() { - return this.client.rollbackTransaction(); + return this.client.rollbackTransaction().then().doOnSuccess(i -> this.isolationLevel = null); } @Override @@ -115,43 +171,103 @@ public Mono rollbackTransactionToSavepoint(String name) { @Override public Mono setAutoCommit(boolean autoCommit) { - return client.setAutoCommit(autoCommit); + return client + .setAutoCommit(autoCommit) + .then() + .doOnSuccess(i -> this.isolationLevel = autoCommit ? null : this.isolationLevel); + } + + @Override + public Mono setLockWaitTimeout(Duration timeout) { + return Mono.empty(); + } + + @Override + public Mono setStatementTimeout(Duration timeout) { + Assert.requireNonNull(timeout, "timeout must not be null"); + boolean serverSupportTimeout = + (client.getVersion().isMariaDBServer() + && client.getVersion().versionGreaterOrEqual(10, 1, 1) + || (!client.getVersion().isMariaDBServer() + && client.getVersion().versionGreaterOrEqual(5, 7, 4))); + if (!serverSupportTimeout) { + return Mono.error( + ExceptionFactory.createException( + "query timeout not supported by server. (required MariaDB 10.1.1+ | MySQL 5.7.4+)", + "HY000", + -1, + "SET max_statement_time")); + } + + long msValue = timeout.toMillis(); + + // MariaDB did implement max_statement_time in seconds, MySQL copied feature but in ms ... + + String sql; + if (client.getVersion().isMariaDBServer()) { + sql = String.format("SET max_statement_time=%s", (double) msValue / 1000); + } else { + sql = String.format("SET SESSION MAX_EXECUTION_TIME=%s", msValue); + } + + ExceptionFactory exceptionFactory = ExceptionFactory.withSql(sql); + return client + .sendCommand(new QueryPacket(sql), true) + .handle(exceptionFactory::handleErrorResponse) + .then(); } @Override public Mono setTransactionIsolationLevel(IsolationLevel isolationLevel) { Assert.requireNonNull(isolationLevel, "isolationLevel must not be null"); - final IsolationLevel newIsolation = isolationLevel; - String sql = String.format("SET TRANSACTION ISOLATION LEVEL %s", isolationLevel.asSql()); + + if ((client.getContext().getClientCapabilities() | Capabilities.CLIENT_SESSION_TRACK) > 0 + && client.getContext().getIsolationLevel() != null + && client.getContext().getIsolationLevel().equals(isolationLevel)) return Mono.empty(); + + String sql = + String.format("SET SESSION TRANSACTION ISOLATION LEVEL %s", isolationLevel.asSql()); ExceptionFactory exceptionFactory = ExceptionFactory.withSql(sql); + final IsolationLevel newIsolation = isolationLevel; return client - .sendCommand(new QueryPacket(sql)) + .sendCommand(new QueryPacket(sql), true) .handle(exceptionFactory::handleErrorResponse) .then() - .doOnSuccess(ignore -> this.isolationLevel = newIsolation); + .doOnSuccess(ignore -> this.sessionIsolationLevel = newIsolation); } @Override public String toString() { - return "MariadbConnection{client=" + client + ", isolationLevel=" + isolationLevel + '}'; + return "MariadbConnection{client=" + + client + + ", isolationLevel=" + + ((client.getContext().getClientCapabilities() | Capabilities.CLIENT_SESSION_TRACK) > 0 + ? client.getContext().getIsolationLevel() + : sessionIsolationLevel) + + '}'; } @Override public Mono validate(ValidationDepth depth) { + if (this.client.isCloseRequested()) { + return Mono.just(false); + } if (depth == ValidationDepth.LOCAL) { return Mono.just(this.client.isConnected()); } return Mono.create( sink -> { - if (!this.client.isConnected()) { + // only when using failover, connection might be recreated + if (HaMode.NONE.equals(this.configuration.getHaMode()) && !this.client.isConnected()) { sink.success(false); return; } this.client - .sendCommand(new PingPacket()) + .sendCommand(new PingPacket(), true) .windowUntil(it -> it.ending()) + .flatMap(Function.identity()) .subscribe( msg -> sink.success(true), err -> { @@ -161,6 +277,29 @@ public Mono validate(ValidationDepth depth) { }); } + @Override + public String getDatabase() { + if ((client.getContext().getClientCapabilities() | Capabilities.CLIENT_SESSION_TRACK) > 0) + return client.getContext().getDatabase(); + return this.database; + } + + public Mono setDatabase(String database) { + Assert.requireNonNull(database, "database must not be null"); + + if ((client.getContext().getClientCapabilities() | Capabilities.CLIENT_SESSION_TRACK) > 0 + && client.getContext().getDatabase() != null + && client.getContext().getDatabase().equals(database)) return Mono.empty(); + + ExceptionFactory exceptionFactory = ExceptionFactory.withSql("COM_INIT_DB"); + final String newDatabase = database; + return client + .sendCommand(new ChangeSchemaPacket(database), true) + .handle(exceptionFactory::handleErrorResponse) + .then() + .doOnSuccess(ignore -> this.database = newDatabase); + } + public PrepareCache _test_prepareCache() { return client.getPrepareCache(); } diff --git a/src/main/java/org/mariadb/r2dbc/MariadbConnectionConfiguration.java b/src/main/java/org/mariadb/r2dbc/MariadbConnectionConfiguration.java index bd06b9bd..d366543f 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbConnectionConfiguration.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbConnectionConfiguration.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; @@ -13,19 +13,22 @@ import java.time.Duration; import java.util.*; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.HostAddress; import org.mariadb.r2dbc.util.SslConfig; +import reactor.netty.resources.LoopResources; +import reactor.netty.tcp.TcpResources; import reactor.util.annotation.Nullable; public final class MariadbConnectionConfiguration { public static final int DEFAULT_PORT = 3306; private final String database; - private final String host; - + private final List hostAddresses; + private HaMode haMode; private final Duration connectTimeout; - private final Duration socketTimeout; private final boolean tcpKeepAlive; private final boolean tcpAbortiveClose; + private final boolean transactionReplay; private final CharSequence password; private final CharSequence[] pamOtherPwd; private final int port; @@ -44,18 +47,22 @@ public final class MariadbConnectionConfiguration { private final boolean useServerPrepStmts; private final boolean autocommit; private final boolean tinyInt1isBit; + private final String[] restrictedAuth; + private final LoopResources loopResources; private MariadbConnectionConfiguration( + String haMode, @Nullable Duration connectTimeout, - @Nullable Duration socketTimeout, @Nullable Boolean tcpKeepAlive, @Nullable Boolean tcpAbortiveClose, + @Nullable Boolean transactionReplay, @Nullable String database, @Nullable String host, @Nullable Map connectionAttributes, @Nullable Map sessionVariables, @Nullable CharSequence password, int port, + @Nullable List hostAddresses, @Nullable String socket, @Nullable String username, boolean allowMultiQueries, @@ -70,16 +77,26 @@ private MariadbConnectionConfiguration( @Nullable String cachingRsaPublicKey, boolean allowPublicKeyRetrieval, boolean useServerPrepStmts, + IsolationLevel isolationLevel, boolean autocommit, @Nullable Integer prepareCacheSize, @Nullable CharSequence[] pamOtherPwd, - boolean tinyInt1isBit) { + boolean tinyInt1isBit, + String restrictedAuth, + @Nullable LoopResources loopResources) { + this.haMode = haMode == null ? HaMode.NONE : HaMode.from(haMode); this.connectTimeout = connectTimeout == null ? Duration.ofSeconds(10) : connectTimeout; - this.socketTimeout = socketTimeout; this.tcpKeepAlive = tcpKeepAlive == null ? Boolean.FALSE : tcpKeepAlive; this.tcpAbortiveClose = tcpAbortiveClose == null ? Boolean.FALSE : tcpAbortiveClose; + this.transactionReplay = transactionReplay == null ? Boolean.FALSE : transactionReplay; this.database = database != null && !database.isEmpty() ? database : null; - this.host = host; + this.isolationLevel = isolationLevel; + this.restrictedAuth = restrictedAuth != null ? restrictedAuth.split(",") : null; + if (hostAddresses != null) { + this.hostAddresses = hostAddresses; + } else { + this.hostAddresses = HostAddress.parse(host, port); + } this.connectionAttributes = connectionAttributes; this.sessionVariables = sessionVariables; this.password = password != null && !password.toString().isEmpty() ? password : null; @@ -98,11 +115,12 @@ private MariadbConnectionConfiguration( this.rsaPublicKey = rsaPublicKey; this.cachingRsaPublicKey = cachingRsaPublicKey; this.allowPublicKeyRetrieval = allowPublicKeyRetrieval; - this.useServerPrepStmts = useServerPrepStmts; this.prepareCacheSize = (prepareCacheSize == null) ? 250 : prepareCacheSize.intValue(); this.pamOtherPwd = pamOtherPwd; this.autocommit = autocommit; this.tinyInt1isBit = tinyInt1isBit; + this.loopResources = loopResources != null ? loopResources : TcpResources.get(); + this.useServerPrepStmts = !this.allowMultiQueries && useServerPrepStmts; } static boolean boolValue(Object value) { @@ -128,13 +146,14 @@ static int intValue(Object value) { public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOptions) { Builder builder = new Builder(); - builder.database(connectionFactoryOptions.getValue(DATABASE)); + builder.database((String) connectionFactoryOptions.getValue(DATABASE)); if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.SOCKET)) { builder.socket( - connectionFactoryOptions.getRequiredValue(MariadbConnectionFactoryProvider.SOCKET)); + (String) + connectionFactoryOptions.getRequiredValue(MariadbConnectionFactoryProvider.SOCKET)); } else { - builder.host(connectionFactoryOptions.getRequiredValue(HOST)); + builder.host((String) connectionFactoryOptions.getRequiredValue(HOST)); } if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.ALLOW_MULTI_QUERIES)) { @@ -150,12 +169,6 @@ public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOpti connectionFactoryOptions.getValue(ConnectionFactoryOptions.CONNECT_TIMEOUT))); } - if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.SOCKET_TIMEOUT)) { - builder.socketTimeout( - durationValue( - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SOCKET_TIMEOUT))); - } - if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.TCP_KEEP_ALIVE)) { builder.tcpKeepAlive( boolValue( @@ -169,12 +182,26 @@ public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOpti MariadbConnectionFactoryProvider.TCP_ABORTIVE_CLOSE))); } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.TRANSACTION_REPLAY)) { + builder.transactionReplay( + boolValue( + connectionFactoryOptions.getValue( + MariadbConnectionFactoryProvider.TRANSACTION_REPLAY))); + } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.SESSION_VARIABLES)) { String sessionVarString = - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SESSION_VARIABLES); + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SESSION_VARIABLES); builder.sessionVariables(getMapFromString(sessionVarString)); } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.HAMODE)) { + String haMode = + (String) connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.HAMODE); + builder.haMode(haMode); + } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.ALLOW_PIPELINING)) { builder.allowPipelining( boolValue( @@ -188,6 +215,14 @@ public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOpti connectionFactoryOptions.getValue( MariadbConnectionFactoryProvider.USE_SERVER_PREPARE))); } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.ISOLATION_LEVEL)) { + String isolationLvl = + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.ISOLATION_LEVEL); + builder.isolationLevel( + isolationLvl == null ? null : IsolationLevel.valueOf(isolationLvl.replace("-", " "))); + } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.AUTO_COMMIT)) { builder.autocommit( boolValue( @@ -202,7 +237,9 @@ public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOpti if (connectionFactoryOptions.hasOption( MariadbConnectionFactoryProvider.CONNECTION_ATTRIBUTES)) { String connAttributes = - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CONNECTION_ATTRIBUTES); + (String) + connectionFactoryOptions.getValue( + MariadbConnectionFactoryProvider.CONNECTION_ATTRIBUTES); builder.connectionAttributes(getMapFromString(connAttributes)); } @@ -216,32 +253,39 @@ public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOpti if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.SSL_MODE)) { builder.sslMode( SslMode.from( - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SSL_MODE))); + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SSL_MODE))); } builder.serverSslCert( - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SERVER_SSL_CERT)); + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.SERVER_SSL_CERT)); builder.clientSslCert( - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CLIENT_SSL_CERT)); + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CLIENT_SSL_CERT)); builder.clientSslKey( - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CLIENT_SSL_KEY)); + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CLIENT_SSL_KEY)); builder.clientSslPassword( - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CLIENT_SSL_PWD)); + (String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.CLIENT_SSL_PWD)); if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.TLS_PROTOCOL)) { String[] protocols = - connectionFactoryOptions - .getValue(MariadbConnectionFactoryProvider.TLS_PROTOCOL) + ((String) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.TLS_PROTOCOL)) .split("[,;\\s]+"); builder.tlsProtocol(protocols); } - builder.password(connectionFactoryOptions.getValue(PASSWORD)); - builder.username(connectionFactoryOptions.getRequiredValue(USER)); + builder.password((CharSequence) connectionFactoryOptions.getValue(PASSWORD)); + builder.username((String) connectionFactoryOptions.getRequiredValue(USER)); if (connectionFactoryOptions.hasOption(PORT)) { builder.port(intValue(connectionFactoryOptions.getValue(PORT))); } if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.PAM_OTHER_PASSWORD)) { String s = - connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.PAM_OTHER_PASSWORD); + (String) + connectionFactoryOptions.getValue( + MariadbConnectionFactoryProvider.PAM_OTHER_PASSWORD); String[] pairs = s.split(","); try { for (int i = 0; i < pairs.length; i++) { @@ -252,6 +296,12 @@ public static Builder fromOptions(ConnectionFactoryOptions connectionFactoryOpti } builder.pamOtherPwd(pairs); } + if (connectionFactoryOptions.hasOption(MariadbConnectionFactoryProvider.LOOP_RESOURCES)) { + LoopResources loopResources = + (LoopResources) + connectionFactoryOptions.getValue(MariadbConnectionFactoryProvider.LOOP_RESOURCES); + builder.loopResources(loopResources); + } return builder; } @@ -295,9 +345,13 @@ public String getDatabase() { return this.database; } + public HaMode getHaMode() { + return this.haMode; + } + @Nullable - public String getHost() { - return this.host; + public List getHostAddresses() { + return this.hostAddresses; } @Nullable @@ -368,10 +422,6 @@ public int getPrepareCacheSize() { return prepareCacheSize; } - public Duration getSocketTimeout() { - return socketTimeout; - } - public boolean isTcpKeepAlive() { return tcpKeepAlive; } @@ -380,6 +430,18 @@ public boolean isTcpAbortiveClose() { return tcpAbortiveClose; } + public boolean isTransactionReplay() { + return transactionReplay; + } + + public String[] getRestrictedAuth() { + return restrictedAuth; + } + + public LoopResources loopResources() { + return loopResources; + } + @Override public String toString() { StringBuilder hiddenPwd = new StringBuilder(); @@ -399,21 +461,19 @@ public String toString() { + "database='" + database + '\'' - + ", host='" - + host - + '\'' + + ", hosts={" + + (hostAddresses == null ? "" : Arrays.toString(hostAddresses.toArray())) + + '}' + ", connectTimeout=" + connectTimeout - + ", socketTimeout=" - + socketTimeout + ", tcpKeepAlive=" + tcpKeepAlive + ", tcpAbortiveClose=" + tcpAbortiveClose + + ", transactionReplay=" + + transactionReplay + ", password=" + hiddenPwd - + ", port=" - + port + ", prepareCacheSize=" + prepareCacheSize + ", socket='" @@ -450,6 +510,8 @@ public String toString() { + tinyInt1isBit + ", pamOtherPwd=" + hiddenPamPwd + + ", restrictedAuth=" + + restrictedAuth + '}'; } @@ -460,15 +522,17 @@ public String toString() { */ public static final class Builder implements Cloneable { + @Nullable private String haMode; @Nullable private String rsaPublicKey; @Nullable private String cachingRsaPublicKey; private boolean allowPublicKeyRetrieval; @Nullable private String username; @Nullable private Duration connectTimeout; - @Nullable private Duration socketTimeout; @Nullable private Boolean tcpKeepAlive; @Nullable private Boolean tcpAbortiveClose; + @Nullable private Boolean transactionReplay; @Nullable private String database; + @Nullable private List hostAddresses; @Nullable private String host; @Nullable private Map sessionVariables; @Nullable private Map connectionAttributes; @@ -478,6 +542,7 @@ public static final class Builder implements Cloneable { private boolean allowMultiQueries = false; private boolean allowPipelining = true; private boolean useServerPrepStmts = false; + private IsolationLevel isolationLevel = null; private boolean autocommit = true; private boolean tinyInt1isBit = true; @Nullable Integer prepareCacheSize; @@ -488,6 +553,8 @@ public static final class Builder implements Cloneable { @Nullable private CharSequence clientSslPassword; private SslMode sslMode = SslMode.DISABLE; private CharSequence[] pamOtherPwd; + private String restrictedAuth; + @Nullable private LoopResources loopResources; private Builder() {} @@ -512,16 +579,18 @@ public MariadbConnectionConfiguration build() { } return new MariadbConnectionConfiguration( + this.haMode, this.connectTimeout, - this.socketTimeout, this.tcpKeepAlive, this.tcpAbortiveClose, + this.transactionReplay, this.database, this.host, this.connectionAttributes, this.sessionVariables, this.password, this.port, + this.hostAddresses, this.socket, this.username, this.allowMultiQueries, @@ -536,10 +605,13 @@ public MariadbConnectionConfiguration build() { this.cachingRsaPublicKey, this.allowPublicKeyRetrieval, this.useServerPrepStmts, + this.isolationLevel, this.autocommit, this.prepareCacheSize, this.pamOtherPwd, - this.tinyInt1isBit); + this.tinyInt1isBit, + this.restrictedAuth, + this.loopResources); } /** @@ -553,8 +625,18 @@ public Builder connectTimeout(@Nullable Duration connectTimeout) { return this; } - public Builder socketTimeout(@Nullable Duration socketTimeout) { - this.socketTimeout = socketTimeout; + public Builder haMode(@Nullable String haMode) { + this.haMode = haMode; + return this; + } + + public Builder hostAddresses(@Nullable List hostAddresses) { + this.hostAddresses = hostAddresses; + return this; + } + + public Builder restrictedAuth(@Nullable String restrictedAuth) { + this.restrictedAuth = restrictedAuth; return this; } @@ -568,6 +650,11 @@ public Builder tcpAbortiveClose(@Nullable Boolean tcpAbortiveClose) { return this; } + public Builder transactionReplay(@Nullable Boolean transactionReplay) { + this.transactionReplay = transactionReplay; + return this; + } + public Builder connectionAttributes(@Nullable Map connectionAttributes) { this.connectionAttributes = connectionAttributes; return this; @@ -764,6 +851,17 @@ public Builder useServerPrepStmts(boolean useServerPrepStmts) { return this; } + /** + * Permit to set default isolation level + * + * @param isolationLevel transaction isolation level + * @return this {@link Builder} + */ + public Builder isolationLevel(IsolationLevel isolationLevel) { + this.isolationLevel = isolationLevel; + return this; + } + /** * Permit to indicate default autocommit value. Default value True. * @@ -837,6 +935,11 @@ public Builder username(String username) { return this; } + public Builder loopResources(LoopResources loopResources) { + this.loopResources = Assert.requireNonNull(loopResources, "loopResources must not be null"); + return this; + } + @Override public Builder clone() throws CloneNotSupportedException { return (Builder) super.clone(); @@ -868,12 +971,12 @@ public String toString() { + username + ", connectTimeout=" + connectTimeout - + ", socketTimeout=" - + socketTimeout + ", tcpKeepAlive=" + tcpKeepAlive + ", tcpAbortiveClose=" + tcpAbortiveClose + + ", transactionReplay=" + + transactionReplay + ", database=" + database + ", host=" @@ -884,8 +987,13 @@ public String toString() { + connectionAttributes + ", password=" + hiddenPwd + + ", restrictedAuth=" + + restrictedAuth + ", port=" + port + + ", hosts={" + + (hostAddresses == null ? "" : Arrays.toString(hostAddresses.toArray())) + + '}' + ", socket=" + socket + ", allowMultiQueries=" @@ -895,6 +1003,8 @@ public String toString() { + ", useServerPrepStmts=" + useServerPrepStmts + ", prepareCacheSize=" + + isolationLevel + + ", isolationLevel=" + prepareCacheSize + ", tlsProtocol=" + tlsProtocol diff --git a/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactory.java b/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactory.java index a2c02f31..59f2e795 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactory.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactory.java @@ -1,121 +1,109 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; import io.netty.channel.unix.DomainSocketAddress; import io.r2dbc.spi.*; -import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.Iterator; import java.util.Map; +import java.util.concurrent.locks.ReentrantLock; import org.mariadb.r2dbc.client.Client; -import org.mariadb.r2dbc.client.ClientImpl; -import org.mariadb.r2dbc.client.ClientPipelineImpl; +import org.mariadb.r2dbc.client.FailoverClient; +import org.mariadb.r2dbc.client.SimpleClient; +import org.mariadb.r2dbc.message.Protocol; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.client.QueryPacket; import org.mariadb.r2dbc.message.flow.AuthenticationFlow; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.HostAddress; +import org.mariadb.r2dbc.util.constants.Capabilities; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.resources.ConnectionProvider; public final class MariadbConnectionFactory implements ConnectionFactory { private final MariadbConnectionConfiguration configuration; - private final SocketAddress endpoint; public MariadbConnectionFactory(MariadbConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); } public static MariadbConnectionFactory from(MariadbConnectionConfiguration configuration) { return new MariadbConnectionFactory(configuration); } - private static SocketAddress createSocketAddress(MariadbConnectionConfiguration configuration) { - - if (configuration.getSocket() != null) { - return new DomainSocketAddress(configuration.getSocket()); - } else { - return InetSocketAddress.createUnresolved(configuration.getHost(), configuration.getPort()); - } - } - @Override public Mono create() { - return doCreateConnection().cast(org.mariadb.r2dbc.api.MariadbConnection.class); + ReentrantLock lock = new ReentrantLock(); + return ((configuration.getSocket() != null) + ? connectToSocket( + configuration, new DomainSocketAddress(configuration.getSocket()), null, lock) + : (configuration.getHaMode().equals(HaMode.NONE) + ? configuration.getHaMode().connectHost(configuration, lock, false) + : configuration + .getHaMode() + .connectHost(configuration, lock, false) + .flatMap(c -> Mono.just(new FailoverClient(configuration, lock, c))))) + .flatMap( + client -> + Mono.just( + new MariadbConnection( + client, + configuration.getIsolationLevel() == null + ? IsolationLevel.REPEATABLE_READ + : configuration.getIsolationLevel(), + configuration)) + .onErrorResume(throwable -> closeWithError(client, throwable))) + .cast(org.mariadb.r2dbc.api.MariadbConnection.class); } - private Mono doCreateConnection() { - - Mono clientMono; - if (configuration.allowPipelining()) { - clientMono = - ClientPipelineImpl.connect( - ConnectionProvider.newConnection(), this.endpoint, configuration); - } else { - clientMono = - ClientImpl.connect(ConnectionProvider.newConnection(), this.endpoint, configuration); - } - - return clientMono - .delayUntil(client -> AuthenticationFlow.exchange(client, this.configuration)) + private static Mono connectToSocket( + final MariadbConnectionConfiguration configuration, + SocketAddress endpoint, + HostAddress hostAddress, + ReentrantLock lock) { + return SimpleClient.connect( + ConnectionProvider.newConnection(), endpoint, hostAddress, configuration, lock) + .delayUntil(client -> AuthenticationFlow.exchange(client, configuration, hostAddress)) .cast(Client.class) - .flatMap( - client -> { - Mono waiting = Mono.empty(); - // only execute SET command if needed : - // - autocommit default value differ than option - // - session variable set - if ((configuration.getSessionVariables() != null - && configuration.getSessionVariables().size() > 0) - || client.isAutoCommit() != configuration.autocommit()) { - waiting = setSessionVariables(client); - } - - if (configuration.getIsolationLevel() == null) { - Mono isolationLevelMono = waiting.then(getIsolationLevel(client)); - return isolationLevelMono - .map(it -> new MariadbConnection(client, it, configuration)) - .onErrorResume(throwable -> this.closeWithError(client, throwable)); - } else { - return waiting - .then( - Mono.just( - new MariadbConnection( - client, configuration.getIsolationLevel(), configuration))) - .onErrorResume(throwable -> this.closeWithError(client, throwable)); - } - }) - .onErrorMap(this::cannotConnect); + .flatMap(client -> setSessionVariables(configuration, client).thenReturn(client)) + .onErrorMap(e -> cannotConnect(e, endpoint)); } - private Mono closeWithError(Client client, Throwable throwable) { - return client.close().then(Mono.error(throwable)); - } + public static Mono setSessionVariables( + final MariadbConnectionConfiguration configuration, Client client) { - private Throwable cannotConnect(Throwable throwable) { + // set default autocommit value + StringBuilder sql = + new StringBuilder("SET autocommit=" + (configuration.autocommit() ? "1" : "0")); - if (throwable instanceof R2dbcException) { - return throwable; + // set default transaction isolation + String txIsolation = + (!client.getVersion().isMariaDBServer() + && (client.getVersion().versionGreaterOrEqual(8, 0, 3) + || (client.getVersion().getMajorVersion() < 8 + && client.getVersion().versionGreaterOrEqual(5, 7, 20)))) + ? "transaction_isolation" + : "tx_isolation"; + sql.append(",") + .append(txIsolation) + .append("='") + .append( + configuration.getIsolationLevel() == null + ? "REPEATABLE-READ" + : configuration.getIsolationLevel().asSql().replace(" ", "-")) + .append("'"); + + // set session tracking + if ((client.getContext().getClientCapabilities() & Capabilities.CLIENT_SESSION_TRACK) > 0) { + sql.append(",session_track_schema=1"); + sql.append(",session_track_system_variables='autocommit,").append(txIsolation).append("'"); } - return new R2dbcNonTransientResourceException( - String.format("Cannot connect to %s", this.endpoint), throwable); - } - - @Override - public ConnectionFactoryMetadata getMetadata() { - return MariadbConnectionFactoryMetadata.INSTANCE; - } - - @Override - public String toString() { - return "MariadbConnectionFactory{configuration=" + this.configuration + '}'; - } - - private Mono setSessionVariables(Client client) { - StringBuilder sql = - new StringBuilder("SET autocommit=" + (configuration.autocommit() ? "1" : "0")); + // set session variables if defined if (configuration.getSessionVariables() != null && configuration.getSessionVariables().size() > 0) { Map sessionVariable = configuration.getSessionVariables(); @@ -129,42 +117,34 @@ private Mono setSessionVariables(Client client) { sql.append(",").append(key).append("=").append(value); } } + Flux messages = client.sendCommand(new QueryPacket(sql.toString()), true); + return MariadbCommonStatement.toResult( + Protocol.TEXT, client, messages, ExceptionFactory.INSTANCE, null, null, configuration) + .last() + .then(); + } - return new MariadbSimpleQueryStatement(client, sql.toString()).execute().last().then(); + public static Mono closeWithError(Client client, Throwable throwable) { + return client.close().then(Mono.error(throwable)); } - private Mono getIsolationLevel(Client client) { - String sql = "SELECT @@tx_isolation"; - if (!client.getVersion().isMariaDBServer() - && (client.getVersion().versionGreaterOrEqual(8, 0, 3) - || (client.getVersion().getMajorVersion() < 8 - && client.getVersion().versionGreaterOrEqual(5, 7, 20)))) { - sql = "SELECT @@transaction_isolation"; + public static Throwable cannotConnect(Throwable throwable, SocketAddress endpoint) { + + if (throwable instanceof R2dbcException) { + return throwable; } - return new MariadbSimpleQueryStatement(client, sql) - .execute() - .flatMap( - it -> - it.map( - (row, rowMetadata) -> { - String level = row.get(0, String.class); - - switch (level) { - case "REPEATABLE-READ": - return IsolationLevel.REPEATABLE_READ; - - case "READ-UNCOMMITTED": - return IsolationLevel.READ_UNCOMMITTED; - - case "SERIALIZABLE": - return IsolationLevel.SERIALIZABLE; - - default: - return IsolationLevel.READ_COMMITTED; - } - })) - .defaultIfEmpty(IsolationLevel.READ_COMMITTED) - .last(); + return new R2dbcNonTransientResourceException( + String.format("Cannot connect to %s", endpoint), throwable); + } + + @Override + public ConnectionFactoryMetadata getMetadata() { + return MariadbConnectionFactoryMetadata.INSTANCE; + } + + @Override + public String toString() { + return "MariadbConnectionFactory{configuration=" + this.configuration + '}'; } } diff --git a/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryMetadata.java b/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryMetadata.java index 9793bcb0..b44db5a6 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryMetadata.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryMetadata.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryProvider.java b/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryProvider.java index 57c66514..a83db627 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryProvider.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbConnectionFactoryProvider.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; @@ -8,8 +8,8 @@ import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.ConnectionFactoryProvider; import io.r2dbc.spi.Option; -import java.time.Duration; import org.mariadb.r2dbc.util.Assert; +import reactor.netty.resources.LoopResources; public final class MariadbConnectionFactoryProvider implements ConnectionFactoryProvider { public static final String MARIADB_DRIVER = "mariadb"; @@ -22,17 +22,20 @@ public final class MariadbConnectionFactoryProvider implements ConnectionFactory public static final Option CLIENT_SSL_PWD = Option.valueOf("clientSslPassword"); public static final Option ALLOW_PIPELINING = Option.valueOf("allowPipelining"); public static final Option USE_SERVER_PREPARE = Option.valueOf("useServerPrepStmts"); + public static final Option ISOLATION_LEVEL = Option.valueOf("isolationLevel"); public static final Option AUTO_COMMIT = Option.valueOf("autoCommit"); public static final Option TINY_IS_BIT = Option.valueOf("tinyInt1isBit"); public static final Option PREPARE_CACHE_SIZE = Option.valueOf("prepareCacheSize"); public static final Option SSL_MODE = Option.valueOf("sslMode"); + public static final Option TRANSACTION_REPLAY = Option.valueOf("transactionReplay"); + public static final Option HAMODE = Option.valueOf("haMode"); public static final Option CONNECTION_ATTRIBUTES = Option.valueOf("connectionAttributes"); public static final Option PAM_OTHER_PASSWORD = Option.valueOf("pamOtherPwd"); - public static final Option SOCKET_TIMEOUT = Option.valueOf("socketTimeout"); public static final Option TCP_KEEP_ALIVE = Option.valueOf("tcpKeepAlive"); public static final Option TCP_ABORTIVE_CLOSE = Option.valueOf("tcpAbortiveClose"); public static final Option SESSION_VARIABLES = Option.valueOf("sessionVariables"); + public static final Option LOOP_RESOURCES = Option.valueOf("loopResources"); static MariadbConnectionConfiguration createConfiguration( ConnectionFactoryOptions connectionFactoryOptions) { @@ -54,7 +57,7 @@ public String getDriver() { public boolean supports(ConnectionFactoryOptions connectionFactoryOptions) { Assert.requireNonNull(connectionFactoryOptions, "connectionFactoryOptions must not be null"); - String driver = connectionFactoryOptions.getValue(DRIVER); + String driver = (String) connectionFactoryOptions.getValue(DRIVER); return MARIADB_DRIVER.equals(driver); } } diff --git a/src/main/java/org/mariadb/r2dbc/MariadbConnectionMetadata.java b/src/main/java/org/mariadb/r2dbc/MariadbConnectionMetadata.java index 13c34216..d640b6eb 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbConnectionMetadata.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbConnectionMetadata.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/main/java/org/mariadb/r2dbc/MariadbDataSegment.java b/src/main/java/org/mariadb/r2dbc/MariadbDataSegment.java new file mode 100644 index 00000000..d27aa34d --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbDataSegment.java @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.netty.buffer.ByteBuf; +import io.r2dbc.spi.Result; + +public interface MariadbDataSegment extends Result.Segment { + + void updateRaw(ByteBuf data); +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbOutParametersMetadata.java b/src/main/java/org/mariadb/r2dbc/MariadbOutParametersMetadata.java new file mode 100644 index 00000000..a4f97c42 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbOutParametersMetadata.java @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.r2dbc.spi.*; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.Assert; + +final class MariadbOutParametersMetadata implements OutParametersMetadata { + + private final List metadataList; + private volatile Collection columnNames; + + MariadbOutParametersMetadata(List metadataList) { + this.metadataList = metadataList; + } + + @Override + public ColumnDefinitionPacket getParameterMetadata(int index) { + if (index < 0 || index >= this.metadataList.size()) { + throw new IllegalArgumentException( + String.format( + "Column index %d is not in permit range[0,%s]", index, this.metadataList.size() - 1)); + } + return this.metadataList.get(index); + } + + @Override + public ColumnDefinitionPacket getParameterMetadata(String name) { + return metadataList.get(getIndex(name)); + } + + @Override + public List getParameterMetadatas() { + return Collections.unmodifiableList(this.metadataList); + } + + private int getIndex(String name) { + Assert.requireNonNull(name, "name must not be null"); + for (int i = 0; i < this.metadataList.size(); i++) { + if (this.metadataList.get(i).getName().equalsIgnoreCase(name)) { + return i; + } + } + + throw new IllegalArgumentException( + String.format( + "Column name '%s' does not exist in column names %s", + name, getColumnNames(this.metadataList))); + } + + private Collection getColumnNames(List columnMetadatas) { + List columnNames = new ArrayList<>(); + for (ColumnDefinitionPacket columnMetadata : columnMetadatas) { + columnNames.add(columnMetadata.getName()); + } + return Collections.unmodifiableCollection(columnNames); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbOutSegment.java b/src/main/java/org/mariadb/r2dbc/MariadbOutSegment.java new file mode 100644 index 00000000..d506f0d0 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbOutSegment.java @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.netty.buffer.ByteBuf; +import io.r2dbc.spi.*; +import java.util.List; +import org.mariadb.r2dbc.codec.RowDecoder; +import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; + +public class MariadbOutSegment extends MariadbReadable + implements Result.OutSegment, OutParameters, MariadbDataSegment { + + public MariadbOutSegment(RowDecoder decoder, List metadataList) { + super(decoder, metadataList); + } + + public void updateRaw(ByteBuf data) { + super.updateRaw(data); + } + + @Override + public OutParameters outParameters() { + return this; + } + + @Override + public OutParametersMetadata getMetadata() { + return new MariadbOutParametersMetadata(metadataList); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbReadable.java b/src/main/java/org/mariadb/r2dbc/MariadbReadable.java new file mode 100644 index 00000000..e40459de --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbReadable.java @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.netty.buffer.ByteBuf; +import io.r2dbc.spi.Readable; +import java.util.*; +import org.mariadb.r2dbc.codec.RowDecoder; +import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.Assert; +import reactor.util.annotation.Nullable; + +public class MariadbReadable implements Readable { + + private final RowDecoder decoder; + protected final List metadataList; + private ByteBuf raw; + + MariadbReadable(RowDecoder decoder, List metadataList) { + this.decoder = decoder; + this.metadataList = metadataList; + } + + protected void updateRaw(ByteBuf data) { + this.raw = data; + decoder.resetRow(raw); + } + + @Nullable + @Override + public T get(int index, Class type) { + return decoder.get(index, getColumnsDef(index), type); + } + + private ColumnDefinitionPacket getColumnsDef(int index) { + if (index < 0) { + throw new IndexOutOfBoundsException(String.format("Column index %d must be positive", index)); + } + if (index >= this.metadataList.size()) { + throw new IndexOutOfBoundsException( + String.format( + "Column index %d not in range [0-%s]", index, this.metadataList.size() - 1)); + } + return this.metadataList.get(index); + } + + private int getIndex(String name) { + Assert.requireNonNull(name, "name must not be null"); + for (int i = 0; i < this.metadataList.size(); i++) { + if (this.metadataList.get(i).getName().equalsIgnoreCase(name)) { + return i; + } + } + + Set columnNames = new TreeSet<>(); + for (ColumnDefinitionPacket columnDef : this.metadataList) { + columnNames.add(columnDef.getName()); + } + throw new NoSuchElementException( + String.format( + "Column name '%s' does not exist in column names %s", + name, Collections.unmodifiableCollection(columnNames))); + } + + @Nullable + @Override + public T get(String name, Class type) { + Assert.requireNonNull(name, "name must not be null"); + return get(getIndex(name), type); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbResult.java b/src/main/java/org/mariadb/r2dbc/MariadbResult.java deleted file mode 100644 index 731192b2..00000000 --- a/src/main/java/org/mariadb/r2dbc/MariadbResult.java +++ /dev/null @@ -1,193 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.r2dbc.spi.R2dbcException; -import io.r2dbc.spi.Row; -import io.r2dbc.spi.RowMetadata; -import java.nio.charset.StandardCharsets; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import org.mariadb.r2dbc.codec.BinaryRowDecoder; -import org.mariadb.r2dbc.codec.RowDecoder; -import org.mariadb.r2dbc.codec.TextRowDecoder; -import org.mariadb.r2dbc.message.server.*; -import org.mariadb.r2dbc.util.ServerPrepareResult; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -final class MariadbResult implements org.mariadb.r2dbc.api.MariadbResult { - - private final Flux dataRows; - private final ExceptionFactory factory; - private RowDecoder decoder; - private final String[] generatedColumns; - private final boolean supportReturning; - private final boolean text; - private final MariadbConnectionConfiguration conf; - private AtomicReference prepareResult; - - private volatile ColumnDefinitionPacket[] metadataList; - private volatile int metadataIndex; - private volatile int columnNumber; - private volatile MariadbRowMetadata rowMetadata; - - MariadbResult( - boolean text, - AtomicReference prepareResult, - Flux dataRows, - ExceptionFactory factory, - String[] generatedColumns, - boolean supportReturning, - MariadbConnectionConfiguration conf) { - this.text = text; - this.dataRows = dataRows; - this.factory = factory; - this.generatedColumns = generatedColumns; - this.supportReturning = supportReturning; - this.conf = conf; - this.prepareResult = prepareResult; - } - - @Override - public Mono getRowsUpdated() { - final AtomicInteger rowCount = new AtomicInteger(0); - Flux f = - this.dataRows.handle( - (serverMessage, sink) -> { - if (serverMessage instanceof ErrorPacket) { - sink.error(this.factory.from((ErrorPacket) serverMessage)); - return; - } - if (serverMessage instanceof RowPacket) { - rowCount.incrementAndGet(); - ((RowPacket) serverMessage).release(); - return; - } - if (serverMessage instanceof EofPacket) { - EofPacket eofPacket = (EofPacket) serverMessage; - if (eofPacket.resultSetEnd()) { - sink.next(rowCount.get()); - rowCount.set(0); - sink.complete(); - } - return; - } - if (serverMessage instanceof OkPacket) { - if (rowCount.get() > 0) { - // a results with returning - sink.next(rowCount.get()); - rowCount.set(0); - sink.complete(); - } else { - OkPacket okPacket = (OkPacket) serverMessage; - long affectedRows = okPacket.getAffectedRows(); - sink.next((int) affectedRows); - sink.complete(); - } - } - }); - return f.singleOrEmpty(); - } - - @Override - public Flux map(BiFunction f) { - metadataIndex = 0; - - return this.dataRows - .takeUntil(msg -> msg.resultSetEnd()) - .handle( - (serverMessage, sink) -> { - if (serverMessage instanceof ErrorPacket) { - sink.error(this.factory.from((ErrorPacket) serverMessage)); - return; - } - - if (serverMessage instanceof CompletePrepareResult) { - this.prepareResult.set(((CompletePrepareResult) serverMessage).getPrepare()); - metadataList = this.prepareResult.get().getColumns(); - return; - } - - if (serverMessage instanceof ColumnCountPacket) { - this.columnNumber = ((ColumnCountPacket) serverMessage).getColumnCount(); - if (!((ColumnCountPacket) serverMessage).isMetaFollows()) { - metadataList = this.prepareResult.get().getColumns(); - rowMetadata = MariadbRowMetadata.toRowMetadata(this.metadataList); - this.decoder = new BinaryRowDecoder(columnNumber, this.metadataList, this.conf); - } else { - metadataList = new ColumnDefinitionPacket[this.columnNumber]; - } - return; - } - - if (serverMessage instanceof ColumnDefinitionPacket) { - this.metadataList[metadataIndex++] = (ColumnDefinitionPacket) serverMessage; - if (metadataIndex == columnNumber) { - rowMetadata = MariadbRowMetadata.toRowMetadata(this.metadataList); - this.decoder = - text - ? new TextRowDecoder(columnNumber, this.metadataList, this.conf) - : new BinaryRowDecoder(columnNumber, this.metadataList, this.conf); - } - return; - } - - if (serverMessage instanceof RowPacket) { - RowPacket row = ((RowPacket) serverMessage); - try { - sink.next( - f.apply(new MariadbRow(metadataList, decoder, row.getRaw()), rowMetadata)); - } catch (IllegalArgumentException i) { - sink.error(this.factory.createException(i.getMessage(), "HY000", -1)); - } catch (R2dbcException i) { - sink.error(i); - } finally { - row.release(); - } - return; - } - - // This is for server that doesn't permit RETURNING: rely on OK_packet LastInsertId - // to retrieve the last generated ID. - if (generatedColumns != null - && !supportReturning - && serverMessage instanceof OkPacket) { - - String colName = generatedColumns.length > 0 ? generatedColumns[0] : "ID"; - metadataList = new ColumnDefinitionPacket[1]; - metadataList[0] = ColumnDefinitionPacket.fromGeneratedId(colName); - rowMetadata = MariadbRowMetadata.toRowMetadata(this.metadataList); - - OkPacket okPacket = ((OkPacket) serverMessage); - if (okPacket.getAffectedRows() > 1) { - sink.error( - this.factory.createException( - "Connector cannot get generated ID (using returnGeneratedValues) multiple rows before MariaDB 10.5.1", - "HY000", - -1)); - return; - } - ByteBuf buf = getLongTextEncoded(okPacket.getLastInsertId()); - decoder = new TextRowDecoder(1, this.metadataList, this.conf); - try { - sink.next(f.apply(new MariadbRow(metadataList, decoder, buf), rowMetadata)); - } finally { - buf.release(); - } - } - }); - } - - private ByteBuf getLongTextEncoded(long value) { - byte[] byteValue = Long.toString(value).getBytes(StandardCharsets.US_ASCII); - byte[] encodedLength; - int length = byteValue.length; - encodedLength = new byte[] {(byte) length}; - return Unpooled.copiedBuffer(encodedLength, byteValue); - } -} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbRow.java b/src/main/java/org/mariadb/r2dbc/MariadbRow.java deleted file mode 100644 index 57641ab5..00000000 --- a/src/main/java/org/mariadb/r2dbc/MariadbRow.java +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc; - -import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.Row; -import java.util.Collections; -import java.util.Set; -import java.util.TreeSet; -import org.mariadb.r2dbc.codec.RowDecoder; -import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; -import org.mariadb.r2dbc.util.Assert; -import reactor.util.annotation.Nullable; - -public class MariadbRow implements Row { - - private final ColumnDefinitionPacket[] columnDefinitionPackets; - private final RowDecoder decoder; - private final ByteBuf raw; - - MariadbRow(ColumnDefinitionPacket[] columnDefinitionPackets, RowDecoder decoder, ByteBuf data) { - this.columnDefinitionPackets = columnDefinitionPackets; - this.decoder = decoder; - this.raw = data; - - decoder.resetRow(raw); - } - - @Nullable - @Override - public T get(int index, Class type) { - Assert.requireNonNull(type, "type must not be null"); - return decoder.get(index, getMeta(index), type); - } - - @Nullable - @Override - public T get(String name, Class type) { - Assert.requireNonNull(name, "name must not be null"); - Assert.requireNonNull(type, "type must not be null"); - return get(getColumn(name), type); - } - - private int getColumn(String name) { - Assert.requireNonNull(name, "name must not be null"); - for (int i = 0; i < this.columnDefinitionPackets.length; i++) { - if (this.columnDefinitionPackets[i].getColumnAlias().equalsIgnoreCase(name)) { - return i; - } - } - - Set columnNames = new TreeSet<>(); - for (ColumnDefinitionPacket columnDef : columnDefinitionPackets) { - columnNames.add(columnDef.getColumnAlias()); - } - throw new IllegalArgumentException( - String.format( - "Column name '%s' does not exist in column names %s", - name, Collections.unmodifiableCollection(columnNames))); - } - - private ColumnDefinitionPacket getMeta(int index) { - if (index < 0) { - throw new IllegalArgumentException(String.format("Column index %d must be positive", index)); - } - if (index >= this.columnDefinitionPackets.length) { - throw new IllegalArgumentException( - String.format( - "Column index %d not in range [0-%s]", - index, this.columnDefinitionPackets.length - 1)); - } - return this.columnDefinitionPackets[index]; - } -} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbRowMetadata.java b/src/main/java/org/mariadb/r2dbc/MariadbRowMetadata.java index 676653a9..3c83cd54 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbRowMetadata.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbRowMetadata.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; @@ -10,25 +10,17 @@ final class MariadbRowMetadata implements RowMetadata { - private final List metadataList; + private final List metadataList; private volatile Collection columnNames; - MariadbRowMetadata(List metadataList) { + MariadbRowMetadata(List metadataList) { this.metadataList = metadataList; } - static MariadbRowMetadata toRowMetadata(ColumnDefinitionPacket[] metadataList) { - List columnMetadata = new ArrayList<>(metadataList.length); - for (ColumnDefinitionPacket col : metadataList) { - columnMetadata.add(new MariadbColumnMetadata(col)); - } - return new MariadbRowMetadata(columnMetadata); - } - @Override - public MariadbColumnMetadata getColumnMetadata(int index) { + public ColumnDefinitionPacket getColumnMetadata(int index) { if (index < 0 || index >= this.metadataList.size()) { - throw new IllegalArgumentException( + throw new IndexOutOfBoundsException( String.format( "Column index %d is not in permit range[0,%s]", index, this.metadataList.size() - 1)); } @@ -36,7 +28,7 @@ public MariadbColumnMetadata getColumnMetadata(int index) { } @Override - public MariadbColumnMetadata getColumnMetadata(String name) { + public ColumnDefinitionPacket getColumnMetadata(String name) { return metadataList.get(getColumn(name)); } @@ -47,17 +39,18 @@ private int getColumn(String name) { return i; } } - throw new IllegalArgumentException( + throw new NoSuchElementException( String.format( "Column name '%s' does not exist in column names %s", name, getColumnNames())); } @Override - public List getColumnMetadatas() { + public List getColumnMetadatas() { return Collections.unmodifiableList(this.metadataList); } @Override + @Deprecated public Collection getColumnNames() { if (this.columnNames == null) { this.columnNames = getColumnNames(this.metadataList); @@ -65,9 +58,9 @@ public Collection getColumnNames() { return Collections.unmodifiableCollection(this.columnNames); } - private Collection getColumnNames(List columnMetadatas) { + private Collection getColumnNames(List columnMetadatas) { List columnNames = new ArrayList<>(); - for (MariadbColumnMetadata columnMetadata : columnMetadatas) { + for (ColumnDefinitionPacket columnMetadata : columnMetadatas) { columnNames.add(columnMetadata.getName()); } return Collections.unmodifiableCollection(columnNames); @@ -82,4 +75,12 @@ public String toString() { sb.append("columnNames=").append(columnNames).append("}"); return sb.toString(); } + + @Override + public boolean contains(String columnName) { + if (this.columnNames == null) { + this.columnNames = getColumnNames(this.metadataList); + } + return this.columnNames.stream().anyMatch(columnName::equalsIgnoreCase); + } } diff --git a/src/main/java/org/mariadb/r2dbc/MariadbRowSegment.java b/src/main/java/org/mariadb/r2dbc/MariadbRowSegment.java new file mode 100644 index 00000000..6c704e14 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbRowSegment.java @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.netty.buffer.ByteBuf; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import java.util.List; +import org.mariadb.r2dbc.codec.RowDecoder; +import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; + +public class MariadbRowSegment extends MariadbReadable + implements Result.RowSegment, Row, MariadbDataSegment { + private MariadbRowMetadata meta = null; + + public MariadbRowSegment(RowDecoder decoder, List metadataList) { + super(decoder, metadataList); + } + + public void updateRaw(ByteBuf data) { + super.updateRaw(data); + } + + @Override + public Row row() { + return this; + } + + @Override + public RowMetadata getMetadata() { + if (meta == null) meta = new MariadbRowMetadata(metadataList); + return meta; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbServerParameterizedQueryStatement.java b/src/main/java/org/mariadb/r2dbc/MariadbServerParameterizedQueryStatement.java index 09c9b739..ea28a379 100644 --- a/src/main/java/org/mariadb/r2dbc/MariadbServerParameterizedQueryStatement.java +++ b/src/main/java/org/mariadb/r2dbc/MariadbServerParameterizedQueryStatement.java @@ -1,234 +1,71 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; -import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import org.mariadb.r2dbc.api.MariadbStatement; import org.mariadb.r2dbc.client.Client; import org.mariadb.r2dbc.client.DecoderState; -import org.mariadb.r2dbc.codec.Codec; -import org.mariadb.r2dbc.codec.Codecs; -import org.mariadb.r2dbc.codec.DataType; -import org.mariadb.r2dbc.codec.Parameter; +import org.mariadb.r2dbc.message.Protocol; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.client.ExecutePacket; import org.mariadb.r2dbc.message.client.PreparePacket; -import org.mariadb.r2dbc.message.server.CompletePrepareResult; -import org.mariadb.r2dbc.message.server.ErrorPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; +import org.mariadb.r2dbc.message.client.QueryPacket; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.Binding; +import org.mariadb.r2dbc.util.ServerNamedParamParser; import org.mariadb.r2dbc.util.ServerPrepareResult; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.util.annotation.Nullable; +import reactor.core.publisher.Sinks; -final class MariadbServerParameterizedQueryStatement implements MariadbStatement { +final class MariadbServerParameterizedQueryStatement extends MariadbCommonStatement + implements MariadbStatement { - private final Client client; - private final String initialSql; - private final MariadbConnectionConfiguration configuration; - private Map> parameters; - private List>> batchingParameters; - private String[] generatedColumns; - private AtomicReference prepareResult; + private ServerNamedParamParser paramParser; + private final AtomicReference prepareResult; MariadbServerParameterizedQueryStatement( Client client, String sql, MariadbConnectionConfiguration configuration) { - this.client = client; - this.configuration = configuration; - this.initialSql = Assert.requireNonNull(sql, "sql must not be null"); - this.parameters = null; + super(client, sql, configuration, Protocol.BINARY); + this.expectedSize = UNKNOWN_SIZE; + this.paramParser = null; this.prepareResult = new AtomicReference<>(client.getPrepareCache().get(sql)); } @Override - public MariadbServerParameterizedQueryStatement add() { - // check valid parameters - if (prepareResult.get() != null) { - for (int i = 0; i < prepareResult.get().getNumParams(); i++) { - if (this.parameters == null || parameters.get(i) == null) { - throw new IllegalArgumentException( - String.format("Parameter at position %s is not set", i)); - } - } + protected int getExpectedSize() { + if (expectedSize == UNKNOWN_SIZE) { + expectedSize = + (prepareResult.get() != null) + ? prepareResult.get().getNumParams() + : (((paramParser != null) + ? paramParser.getParamCount() + : ServerNamedParamParser.parameterParts( + initialSql, this.client.noBackslashEscapes()) + .getParamCount())); } - if (batchingParameters == null) batchingParameters = new ArrayList<>(); - batchingParameters.add(parameters); - parameters = null; - return this; - } - - @Override - public MariadbServerParameterizedQueryStatement bind( - @Nullable String identifier, @Nullable Object value) { - Assert.requireNonNull(identifier, "identifier cannot be null"); - return bind(getColumn(identifier), value); + return expectedSize; } - @SuppressWarnings({"rawtypes", "unchecked"}) - @Override - public MariadbServerParameterizedQueryStatement bind(int index, @Nullable Object value) { - if (index < 0) { - throw new IndexOutOfBoundsException( - String.format("wrong index value %d, index must be positive", index)); + protected int getColumnIndex(String name) { + Assert.requireNonNull(name, "identifier cannot be null"); + if (paramParser == null) { + paramParser = + ServerNamedParamParser.parameterParts(initialSql, this.client.noBackslashEscapes()); } - - if (prepareResult.get() != null && index >= prepareResult.get().getNumParams()) { - throw new IndexOutOfBoundsException( - String.format( - "index must be in 0-%d range but value is %d", - prepareResult.get().getNumParams() - 1, index)); + for (int i = 0; i < this.paramParser.getParamNameList().size(); i++) { + if (name.equals(this.paramParser.getParamNameList().get(i))) return i; } - if (value == null) return bindNull(index, null); - if (parameters == null) parameters = new HashMap<>(); - for (Codec codec : Codecs.LIST) { - if (codec.canEncode(value.getClass())) { - parameters.put(index, (Parameter) new Parameter(codec, value)); - return this; - } - } - throw new IllegalArgumentException( + throw new NoSuchElementException( String.format( - "No encoder for class %s (parameter at index %s) ", value.getClass().getName(), index)); - } - - @Override - public MariadbServerParameterizedQueryStatement bindNull( - @Nullable String identifier, @Nullable Class type) { - Assert.requireNonNull(identifier, "identifier cannot be null"); - return bindNull(getColumn(identifier), type); - } - - @Override - @SuppressWarnings({"unchecked", "rawtypes"}) - public MariadbServerParameterizedQueryStatement bindNull(int index, @Nullable Class type) { - if (index < 0) { - throw new IndexOutOfBoundsException( - String.format("wrong index value %d, index must be positive", index)); - } - - if (prepareResult.get() != null && index >= prepareResult.get().getNumParams()) { - throw new IndexOutOfBoundsException( - String.format( - "index must be in 0-%d range but value is %d", - prepareResult.get().getNumParams() - 1, index)); - } - Parameter parameter = null; - if (type != null) { - for (Codec codec : Codecs.LIST) { - if (codec.canEncode(type)) { - - parameter = - new Parameter(codec, null) { - @Override - public DataType getBinaryEncodeType() { - return DataType.VARCHAR; - } - - @Override - public boolean isNull() { - return true; - } - }; - break; - } - } - } - if (parameter == null) { - parameter = Parameter.NULL_PARAMETER; - } - if (parameters == null) parameters = new HashMap<>(); - parameters.put(index, parameter); - return this; - } - - private int getColumn(String name) { - throw new IllegalArgumentException("Cannot use getColumn(name) with prepared statement"); - } - - private void validateParameters() { - if (prepareResult.get() != null) { - // valid parameters - for (int i = 0; i < prepareResult.get().getNumParams(); i++) { - if (this.parameters == null || parameters.get(i) == null) { - throw new IllegalArgumentException( - String.format("Parameter at position %s is not set", i)); - } - } - } - } - - @Override - public Flux execute() { - String sql = this.initialSql; - if (client.getVersion().supportReturning() && generatedColumns != null) { - sql += - generatedColumns.length == 0 - ? " RETURNING *" - : " RETURNING " + String.join(", ", generatedColumns); - prepareResult.set(client.getPrepareCache().get(sql)); - } - - if (batchingParameters == null) { - return executeSingleQuery(sql, this.generatedColumns); - } else { - // add current set of parameters. see https://github.com/r2dbc/r2dbc-spi/issues/229 - if (parameters != null) this.add(); - // prepare command, if not already done - if (prepareResult.get() == null) { - prepareResult.set(client.getPrepareCache().get(sql)); - if (prepareResult.get() == null) { - sendPrepare(sql, ExceptionFactory.withSql(sql)).block(); - } - } - - Flux fluxMsg = - this.client.sendCommand( - new ExecutePacket( - prepareResult.get().getStatementId(), this.batchingParameters.get(0))); - int index = 1; - while (index < this.batchingParameters.size()) { - fluxMsg = - fluxMsg.concatWith( - this.client.sendCommand( - new ExecutePacket( - prepareResult.get().getStatementId(), - this.batchingParameters.get(index++)))); - } - fluxMsg = - fluxMsg.concatWith( - Flux.create( - sink -> { - prepareResult.get().decrementUse(client); - sink.complete(); - })); - - this.batchingParameters = null; - this.parameters = null; - - return fluxMsg - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new MariadbResult( - false, - prepareResult, - dataRow, - ExceptionFactory.INSTANCE, - null, - client.getVersion().supportReturning(), - client.getConf())); - } - } - - @Override - public MariadbServerParameterizedQueryStatement fetchSize(int rows) { - return this; + "No parameter with name '%s' found (possible values %s)", + name, this.paramParser.getParamNameList().toString())); } @Override @@ -239,129 +76,182 @@ public MariadbServerParameterizedQueryStatement returnGeneratedValues(String... throw new IllegalArgumentException( "returnGeneratedValues can have only one column before MariaDB 10.5.1"); } - // prepareResult.validateAddingReturning(); this.generatedColumns = columns; return this; } - private Flux executeSingleQuery( - String sql, String[] generatedColumns) { + @Override + public Flux execute() { + String realSql = paramParser == null ? this.initialSql : paramParser.getRealSql(); + String sql; + if (this.generatedColumns == null || !client.getVersion().supportReturning()) { + sql = realSql; + } else { + sql = augment(realSql, this.generatedColumns); + } ExceptionFactory factory = ExceptionFactory.withSql(sql); if (prepareResult.get() == null && client.getPrepareCache() != null) { prepareResult.set(client.getPrepareCache().get(sql)); } - - if (prepareResult.get() != null) { - validateParameters(); - - ServerPrepareResult res; - if (this.client.getPrepareCache() != null - && (res = this.client.getPrepareCache().get(sql)) != null - && !res.equals(prepareResult.get())) { - prepareResult.get().decrementUse(client); - prepareResult.set(res); + if (this.getExpectedSize() != 0) { + if (this.bindings.size() == 0) { + throw new IllegalStateException("No parameters have been set"); } - if (prepareResult.get().incrementUse()) { - return sendExecuteCmd(factory, parameters, generatedColumns) - .concatWith( - Flux.create( - sink -> { - prepareResult.get().decrementUse(client); - sink.complete(); - parameters = null; - })); - } else { - // prepare is closing - prepareResult.set(null); - } - } - if (this.parameters == null) this.parameters = new HashMap<>(); + this.bindings.forEach(b -> b.validate(this.getExpectedSize())); + return Flux.defer( + () -> { + if (this.bindings.size() == 1) { + // single query + Binding binding = this.bindings.pollFirst(); - Flux flux; - if (configuration.allowPipelining() - && client.getVersion().isMariaDBServer() - && client.getVersion().versionGreaterOrEqual(10, 2, 0)) { - flux = sendPrepareAndExecute(sql, factory, parameters, generatedColumns); - } else { - flux = - sendPrepare(sql, factory) - .flatMapMany( - prepareResult1 -> { - prepareResult.set(prepareResult1); - return sendExecuteCmd(factory, parameters, generatedColumns); - }); - } - return flux.concatWith( - Flux.create( - sink -> { - prepareResult.set(client.getPrepareCache().get(sql)); if (prepareResult.get() != null) { - prepareResult.get().decrementUse(client); - } - sink.complete(); - parameters = null; - })); - } - - private Flux sendPrepareAndExecute( - String sql, - ExceptionFactory factory, - Map> parameters, - String[] generatedColumns) { - return this.client - .sendCommand(new PreparePacket(sql), new ExecutePacket(-1, parameters)) - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new MariadbResult( - false, - this.prepareResult, - dataRow, - factory, - generatedColumns, - client.getVersion().supportReturning(), - client.getConf())); - } + ServerPrepareResult res; + if (this.client.getPrepareCache() != null + && (res = this.client.getPrepareCache().get(sql)) != null + && !res.equals(prepareResult.get())) { + prepareResult.get().decrementUse(client); + prepareResult.set(res); + } - private Mono sendPrepare(String sql, ExceptionFactory factory) { - Flux f = - this.client - .sendCommand(new PreparePacket(sql), DecoderState.PREPARE_RESPONSE, sql) - .handle( - (it, sink) -> { - if (it instanceof ErrorPacket) { - sink.error(factory.from((ErrorPacket) it)); - return; - } - if (it instanceof CompletePrepareResult) { - prepareResult.set(((CompletePrepareResult) it).getPrepare()); - sink.next(prepareResult.get()); - } - if (it.ending()) sink.complete(); - }); - return f.singleOrEmpty(); + if (prepareResult.get().incrementUse()) { + Flux messages = + bindingParameterResults(binding, getExpectedSize()) + .flatMapMany( + values -> + this.client.sendCommand( + new ExecutePacket(sql, prepareResult.get(), values), + DecoderState.QUERY_RESPONSE, + sql, + false)) + .doFinally(s -> prepareResult.get().decrementUse(client)); + return toResult( + Protocol.BINARY, + client, + messages, + factory, + prepareResult, + generatedColumns, + configuration); + } else { + // prepare is closing + prepareResult.set(null); + } + } + Flux messages; + if (configuration.allowPipelining() + && client.getVersion().isMariaDBServer() + && client.getVersion().versionGreaterOrEqual(10, 2, 0)) { + messages = + bindingParameterResults(binding, getExpectedSize()) + .flatMapMany( + values -> + this.client.sendCommand( + new PreparePacket(sql), + new ExecutePacket(sql, null, values), + false)); + } else { + messages = + client + .sendPrepare(new PreparePacket(sql), factory, sql) + .flatMapMany( + serverPrepareResult -> { + prepareResult.set(serverPrepareResult); + return bindingParameterResults(binding, getExpectedSize()) + .flatMapMany( + values -> + this.client.sendCommand( + new ExecutePacket(sql, prepareResult.get(), values), + DecoderState.QUERY_RESPONSE, + sql, + false)); + }); + } + return toResult( + Protocol.BINARY, + client, + messages, + factory, + prepareResult, + generatedColumns, + configuration) + .doFinally( + s -> { + if (prepareResult.get() != null) { + prepareResult.get().decrementUse(client); + } + }); + } + // batch + Iterator iterator = this.bindings.iterator(); + Sinks.Many bindingSink = Sinks.many().unicast().onBackpressureBuffer(); + AtomicBoolean canceled = new AtomicBoolean(); + return prepareIfNotDone(sql, factory) + .thenMany( + bindingSink + .asFlux() + .map( + binding -> { + Flux messages = + bindingParameterResults(binding, getExpectedSize()) + .flatMapMany( + values -> + this.client.sendCommand( + new ExecutePacket( + sql, prepareResult.get(), values), + false)) + .doOnComplete( + () -> tryNextBinding(iterator, bindingSink, canceled)); + + return toResult( + Protocol.BINARY, + this.client, + messages, + factory, + prepareResult, + generatedColumns, + configuration); + }) + .doOnSubscribe( + it -> + bindingSink.emitNext( + iterator.next(), Sinks.EmitFailureHandler.FAIL_FAST)) + .doOnComplete(this.bindings::clear) + .doFinally( + s -> { + if (prepareResult.get() != null) { + prepareResult.get().decrementUse(client); + } + }) + .doOnCancel(() -> clearBindings(iterator, canceled)) + .doOnError(e -> clearBindings(iterator, canceled))) + .flatMap(mariadbResultFlux -> mariadbResultFlux); + }); + } else { + return Flux.defer( + () -> { + Flux messages = + this.client.sendCommand( + new QueryPacket(sql), DecoderState.QUERY_RESPONSE, sql, false); + return toResult( + Protocol.TEXT, client, messages, factory, null, generatedColumns, configuration); + }); + } } - private Flux sendExecuteCmd( - ExceptionFactory factory, Map> parameters, String[] generatedColumns) { - return this.client - .sendCommand( - new ExecutePacket( - prepareResult.get() != null ? prepareResult.get().getStatementId() : -1, - parameters)) - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new MariadbResult( - false, - prepareResult, - dataRow, - factory, - generatedColumns, - client.getVersion().supportReturning(), - client.getConf())); + private Mono prepareIfNotDone(String sql, ExceptionFactory factory) { + // prepare command, if not already done + if (prepareResult.get() == null) { + prepareResult.set(client.getPrepareCache().get(sql)); + if (prepareResult.get() == null) { + return client + .sendPrepare(new PreparePacket(sql), factory, sql) + .doOnSuccess(p -> prepareResult.set(p)); + } + } + prepareResult.get().incrementUse(); + return Mono.just(prepareResult.get()); } @Override @@ -374,10 +264,8 @@ public String toString() { + '\'' + ", configuration=" + configuration - + ", parameters=" - + parameters - + ", batchingParameters=" - + batchingParameters + + ", bindings=" + + bindings + ", generatedColumns=" + (generatedColumns != null ? Arrays.toString(generatedColumns) : null) + ", prepareResult=" diff --git a/src/main/java/org/mariadb/r2dbc/MariadbSimpleQueryStatement.java b/src/main/java/org/mariadb/r2dbc/MariadbSimpleQueryStatement.java deleted file mode 100644 index 2dbcf141..00000000 --- a/src/main/java/org/mariadb/r2dbc/MariadbSimpleQueryStatement.java +++ /dev/null @@ -1,129 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc; - -import org.mariadb.r2dbc.api.MariadbStatement; -import org.mariadb.r2dbc.client.Client; -import org.mariadb.r2dbc.message.client.QueryPacket; -import org.mariadb.r2dbc.message.server.RowPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; -import org.mariadb.r2dbc.util.Assert; -import org.mariadb.r2dbc.util.ClientPrepareResult; -import reactor.core.publisher.Flux; -import reactor.util.annotation.Nullable; - -final class MariadbSimpleQueryStatement implements MariadbStatement { - - private final Client client; - private final String sql; - private String[] generatedColumns; - - MariadbSimpleQueryStatement(Client client, String sql) { - this.client = client; - this.sql = Assert.requireNonNull(sql, "sql must not be null"); - } - - static boolean supports(String sql, Client client) { - Assert.requireNonNull(sql, "sql must not be null"); - if (sql.contains("?") || sql.contains(":")) { - return !ClientPrepareResult.hasParameter(sql, client.noBackslashEscapes()); - } - return true; - } - - @Override - public MariadbSimpleQueryStatement add() { - return this; - } - - @Override - public MariadbSimpleQueryStatement bind(@Nullable String identifier, @Nullable Object value) { - throw new UnsupportedOperationException( - String.format("Binding parameters is not supported for the statement '%s'", this.sql)); - } - - @Override - public MariadbSimpleQueryStatement bind(int index, @Nullable Object value) { - throw new UnsupportedOperationException( - String.format("Binding parameters is not supported for the statement '%s'", this.sql)); - } - - @Override - public MariadbSimpleQueryStatement bindNull( - @Nullable String identifier, @Nullable Class type) { - throw new UnsupportedOperationException( - String.format("Binding parameters is not supported for the statement '%s'", this.sql)); - } - - @Override - public MariadbSimpleQueryStatement bindNull(int index, @Nullable Class type) { - throw new UnsupportedOperationException( - String.format("Binding parameters is not supported for the statement '%s'", this.sql)); - } - - @Override - public Flux execute() { - return execute(this.sql, this.generatedColumns); - } - - @Override - public MariadbSimpleQueryStatement fetchSize(int rows) { - return this; - } - - @Override - public MariadbSimpleQueryStatement returnGeneratedValues(String... columns) { - Assert.requireNonNull(columns, "columns must not be null"); - - if (!client.getVersion().supportReturning() && columns.length > 1) { - throw new IllegalArgumentException( - "returnGeneratedValues can have only one column before MariaDB 10.5.1"); - } - - ClientPrepareResult prepareResult = - ClientPrepareResult.parameterParts(this.sql, this.client.noBackslashEscapes()); - prepareResult.validateAddingReturning(); - - this.generatedColumns = columns; - return this; - } - - @Override - public String toString() { - return "MariadbSimpleQueryStatement{" - + "client=" - + this.client - + ", sql='" - + this.sql - + '\'' - + '}'; - } - - private Flux execute(String sql, String[] generatedColumns) { - ExceptionFactory factory = ExceptionFactory.withSql(sql); - - if (generatedColumns != null && client.getVersion().supportReturning()) { - sql = - String.format( - "%s RETURNING %s", - sql, generatedColumns.length == 0 ? "*" : String.join(", ", generatedColumns)); - } - - Flux response = this.client.sendCommand(new QueryPacket(sql)); - Flux flux = - response - .windowUntil(it -> it.resultSetEnd()) - .map( - dataRow -> - new MariadbResult( - true, - null, - dataRow, - factory, - generatedColumns, - client.getVersion().supportReturning(), - client.getConf())); - return flux.doOnDiscard(RowPacket.class, RowPacket::release); - } -} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbTransactionDefinition.java b/src/main/java/org/mariadb/r2dbc/MariadbTransactionDefinition.java new file mode 100644 index 00000000..cf5f801d --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbTransactionDefinition.java @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.Option; +import io.r2dbc.spi.TransactionDefinition; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.mariadb.r2dbc.util.Assert; + +public final class MariadbTransactionDefinition implements TransactionDefinition { + + public static final MariadbTransactionDefinition EMPTY = + new MariadbTransactionDefinition(Collections.emptyMap()); + + public static Option WITH_CONSISTENT_SNAPSHOT = + Option.valueOf("WITH CONSISTENT SNAPSHOT"); + + private final Map, Object> options; + + private MariadbTransactionDefinition(Map, Object> options) { + this.options = options; + } + + @Override + @SuppressWarnings("unchecked") + public T getAttribute(Option option) { + return (T) this.options.get(option); + } + + public MariadbTransactionDefinition with(Option option, Object value) { + + Map, Object> options = new HashMap<>(this.options); + options.put( + Assert.requireNonNull(option, "option must not be null"), + Assert.requireNonNull(value, "value must not be null")); + + return new MariadbTransactionDefinition(options); + } + + static MariadbTransactionDefinition mutability(boolean readWrite) { + return readWrite ? EMPTY.readWrite() : EMPTY.readOnly(); + } + + static MariadbTransactionDefinition from(IsolationLevel isolationLevel) { + return MariadbTransactionDefinition.EMPTY.isolationLevel(isolationLevel); + } + + public MariadbTransactionDefinition isolationLevel(IsolationLevel isolationLevel) { + return with(TransactionDefinition.ISOLATION_LEVEL, isolationLevel); + } + + public MariadbTransactionDefinition readOnly() { + return with(TransactionDefinition.READ_ONLY, true); + } + + public MariadbTransactionDefinition readWrite() { + return with(TransactionDefinition.READ_ONLY, false); + } + + public MariadbTransactionDefinition consistent() { + return with(MariadbTransactionDefinition.WITH_CONSISTENT_SNAPSHOT, true); + } + + public MariadbTransactionDefinition notConsistent() { + return with(MariadbTransactionDefinition.WITH_CONSISTENT_SNAPSHOT, false); + } + + public static MariadbTransactionDefinition WITH_CONSISTENT_SNAPSHOT_READ_WRITE = + EMPTY.consistent().readWrite(); + public static MariadbTransactionDefinition WITH_CONSISTENT_SNAPSHOT_READ_ONLY = + EMPTY.consistent().readOnly(); + public static MariadbTransactionDefinition READ_WRITE = EMPTY.readWrite(); + public static MariadbTransactionDefinition READ_ONLY = EMPTY.readOnly(); +} diff --git a/src/main/java/org/mariadb/r2dbc/MariadbUpdateCount.java b/src/main/java/org/mariadb/r2dbc/MariadbUpdateCount.java new file mode 100644 index 00000000..bd6ccdf8 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/MariadbUpdateCount.java @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc; + +import io.r2dbc.spi.Result; + +public class MariadbUpdateCount implements Result.UpdateCount { + public MariadbUpdateCount() {} + + @Override + public long value() { + return 0; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/SslMode.java b/src/main/java/org/mariadb/r2dbc/SslMode.java index 9aba3c2e..d99f1bd2 100644 --- a/src/main/java/org/mariadb/r2dbc/SslMode.java +++ b/src/main/java/org/mariadb/r2dbc/SslMode.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/main/java/org/mariadb/r2dbc/api/MariadbBatch.java b/src/main/java/org/mariadb/r2dbc/api/MariadbBatch.java index d92a28a5..85f53a56 100644 --- a/src/main/java/org/mariadb/r2dbc/api/MariadbBatch.java +++ b/src/main/java/org/mariadb/r2dbc/api/MariadbBatch.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.api; diff --git a/src/main/java/org/mariadb/r2dbc/api/MariadbConnection.java b/src/main/java/org/mariadb/r2dbc/api/MariadbConnection.java index 86a3f305..8e97cbee 100644 --- a/src/main/java/org/mariadb/r2dbc/api/MariadbConnection.java +++ b/src/main/java/org/mariadb/r2dbc/api/MariadbConnection.java @@ -1,11 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.api; -import io.r2dbc.spi.Connection; -import io.r2dbc.spi.IsolationLevel; -import io.r2dbc.spi.ValidationDepth; +import io.r2dbc.spi.*; +import java.time.Duration; import reactor.core.publisher.Mono; public interface MariadbConnection extends Connection { @@ -13,6 +12,9 @@ public interface MariadbConnection extends Connection { @Override Mono beginTransaction(); + @Override + Mono beginTransaction(TransactionDefinition definition); + @Override Mono close(); @@ -31,12 +33,20 @@ public interface MariadbConnection extends Connection { @Override MariadbConnectionMetadata getMetadata(); + String getDatabase(); + + Mono setDatabase(String database); + @Override IsolationLevel getTransactionIsolationLevel(); @Override boolean isAutoCommit(); + boolean isInTransaction(); + + boolean isInReadOnlyTransaction(); + @Override Mono releaseSavepoint(String name); @@ -55,5 +65,15 @@ public interface MariadbConnection extends Connection { @Override Mono validate(ValidationDepth depth); + @Override + Mono setLockWaitTimeout(Duration timeout); + + @Override + Mono setStatementTimeout(Duration timeout); + long getThreadId(); + + String getHost(); + + int getPort(); } diff --git a/src/main/java/org/mariadb/r2dbc/api/MariadbConnectionMetadata.java b/src/main/java/org/mariadb/r2dbc/api/MariadbConnectionMetadata.java index 672c5815..f9091cc7 100644 --- a/src/main/java/org/mariadb/r2dbc/api/MariadbConnectionMetadata.java +++ b/src/main/java/org/mariadb/r2dbc/api/MariadbConnectionMetadata.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.api; diff --git a/src/main/java/org/mariadb/r2dbc/api/MariadbOutSegment.java b/src/main/java/org/mariadb/r2dbc/api/MariadbOutSegment.java new file mode 100644 index 00000000..d61e949a --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/api/MariadbOutSegment.java @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.api; + +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Row; + +public interface MariadbOutSegment extends Result.OutSegment { + Row row(); +} diff --git a/src/main/java/org/mariadb/r2dbc/api/MariadbResult.java b/src/main/java/org/mariadb/r2dbc/api/MariadbResult.java index 4328c1d1..f5295315 100644 --- a/src/main/java/org/mariadb/r2dbc/api/MariadbResult.java +++ b/src/main/java/org/mariadb/r2dbc/api/MariadbResult.java @@ -1,20 +1,30 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.api; -import io.r2dbc.spi.Result; -import io.r2dbc.spi.Row; -import io.r2dbc.spi.RowMetadata; +import io.r2dbc.spi.*; +import io.r2dbc.spi.Readable; import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Predicate; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public interface MariadbResult extends Result { @Override - Mono getRowsUpdated(); + Flux getRowsUpdated(); @Override Flux map(BiFunction mappingFunction); + + @Override + Flux map(Function mappingFunction); + + @Override + Result filter(Predicate filter); + + @Override + Flux flatMap(Function> mappingFunction); } diff --git a/src/main/java/org/mariadb/r2dbc/api/MariadbStatement.java b/src/main/java/org/mariadb/r2dbc/api/MariadbStatement.java index 598540b1..d1bef15d 100644 --- a/src/main/java/org/mariadb/r2dbc/api/MariadbStatement.java +++ b/src/main/java/org/mariadb/r2dbc/api/MariadbStatement.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.api; diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlowPluginLoader.java b/src/main/java/org/mariadb/r2dbc/authentication/AuthenticationFlowPluginLoader.java similarity index 82% rename from src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlowPluginLoader.java rename to src/main/java/org/mariadb/r2dbc/authentication/AuthenticationFlowPluginLoader.java index 1f79e02f..f6d26137 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlowPluginLoader.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/AuthenticationFlowPluginLoader.java @@ -1,11 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication; import java.util.ServiceLoader; import org.mariadb.r2dbc.api.MariadbConnection; -import org.mariadb.r2dbc.authentication.AuthenticationPlugin; public class AuthenticationFlowPluginLoader { @@ -20,10 +19,6 @@ public static AuthenticationPlugin get(String type) { ServiceLoader loader = ServiceLoader.load(AuthenticationPlugin.class, MariadbConnection.class.getClassLoader()); - if (type == null || type.isEmpty()) { - return null; - } - for (AuthenticationPlugin implClass : loader) { if (type.equals(implClass.type())) { return implClass.create(); diff --git a/src/main/java/org/mariadb/r2dbc/authentication/AuthenticationPlugin.java b/src/main/java/org/mariadb/r2dbc/authentication/AuthenticationPlugin.java index 4f34c326..0c768845 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/AuthenticationPlugin.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/AuthenticationPlugin.java @@ -1,13 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.authentication; import io.r2dbc.spi.R2dbcException; import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.message.client.ClientMessage; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; public interface AuthenticationPlugin { @@ -17,7 +17,7 @@ public interface AuthenticationPlugin { ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) + AuthSwitch authSwitch, + AuthMoreData authMoreData) throws R2dbcException; } diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/ClearPasswordPluginFlow.java b/src/main/java/org/mariadb/r2dbc/authentication/addon/ClearPasswordPluginFlow.java similarity index 56% rename from src/main/java/org/mariadb/r2dbc/message/flow/ClearPasswordPluginFlow.java rename to src/main/java/org/mariadb/r2dbc/authentication/addon/ClearPasswordPluginFlow.java index 3a96db8b..62090ffc 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/ClearPasswordPluginFlow.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/addon/ClearPasswordPluginFlow.java @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication.addon; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.authentication.AuthenticationPlugin; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; import org.mariadb.r2dbc.message.client.ClearPasswordPacket; -import org.mariadb.r2dbc.message.client.ClientMessage; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; public final class ClearPasswordPluginFlow implements AuthenticationPlugin { @@ -24,8 +24,8 @@ public String type() { public ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) { - return new ClearPasswordPacket(authSwitchPacket.getSequencer(), configuration.getPassword()); + AuthSwitch authSwitch, + AuthMoreData authMoreData) { + return new ClearPasswordPacket(authSwitch.getSequencer(), configuration.getPassword()); } } diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/CachingSha2PasswordFlow.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/CachingSha2PasswordFlow.java similarity index 81% rename from src/main/java/org/mariadb/r2dbc/message/flow/CachingSha2PasswordFlow.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/CachingSha2PasswordFlow.java index ce53f601..01b23443 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/CachingSha2PasswordFlow.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/CachingSha2PasswordFlow.java @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication.standard; import io.r2dbc.spi.R2dbcException; import io.r2dbc.spi.R2dbcNonTransientResourceException; @@ -11,9 +11,10 @@ import java.security.PublicKey; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.SslMode; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; import org.mariadb.r2dbc.message.client.*; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; public final class CachingSha2PasswordFlow extends Sha256PasswordPluginFlow { @@ -70,21 +71,21 @@ public String type() { public ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) + AuthSwitch authSwitch, + AuthMoreData authMoreData) throws R2dbcException { - if (authMoreDataPacket == null) state = State.INIT; + if (authMoreData == null) state = State.INIT; CharSequence password = configuration.getPassword(); switch (state) { case INIT: - byte[] fastCryptedPwd = sha256encryptPassword(password, authSwitchPacket.getSeed()); + byte[] fastCryptedPwd = sha256encryptPassword(password, authSwitch.getSeed()); state = State.FAST_AUTH_RESULT; - return new AuthMoreRawPacket(authSwitchPacket.getSequencer(), fastCryptedPwd); + return new AuthMoreRawPacket(authSwitch.getSequencer(), fastCryptedPwd); case FAST_AUTH_RESULT: - byte fastAuthResult = authMoreDataPacket.getBuf().getByte(0); + byte fastAuthResult = authMoreData.getBuf().getByte(0); switch (fastAuthResult) { case 3: // success authentication @@ -94,7 +95,7 @@ public ClientMessage next( if (configuration.getSslConfig().getSslMode() != SslMode.DISABLE) { // send clear password state = State.SEND_AUTH; - return new ClearPasswordPacket(authMoreDataPacket.getSequencer(), password); + return new ClearPasswordPacket(authMoreData.getSequencer(), password); } // retrieve public key from configuration or from server @@ -104,9 +105,9 @@ public ClientMessage next( state = State.SEND_AUTH; return new Sha256PasswordPacket( - authMoreDataPacket.getSequencer(), + authMoreData.getSequencer(), configuration.getPassword(), - authSwitchPacket.getSeed(), + authSwitch.getSeed(), publicKey); } @@ -118,7 +119,7 @@ public ClientMessage next( state = State.REQUEST_SERVER_KEY; // ask public Key Retrieval - return new Sha2PublicKeyRequestPacket(authMoreDataPacket.getSequencer()); + return new Sha2PublicKeyRequestPacket(authMoreData.getSequencer()); default: throw new R2dbcNonTransientResourceException( @@ -127,12 +128,12 @@ public ClientMessage next( } case REQUEST_SERVER_KEY: - publicKey = readPublicKey(authMoreDataPacket); + publicKey = readPublicKey(authMoreData.getBuf()); state = State.SEND_AUTH; return new Sha256PasswordPacket( - authMoreDataPacket.getSequencer(), + authMoreData.getSequencer(), configuration.getPassword(), - authSwitchPacket.getSeed(), + authSwitch.getSeed(), publicKey); default: diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/Ed25519PasswordPluginFlow.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/Ed25519PasswordPluginFlow.java similarity index 58% rename from src/main/java/org/mariadb/r2dbc/message/flow/Ed25519PasswordPluginFlow.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/Ed25519PasswordPluginFlow.java index f8503b39..593c47fc 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/Ed25519PasswordPluginFlow.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/Ed25519PasswordPluginFlow.java @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication.standard; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.authentication.AuthenticationPlugin; -import org.mariadb.r2dbc.message.client.ClientMessage; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; import org.mariadb.r2dbc.message.client.Ed25519PasswordPacket; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; public final class Ed25519PasswordPluginFlow implements AuthenticationPlugin { @@ -24,10 +24,10 @@ public String type() { public ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) { + AuthSwitch authSwitch, + AuthMoreData authMoreData) { return new Ed25519PasswordPacket( - authSwitchPacket.getSequencer(), configuration.getPassword(), authSwitchPacket.getSeed()); + authSwitch.getSequencer(), configuration.getPassword(), authSwitch.getSeed()); } } diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/NativePasswordPluginFlow.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/NativePasswordPluginFlow.java similarity index 58% rename from src/main/java/org/mariadb/r2dbc/message/flow/NativePasswordPluginFlow.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/NativePasswordPluginFlow.java index e82c7663..b180c2ff 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/NativePasswordPluginFlow.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/NativePasswordPluginFlow.java @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication.standard; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.authentication.AuthenticationPlugin; -import org.mariadb.r2dbc.message.client.ClientMessage; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; import org.mariadb.r2dbc.message.client.NativePasswordPacket; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; public final class NativePasswordPluginFlow implements AuthenticationPlugin { @@ -24,9 +24,9 @@ public String type() { public ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) { + AuthSwitch authSwitch, + AuthMoreData authMoreData) { return new NativePasswordPacket( - authSwitchPacket.getSequencer(), configuration.getPassword(), authSwitchPacket.getSeed()); + authSwitch.getSequencer(), configuration.getPassword(), authSwitch.getSeed()); } } diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/PamPluginFlow.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/PamPluginFlow.java similarity index 68% rename from src/main/java/org/mariadb/r2dbc/message/flow/PamPluginFlow.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/PamPluginFlow.java index ca6b05b6..cea9792a 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/PamPluginFlow.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/PamPluginFlow.java @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication.standard; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.authentication.AuthenticationPlugin; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; import org.mariadb.r2dbc.message.client.ClearPasswordPacket; -import org.mariadb.r2dbc.message.client.ClientMessage; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; public final class PamPluginFlow implements AuthenticationPlugin { @@ -26,13 +26,12 @@ public String type() { public ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) { + AuthSwitch authSwitch, + AuthMoreData authMoreData) { while (true) { counter++; if (counter == 0) { - return new ClearPasswordPacket( - authSwitchPacket.getSequencer(), configuration.getPassword()); + return new ClearPasswordPacket(authSwitch.getSequencer(), configuration.getPassword()); } else { if (configuration.getPamOtherPwd() == null) { throw new IllegalArgumentException( @@ -45,7 +44,7 @@ public ClientMessage next( counter, configuration.getPamOtherPwd().length)); } return new ClearPasswordPacket( - authSwitchPacket.getSequencer(), configuration.getPamOtherPwd()[counter - 1]); + authSwitch.getSequencer(), configuration.getPamOtherPwd()[counter - 1]); } } } diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/Sha256PasswordPluginFlow.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/Sha256PasswordPluginFlow.java similarity index 80% rename from src/main/java/org/mariadb/r2dbc/message/flow/Sha256PasswordPluginFlow.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/Sha256PasswordPluginFlow.java index b62157e7..c618f204 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/Sha256PasswordPluginFlow.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/Sha256PasswordPluginFlow.java @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.flow; +package org.mariadb.r2dbc.authentication.standard; import io.netty.buffer.ByteBuf; import io.r2dbc.spi.R2dbcException; @@ -16,12 +16,12 @@ import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.SslMode; import org.mariadb.r2dbc.authentication.AuthenticationPlugin; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.ClientMessage; import org.mariadb.r2dbc.message.client.ClearPasswordPacket; -import org.mariadb.r2dbc.message.client.ClientMessage; import org.mariadb.r2dbc.message.client.RsaPublicKeyRequestPacket; import org.mariadb.r2dbc.message.client.Sha256PasswordPacket; -import org.mariadb.r2dbc.message.server.AuthMoreDataPacket; -import org.mariadb.r2dbc.message.server.AuthSwitchPacket; public class Sha256PasswordPluginFlow implements AuthenticationPlugin { @@ -77,13 +77,11 @@ public static PublicKey generatePublicKey(byte[] publicKeyBytes) throws R2dbcExc /** * Read public Key * - * @param authMoreDataPacket More data packet + * @param buf more data buffer * @return public key * @throws R2dbcException if server return an Error packet or public key cannot be parsed. */ - public static PublicKey readPublicKey(AuthMoreDataPacket authMoreDataPacket) - throws R2dbcException { - ByteBuf buf = authMoreDataPacket.getBuf(); + public static PublicKey readPublicKey(ByteBuf buf) throws R2dbcException { byte[] key = new byte[buf.readableBytes()]; buf.readBytes(key); return generatePublicKey(key); @@ -99,14 +97,14 @@ public String type() { public ClientMessage next( MariadbConnectionConfiguration configuration, - AuthSwitchPacket authSwitchPacket, - AuthMoreDataPacket authMoreDataPacket) + AuthSwitch authSwitch, + AuthMoreData authMoreData) throws R2dbcException { if (state == State.INIT) { CharSequence password = configuration.getPassword(); if (password == null || configuration.getSslConfig().getSslMode() != SslMode.DISABLE) { - return new ClearPasswordPacket(authSwitchPacket.getSequencer(), password); + return new ClearPasswordPacket(authSwitch.getSequencer(), password); } else { // retrieve public key from configuration or from server if (configuration.getRsaPublicKey() != null && !configuration.getRsaPublicKey().isEmpty()) { @@ -119,20 +117,17 @@ public ClientMessage next( } state = State.REQUEST_SERVER_KEY; // ask public Key Retrieval - return new RsaPublicKeyRequestPacket(authSwitchPacket.getSequencer()); + return new RsaPublicKeyRequestPacket(authSwitch.getSequencer()); } } return new Sha256PasswordPacket( - authSwitchPacket.getSequencer(), - configuration.getPassword(), - authSwitchPacket.getSeed(), - publicKey); + authSwitch.getSequencer(), configuration.getPassword(), authSwitch.getSeed(), publicKey); } else { - publicKey = readPublicKey(authMoreDataPacket); + publicKey = readPublicKey(authMoreData.getBuf()); return new Sha256PasswordPacket( - authMoreDataPacket.getSequencer(), + authMoreData.getSequencer(), configuration.getPassword(), - authSwitchPacket.getSeed(), + authSwitch.getSeed(), publicKey); } } diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/README b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/README similarity index 100% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/README rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/README diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/Utils.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/Utils.java similarity index 95% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/Utils.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/Utils.java index ecd52ef5..98a03198 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/Utils.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/Utils.java @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519; /** * Basic utilities for EdDSA. Not for external use, not maintained as a public API. diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Constants.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Constants.java similarity index 82% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Constants.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Constants.java index 8821d0ae..cede61ea 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Constants.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Constants.java @@ -1,8 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math; -import org.mariadb.r2dbc.authentication.ed25519.Utils; +import org.mariadb.r2dbc.authentication.standard.ed25519.Utils; final class Constants { diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Curve.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Curve.java similarity index 94% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Curve.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Curve.java index 4087535a..de553569 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Curve.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Curve.java @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math; import java.io.Serializable; diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Encoding.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Encoding.java similarity index 92% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Encoding.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Encoding.java index 3e41b883..9a2be17f 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Encoding.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Encoding.java @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math; /** * Common interface for all $(b-1)$-bit encodings of elements of EdDSA finite fields. diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Field.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Field.java similarity index 94% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Field.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Field.java index 4795b11e..eea773dd 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/Field.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/Field.java @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math; import java.io.Serializable; diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/FieldElement.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/FieldElement.java similarity index 93% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/FieldElement.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/FieldElement.java index 70a66675..93b5962c 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/FieldElement.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/FieldElement.java @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math; import java.io.Serializable; diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/GroupElement.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/GroupElement.java similarity index 99% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/GroupElement.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/GroupElement.java index 9310c95f..a06e2641 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/GroupElement.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/GroupElement.java @@ -1,10 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math; import java.io.Serializable; import java.util.Arrays; -import org.mariadb.r2dbc.authentication.ed25519.Utils; +import org.mariadb.r2dbc.authentication.standard.ed25519.Utils; /** * A point $(x,y)$ on an EdDSA curve. diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/Ed25519FieldElement.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/Ed25519FieldElement.java similarity index 98% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/Ed25519FieldElement.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/Ed25519FieldElement.java index 278c50a3..71930adc 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/Ed25519FieldElement.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/Ed25519FieldElement.java @@ -1,11 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math.ed25519; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519; import java.util.Arrays; -import org.mariadb.r2dbc.authentication.ed25519.Utils; -import org.mariadb.r2dbc.authentication.ed25519.math.Field; -import org.mariadb.r2dbc.authentication.ed25519.math.FieldElement; +import org.mariadb.r2dbc.authentication.standard.ed25519.Utils; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.Field; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.FieldElement; /** * Class to represent a field element of the finite field $p = 2^{255} - 19$ elements. diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/Ed25519LittleEndianEncoding.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/Ed25519LittleEndianEncoding.java similarity index 92% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/Ed25519LittleEndianEncoding.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/Ed25519LittleEndianEncoding.java index c2dec2de..15d488b5 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/Ed25519LittleEndianEncoding.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/Ed25519LittleEndianEncoding.java @@ -1,9 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math.ed25519; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519; -import org.mariadb.r2dbc.authentication.ed25519.math.Encoding; -import org.mariadb.r2dbc.authentication.ed25519.math.FieldElement; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.Encoding; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.FieldElement; /** * Helper class for encoding/decoding from/to the 32 byte representation. @@ -35,9 +35,9 @@ static long load_4(byte[] in, int offset) { *
  • Convert the field element to the 32 byte representation. * * - *

    The idea for the modulo $p$ reduction algorithm is as follows: + * The idea for the modulo $p$ reduction algorithm is as follows: * - *

    Assumption:

    + *

    Assumption: * *

      *
    • $p = 2^{255} - 19$ @@ -49,7 +49,7 @@ static long load_4(byte[] in, int offset) { * *

      Then $q = [2^{-255} * (h + 19 * 2^{-25} * h_9 + 1/2)]$ where $[x] = floor(x)$. * - *

      Proof:

      + *

      Proof: * *

      We begin with some very raw estimation for the bounds of some expressions: * @@ -58,13 +58,9 @@ static long load_4(byte[] in, int offset) { * 2^{230} * h_9| = |h_0 + \dots + 2^{204} * h_8| \lt 2^{204} * 2^{30} = 2^{234}. \\ \Rightarrow * -1/4 \le b := 19 * 2^{-255} * (h - 2^{230} * h_9) \lt 1/4 \end{equation} $$ * - *

      Therefore $0 \lt 1/2 - a - b \lt 1$. - * - *

      Set $x := r + 19 * 2^{-255} * r + 1/2 - a - b$. Then: - * - *

      $$ 0 \le x \lt 255 - 20 + 19 + 1 = 2^{255} \\ \Rightarrow 0 \le 2^{-255} * x \lt 1. $$ - * - *

      Since $q$ is an integer we have + *

      Therefore $0 \lt 1/2 - a - b \lt 1$. Set $x := r + 19 * 2^{-255} * r + 1/2 - a - b$. Then: + * $$ 0 \le x \lt 255 - 20 + 19 + 1 = 2^{255} \\ \Rightarrow 0 \le 2^{-255} * x \lt 1. $$ Since + * $q$ is an integer we have * *

      $$ [q + 2^{-255} * x] = q \quad (1) $$ * diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/ScalarOps.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/ScalarOps.java similarity index 98% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/ScalarOps.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/ScalarOps.java index 9fc48722..f8a408ec 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/math/ed25519/ScalarOps.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/math/ed25519/ScalarOps.java @@ -1,9 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.math.ed25519; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519; -import static org.mariadb.r2dbc.authentication.ed25519.math.ed25519.Ed25519LittleEndianEncoding.load_3; -import static org.mariadb.r2dbc.authentication.ed25519.math.ed25519.Ed25519LittleEndianEncoding.load_4; +import static org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.Ed25519LittleEndianEncoding.load_3; +import static org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.Ed25519LittleEndianEncoding.load_4; /** * Class for reducing a huge integer modulo the group order q and doing a combined multiply plus add diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSANamedCurveSpec.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSANamedCurveSpec.java similarity index 60% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSANamedCurveSpec.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSANamedCurveSpec.java index cda76c5c..0e208056 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSANamedCurveSpec.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSANamedCurveSpec.java @@ -1,10 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.spec; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.spec; -import org.mariadb.r2dbc.authentication.ed25519.math.Curve; -import org.mariadb.r2dbc.authentication.ed25519.math.GroupElement; -import org.mariadb.r2dbc.authentication.ed25519.math.ed25519.ScalarOps; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.Curve; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.GroupElement; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.ScalarOps; /** * EdDSA Curve specification that can also be referred to by name. diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSANamedCurveTable.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSANamedCurveTable.java similarity index 79% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSANamedCurveTable.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSANamedCurveTable.java index e5df7fe6..dc8918f5 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSANamedCurveTable.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSANamedCurveTable.java @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.spec; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.spec; import java.util.Hashtable; import java.util.Locale; -import org.mariadb.r2dbc.authentication.ed25519.Utils; -import org.mariadb.r2dbc.authentication.ed25519.math.Curve; -import org.mariadb.r2dbc.authentication.ed25519.math.Field; -import org.mariadb.r2dbc.authentication.ed25519.math.ed25519.Ed25519LittleEndianEncoding; -import org.mariadb.r2dbc.authentication.ed25519.math.ed25519.ScalarOps; +import org.mariadb.r2dbc.authentication.standard.ed25519.Utils; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.Curve; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.Field; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.Ed25519LittleEndianEncoding; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.ScalarOps; /** * The named EdDSA curves. diff --git a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSAParameterSpec.java b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSAParameterSpec.java similarity index 86% rename from src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSAParameterSpec.java rename to src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSAParameterSpec.java index 2659f27f..a5678345 100644 --- a/src/main/java/org/mariadb/r2dbc/authentication/ed25519/spec/EdDSAParameterSpec.java +++ b/src/main/java/org/mariadb/r2dbc/authentication/standard/ed25519/spec/EdDSAParameterSpec.java @@ -1,14 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab -package org.mariadb.r2dbc.authentication.ed25519.spec; +// Copyright (c) 2020-2022 MariaDB Corporation Ab +package org.mariadb.r2dbc.authentication.standard.ed25519.spec; import java.io.Serializable; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.security.spec.AlgorithmParameterSpec; -import org.mariadb.r2dbc.authentication.ed25519.math.Curve; -import org.mariadb.r2dbc.authentication.ed25519.math.GroupElement; -import org.mariadb.r2dbc.authentication.ed25519.math.ed25519.ScalarOps; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.Curve; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.GroupElement; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.ScalarOps; /** * Parameter specification for an EdDSA algorithm. diff --git a/src/main/java/org/mariadb/r2dbc/client/Client.java b/src/main/java/org/mariadb/r2dbc/client/Client.java index 26681dff..69c03934 100644 --- a/src/main/java/org/mariadb/r2dbc/client/Client.java +++ b/src/main/java/org/mariadb/r2dbc/client/Client.java @@ -1,16 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.client; +import io.r2dbc.spi.TransactionDefinition; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.message.client.ClientMessage; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.client.ExecutePacket; import org.mariadb.r2dbc.message.client.PreparePacket; import org.mariadb.r2dbc.message.client.SslRequestPacket; import org.mariadb.r2dbc.message.server.InitialHandshakePacket; -import org.mariadb.r2dbc.message.server.ServerMessage; +import org.mariadb.r2dbc.util.HostAddress; import org.mariadb.r2dbc.util.PrepareCache; +import org.mariadb.r2dbc.util.ServerPrepareResult; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -18,39 +23,47 @@ public interface Client { Mono close(); - Flux receive(DecoderState initialState); - void sendCommandWithoutResult(ClientMessage requests); - Flux sendCommand(ClientMessage requests); + Flux sendCommand(ClientMessage requests, boolean canSafelyBeReExecuted); + + Flux sendCommand( + ClientMessage requests, DecoderState initialState, boolean canSafelyBeReExecuted); - Flux sendCommand(ClientMessage requests, DecoderState initialState); + Flux sendCommand( + ClientMessage requests, DecoderState initialState, String sql, boolean canSafelyBeReExecuted); - Flux sendCommand(ClientMessage requests, DecoderState initialState, String sql); + Flux sendCommand( + PreparePacket preparePacket, ExecutePacket executePacket, boolean canSafelyBeReExecuted); - Flux sendCommand(PreparePacket preparePacket, ExecutePacket executePacket); + Mono sendPrepare( + ClientMessage requests, ExceptionFactory factory, String sql); Mono sendSslRequest( SslRequestPacket sslRequest, MariadbConnectionConfiguration configuration); boolean isAutoCommit(); + boolean isInTransaction(); + boolean noBackslashEscapes(); ServerVersion getVersion(); boolean isConnected(); - void setContext(InitialHandshakePacket packet); + boolean isCloseRequested(); - void sendNext(); + void setContext(InitialHandshakePacket packet, long clientCapabilities); - MariadbConnectionConfiguration getConf(); + Context getContext(); PrepareCache getPrepareCache(); Mono beginTransaction(); + Mono beginTransaction(TransactionDefinition definition); + Mono commitTransaction(); Mono rollbackTransaction(); @@ -59,9 +72,7 @@ Mono sendSslRequest( Mono rollbackTransactionToSavepoint(String name); - Mono releaseSavepoint(String name); - - Mono createSavepoint(String name); - long getThreadId(); + + HostAddress getHostAddress(); } diff --git a/src/main/java/org/mariadb/r2dbc/client/ClientBase.java b/src/main/java/org/mariadb/r2dbc/client/ClientBase.java deleted file mode 100644 index ee335aaa..00000000 --- a/src/main/java/org/mariadb/r2dbc/client/ClientBase.java +++ /dev/null @@ -1,405 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.client; - -import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; -import io.netty.handler.logging.LogLevel; -import io.netty.handler.logging.LoggingHandler; -import io.netty.handler.ssl.SslHandler; -import io.netty.util.concurrent.Future; -import io.netty.util.concurrent.GenericFutureListener; -import io.r2dbc.spi.R2dbcNonTransientResourceException; -import io.r2dbc.spi.R2dbcTransientResourceException; -import java.util.Queue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Consumer; -import javax.net.ssl.SSLEngine; -import javax.net.ssl.SSLException; -import org.mariadb.r2dbc.ExceptionFactory; -import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.message.client.ClientMessage; -import org.mariadb.r2dbc.message.client.QueryPacket; -import org.mariadb.r2dbc.message.client.QuitPacket; -import org.mariadb.r2dbc.message.client.SslRequestPacket; -import org.mariadb.r2dbc.message.server.InitialHandshakePacket; -import org.mariadb.r2dbc.message.server.ServerMessage; -import org.mariadb.r2dbc.util.PrepareCache; -import org.mariadb.r2dbc.util.constants.ServerStatus; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; -import reactor.netty.Connection; -import reactor.netty.tcp.TcpClient; -import reactor.util.Logger; -import reactor.util.Loggers; -import reactor.util.concurrent.Queues; - -public abstract class ClientBase implements Client { - - private static final Logger logger = Loggers.getLogger(ClientBase.class); - protected final ReentrantLock lock = new ReentrantLock(); - private final MariadbConnectionConfiguration configuration; - protected final Connection connection; - protected final Queue responseReceivers = Queues.unbounded().get(); - private final AtomicBoolean isClosed = new AtomicBoolean(false); - private final MariadbPacketDecoder mariadbPacketDecoder; - private final MariadbPacketEncoder mariadbPacketEncoder = new MariadbPacketEncoder(); - protected volatile Context context; - private final PrepareCache prepareCache; - - protected ClientBase(Connection connection, MariadbConnectionConfiguration configuration) { - this.connection = connection; - this.configuration = configuration; - this.prepareCache = - this.configuration.useServerPrepStmts() - ? new PrepareCache(this.configuration.getPrepareCacheSize(), this) - : null; - this.mariadbPacketDecoder = new MariadbPacketDecoder(responseReceivers, this); - - connection.addHandler(mariadbPacketDecoder); - connection.addHandler(mariadbPacketEncoder); - - if (logger.isTraceEnabled()) { - connection.addHandlerFirst( - LoggingHandler.class.getSimpleName(), - new LoggingHandler(ClientBase.class, LogLevel.TRACE)); - } - - connection - .inbound() - .receive() - .doOnError(this::handleConnectionError) - .doOnComplete(this::closedServlet) - .then() - .subscribe(); - } - - public static TcpClient setSocketOption( - MariadbConnectionConfiguration configuration, TcpClient tcpClient) { - if (configuration.getConnectTimeout() != null) { - tcpClient = - tcpClient.option( - ChannelOption.CONNECT_TIMEOUT_MILLIS, - Math.toIntExact(configuration.getConnectTimeout().toMillis())); - } - - if (configuration.getSocketTimeout() != null) { - tcpClient = - tcpClient.option( - ChannelOption.SO_TIMEOUT, - Math.toIntExact(configuration.getSocketTimeout().toMillis())); - } - - if (configuration.isTcpKeepAlive()) { - tcpClient = tcpClient.option(ChannelOption.SO_KEEPALIVE, configuration.isTcpKeepAlive()); - } - - if (configuration.isTcpAbortiveClose()) { - tcpClient = tcpClient.option(ChannelOption.SO_LINGER, 0); - } - return tcpClient; - } - - private void handleConnectionError(Throwable throwable) { - R2dbcNonTransientResourceException err; - if (this.isClosed.compareAndSet(false, true)) { - err = - new R2dbcNonTransientResourceException("Connection unexpected error", "08000", throwable); - logger.error("Connection unexpected error", throwable); - Channel channel = this.connection.channel(); - if (!channel.isOpen()) { - this.connection.dispose(); - } - } else { - err = new R2dbcNonTransientResourceException("Connection error", "08000", throwable); - logger.error("Connection error", throwable); - } - clearWaitingListWithError(err); - } - - @Override - public Mono close() { - return Mono.defer( - () -> { - if (this.isClosed.compareAndSet(false, true)) { - - Channel channel = this.connection.channel(); - if (!channel.isOpen()) { - this.connection.dispose(); - return this.connection.onDispose(); - } - - return Flux.just(QuitPacket.INSTANCE) - .doOnNext(message -> connection.channel().writeAndFlush(message)) - .then() - .doOnSuccess(v -> this.connection.dispose()) - .then(this.connection.onDispose()); - } - - return Mono.empty(); - }); - } - - public Flux sendCommand(ClientMessage message) { - return sendCommand(message, DecoderState.QUERY_RESPONSE); - } - - @Override - public Mono sendSslRequest( - SslRequestPacket sslRequest, MariadbConnectionConfiguration configuration) { - CompletableFuture result = new CompletableFuture<>(); - try { - SSLEngine engine = - configuration.getSslConfig().getSslContext().newEngine(connection.channel().alloc()); - final SslHandler sslHandler = new SslHandler(engine); - - final GenericFutureListener> listener = - configuration - .getSslConfig() - .getHostNameVerifier(result, configuration.getHost(), context.getThreadId(), engine); - - sslHandler.handshakeFuture().addListener(listener); - // send SSL request in clear - connection.channel().writeAndFlush(sslRequest); - - // add SSL handler - connection.addHandlerFirst(sslHandler); - return Mono.fromFuture(result); - - } catch (SSLException | R2dbcTransientResourceException e) { - result.completeExceptionally(e); - return Mono.fromFuture(result); - } - } - - public Flux sendCommand(ClientMessage message, DecoderState initialState) { - return sendCommand(message, initialState, null); - } - - public abstract Flux sendCommand( - ClientMessage message, DecoderState initialState, String sql); - - private Flux execute(Consumer> s) { - AtomicBoolean atomicBoolean = new AtomicBoolean(); - return Flux.create( - sink -> { - if (!isConnected()) { - sink.error( - new R2dbcNonTransientResourceException( - "Connection is close. Cannot send anything")); - return; - } - if (atomicBoolean.compareAndSet(false, true)) { - try { - lock.lock(); - s.accept(sink); - } finally { - lock.unlock(); - } - } - }); - } - - abstract void begin(FluxSink sink); - - abstract void executeWhenTransaction(FluxSink sink, String cmd); - - abstract void executeAutoCommit(FluxSink sink, boolean autoCommit); - - public long getThreadId() { - return context.getThreadId(); - } - /** - * Specific implementation, to avoid executing BEGIN if already in transaction - * - * @return publisher - */ - public Mono beginTransaction() { - try { - lock.lock(); - return execute(sink -> begin(sink)) - .handle(ExceptionFactory.withSql("BEGIN")::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - /** - * Specific implementation, to avoid executing COMMIT if no transaction - * - * @return publisher - */ - public Mono commitTransaction() { - try { - lock.lock(); - return execute(sink -> executeWhenTransaction(sink, "COMMIT")) - .handle(ExceptionFactory.withSql("COMMIT")::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - /** - * Specific implementation, to avoid executing ROLLBACK if no transaction - * - * @return publisher - */ - public Mono rollbackTransaction() { - try { - lock.lock(); - return execute(sink -> executeWhenTransaction(sink, "ROLLBACK")) - .handle(ExceptionFactory.withSql("ROLLBACK")::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - /** - * Specific implementation, to avoid executing ROLLBACK TO TRANSACTION if no transaction - * - * @return publisher - */ - public Mono rollbackTransactionToSavepoint(String name) { - try { - lock.lock(); - String cmd = String.format("ROLLBACK TO SAVEPOINT `%s`", name.replace("`", "``")); - return execute(sink -> executeWhenTransaction(sink, cmd)) - .handle(ExceptionFactory.withSql(cmd)::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - public Mono releaseSavepoint(String name) { - try { - lock.lock(); - String cmd = String.format("RELEASE SAVEPOINT `%s`", name.replace("`", "``")); - return sendCommand(new QueryPacket(cmd)) - .handle(ExceptionFactory.withSql(cmd)::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - public Mono createSavepoint(String name) { - try { - lock.lock(); - String cmd = String.format("SAVEPOINT `%s`", name.replace("`", "``")); - return sendCommand(new QueryPacket(cmd)) - .handle(ExceptionFactory.withSql(cmd)::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - /** - * Specific implementation, to avoid changing autocommit mode if already in this autocommit mode - * - * @return publisher - */ - public Mono setAutoCommit(boolean autoCommit) { - try { - lock.lock(); - return execute(sink -> executeAutoCommit(sink, autoCommit)) - .handle(ExceptionFactory.withSql(null)::handleErrorResponse) - .then(); - } finally { - lock.unlock(); - } - } - - @Override - public Flux receive(DecoderState initialState) { - return Flux.create( - sink -> { - this.responseReceivers.add(new CmdElement(sink, initialState)); - }); - } - - public void setContext(InitialHandshakePacket handshake) { - this.context = - new Context( - handshake.getServerVersion(), - handshake.getThreadId(), - handshake.getCapabilities(), - handshake.getServerStatus(), - handshake.isMariaDBServer()); - mariadbPacketDecoder.setContext(context); - mariadbPacketEncoder.setContext(context); - } - - /** - * Get current server autocommit. - * - * @return autocommit current server value. - */ - @Override - public boolean isAutoCommit() { - return (this.context.getServerStatus() & ServerStatus.AUTOCOMMIT) > 0; - } - - @Override - public boolean noBackslashEscapes() { - return (this.context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0; - } - - @Override - public ServerVersion getVersion() { - return (this.context != null) ? this.context.getVersion() : ServerVersion.UNKNOWN_VERSION; - } - - @Override - public boolean isConnected() { - if (this.isClosed.get()) { - return false; - } - return this.connection.channel().isOpen(); - } - - private void closedServlet() { - if (this.isClosed.compareAndSet(false, true)) { - Channel channel = this.connection.channel(); - if (!channel.isOpen()) { - this.connection.dispose(); - } - clearWaitingListWithError( - ExceptionFactory.INSTANCE.createException("Connection unexpectedly closed", "08000", -1)); - - } else { - clearWaitingListWithError(new R2dbcNonTransientResourceException("Connection closed")); - } - } - - private void clearWaitingListWithError(Throwable exception) { - mariadbPacketDecoder.connectionError(exception); - CmdElement response; - while ((response = this.responseReceivers.poll()) != null) { - response.getSink().error(exception); - } - } - - public MariadbConnectionConfiguration getConf() { - return configuration; - } - - public abstract void sendNext(); - - public PrepareCache getPrepareCache() { - return prepareCache; - } - - @Override - public String toString() { - return "Client{isClosed=" + isClosed + ", context=" + context + '}'; - } -} diff --git a/src/main/java/org/mariadb/r2dbc/client/ClientImpl.java b/src/main/java/org/mariadb/r2dbc/client/ClientImpl.java deleted file mode 100644 index e8898a92..00000000 --- a/src/main/java/org/mariadb/r2dbc/client/ClientImpl.java +++ /dev/null @@ -1,147 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.client; - -import io.r2dbc.spi.R2dbcNonTransientResourceException; -import java.net.SocketAddress; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; -import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.message.client.ClientMessage; -import org.mariadb.r2dbc.message.client.ExecutePacket; -import org.mariadb.r2dbc.message.client.PreparePacket; -import org.mariadb.r2dbc.message.client.QueryPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; -import org.mariadb.r2dbc.util.constants.ServerStatus; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; -import reactor.netty.Connection; -import reactor.netty.resources.ConnectionProvider; -import reactor.netty.tcp.TcpClient; -import reactor.util.Logger; -import reactor.util.Loggers; -import reactor.util.concurrent.Queues; - -/** Client that only send query one by one. */ -public final class ClientImpl extends ClientBase { - private static final Logger logger = Loggers.getLogger(ClientImpl.class); - - public ClientImpl(Connection connection, MariadbConnectionConfiguration configuration) { - super(connection, configuration); - } - - protected final Queue sendingQueue = Queues.unbounded().get(); - - public static Mono connect( - ConnectionProvider connectionProvider, - SocketAddress socketAddress, - MariadbConnectionConfiguration configuration) { - - TcpClient tcpClient = TcpClient.create(connectionProvider).remoteAddress(() -> socketAddress); - tcpClient = setSocketOption(configuration, tcpClient); - return tcpClient.connect().flatMap(it -> Mono.just(new ClientImpl(it, configuration))); - } - - public void sendCommandWithoutResult(ClientMessage message) { - try { - lock.lock(); - if (this.responseReceivers.isEmpty()) { - connection.channel().writeAndFlush(message); - } else { - sendingQueue.add(message); - } - } finally { - lock.unlock(); - } - } - - public Flux sendCommand(PreparePacket preparePacket, ExecutePacket executePacket) { - return Flux.error(new R2dbcNonTransientResourceException("Cannot pipeline")); - } - - public Flux sendCommand( - ClientMessage message, DecoderState initialState, String sql) { - AtomicBoolean atomicBoolean = new AtomicBoolean(); - return Flux.create( - sink -> { - if (!isConnected()) { - sink.error( - new R2dbcNonTransientResourceException( - "Connection is close. Cannot send anything")); - return; - } - if (atomicBoolean.compareAndSet(false, true)) { - try { - lock.lock(); - if (this.responseReceivers.isEmpty()) { - this.responseReceivers.add(new CmdElement(sink, initialState, sql)); - connection.channel().writeAndFlush(message); - } else { - this.responseReceivers.add(new CmdElement(sink, initialState, sql)); - sendingQueue.add(message); - } - } finally { - lock.unlock(); - } - } - }); - } - - protected void begin(FluxSink sink) { - if (this.responseReceivers.isEmpty()) { - if ((context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0) { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, "BEGIN")); - connection.channel().writeAndFlush(new QueryPacket("BEGIN")); - } else { - logger.debug("Skipping begin transaction because already in transaction"); - sink.complete(); - } - } else { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, "BEGIN")); - sendingQueue.add(new QueryPacket("BEGIN")); - } - } - - protected void executeAutoCommit(FluxSink sink, boolean autoCommit) { - String cmd = "SET autocommit=" + (autoCommit ? '1' : '0'); - if (this.responseReceivers.isEmpty()) { - if (autoCommit != isAutoCommit()) { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, cmd)); - connection.channel().writeAndFlush(new QueryPacket(cmd)); - } else { - logger.debug("Skipping autocommit since already in that state"); - sink.complete(); - } - } else { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, cmd)); - sendingQueue.add(new QueryPacket(cmd)); - } - } - - protected void executeWhenTransaction(FluxSink sink, String cmd) { - if (this.responseReceivers.isEmpty()) { - if ((context.getServerStatus() & ServerStatus.IN_TRANSACTION) > 0) { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, cmd)); - connection.channel().writeAndFlush(new QueryPacket(cmd)); - } else { - logger.debug(String.format("Skipping '%s' because no active transaction", cmd)); - sink.complete(); - } - } else { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, cmd)); - sendingQueue.add(new QueryPacket(cmd)); - } - } - - public void sendNext() { - lock.lock(); - try { - ClientMessage next = sendingQueue.poll(); - if (next != null) connection.channel().writeAndFlush(next); - } finally { - lock.unlock(); - } - } -} diff --git a/src/main/java/org/mariadb/r2dbc/client/ClientPipelineImpl.java b/src/main/java/org/mariadb/r2dbc/client/ClientPipelineImpl.java deleted file mode 100644 index 46713d0f..00000000 --- a/src/main/java/org/mariadb/r2dbc/client/ClientPipelineImpl.java +++ /dev/null @@ -1,134 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.client; - -import io.r2dbc.spi.R2dbcNonTransientResourceException; -import java.net.SocketAddress; -import java.util.concurrent.atomic.AtomicBoolean; -import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.message.client.ClientMessage; -import org.mariadb.r2dbc.message.client.ExecutePacket; -import org.mariadb.r2dbc.message.client.PreparePacket; -import org.mariadb.r2dbc.message.client.QueryPacket; -import org.mariadb.r2dbc.message.server.ServerMessage; -import org.mariadb.r2dbc.util.constants.ServerStatus; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; -import reactor.netty.Connection; -import reactor.netty.resources.ConnectionProvider; -import reactor.netty.tcp.TcpClient; -import reactor.util.Logger; -import reactor.util.Loggers; - -/** Client that send queries pipelining (without waiting for result). */ -public final class ClientPipelineImpl extends ClientBase { - private static final Logger logger = Loggers.getLogger(ClientPipelineImpl.class); - - public ClientPipelineImpl(Connection connection, MariadbConnectionConfiguration configuration) { - super(connection, configuration); - } - - public static Mono connect( - ConnectionProvider connectionProvider, - SocketAddress socketAddress, - MariadbConnectionConfiguration configuration) { - - TcpClient tcpClient = TcpClient.create(connectionProvider).remoteAddress(() -> socketAddress); - tcpClient = setSocketOption(configuration, tcpClient); - return tcpClient.connect().flatMap(it -> Mono.just(new ClientPipelineImpl(it, configuration))); - } - - public void sendCommandWithoutResult(ClientMessage message) { - try { - lock.lock(); - connection.channel().writeAndFlush(message); - } finally { - lock.unlock(); - } - } - - public Flux sendCommand(PreparePacket preparePacket, ExecutePacket executePacket) { - AtomicBoolean atomicBoolean = new AtomicBoolean(); - return Flux.create( - sink -> { - if (!isConnected()) { - sink.error( - new R2dbcNonTransientResourceException( - "Connection is close. Cannot send anything")); - return; - } - if (atomicBoolean.compareAndSet(false, true)) { - try { - lock.lock(); - this.responseReceivers.add( - new CmdElement( - sink, DecoderState.PREPARE_AND_EXECUTE_RESPONSE, preparePacket.getSql())); - connection.channel().writeAndFlush(preparePacket); - connection.channel().writeAndFlush(executePacket); - } finally { - lock.unlock(); - } - } - }); - } - - public Flux sendCommand( - ClientMessage message, DecoderState initialState, String sql) { - AtomicBoolean atomicBoolean = new AtomicBoolean(); - return Flux.create( - sink -> { - if (!isConnected()) { - sink.error( - new R2dbcNonTransientResourceException( - "Connection is close. Cannot send anything")); - return; - } - if (atomicBoolean.compareAndSet(false, true)) { - try { - lock.lock(); - this.responseReceivers.add(new CmdElement(sink, initialState, sql)); - connection.channel().writeAndFlush(message); - } finally { - lock.unlock(); - } - } - }); - } - - protected void begin(FluxSink sink) { - if (!responseReceivers.isEmpty() - || (context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0) { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, "BEGIN")); - connection.channel().writeAndFlush(new QueryPacket("BEGIN")); - } else { - logger.debug("Skipping begin transaction because already in transaction"); - sink.complete(); - } - } - - protected void executeAutoCommit(FluxSink sink, boolean autoCommit) { - String cmd = "SET autocommit=" + (autoCommit ? '1' : '0'); - if (this.responseReceivers.isEmpty() || autoCommit != isAutoCommit()) { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, cmd)); - connection.channel().writeAndFlush(new QueryPacket(cmd)); - } else { - logger.debug("Skipping autocommit since already in that state"); - sink.complete(); - } - } - - protected void executeWhenTransaction(FluxSink sink, String cmd) { - if (!responseReceivers.isEmpty() - || (context.getServerStatus() & ServerStatus.IN_TRANSACTION) > 0) { - this.responseReceivers.add(new CmdElement(sink, DecoderState.QUERY_RESPONSE, cmd)); - connection.channel().writeAndFlush(new QueryPacket(cmd)); - } else { - logger.debug(String.format("Skipping '%s' because no active transaction", cmd)); - sink.complete(); - } - } - - public void sendNext() {} -} diff --git a/src/main/java/org/mariadb/r2dbc/client/CmdElement.java b/src/main/java/org/mariadb/r2dbc/client/CmdElement.java deleted file mode 100644 index 75d3fd9b..00000000 --- a/src/main/java/org/mariadb/r2dbc/client/CmdElement.java +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.client; - -import org.mariadb.r2dbc.message.server.ServerMessage; -import reactor.core.publisher.FluxSink; - -public class CmdElement { - - private final FluxSink sink; - private final DecoderState initialState; - private final String sql; - - public CmdElement(FluxSink sink, DecoderState initialState) { - this.sink = sink; - this.initialState = initialState; - this.sql = null; - } - - public CmdElement(FluxSink sink, DecoderState initialState, String sql) { - this.sink = sink; - this.initialState = initialState; - this.sql = sql; - } - - public FluxSink getSink() { - return sink; - } - - public DecoderState getInitialState() { - return initialState; - } - - public String getSql() { - return sql; - } -} diff --git a/src/main/java/org/mariadb/r2dbc/client/Context.java b/src/main/java/org/mariadb/r2dbc/client/Context.java deleted file mode 100644 index 420fd6aa..00000000 --- a/src/main/java/org/mariadb/r2dbc/client/Context.java +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.client; - -public class Context { - - private final long threadId; - private final long serverCapabilities; - private short serverStatus; - private ServerVersion version; - - public Context( - String serverVersion, - long threadId, - long capabilities, - short serverStatus, - boolean mariaDBServer) { - - this.threadId = threadId; - this.serverCapabilities = capabilities; - this.serverStatus = serverStatus; - this.version = new ServerVersion(serverVersion, mariaDBServer); - } - - public long getThreadId() { - return threadId; - } - - public long getServerCapabilities() { - return serverCapabilities; - } - - public short getServerStatus() { - return serverStatus; - } - - public void setServerStatus(short serverStatus) { - this.serverStatus = serverStatus; - } - - public ServerVersion getVersion() { - return version; - } - - @Override - public String toString() { - return "ConnectionContext{" + "threadId=" + threadId + ", version=" + version + '}'; - } -} diff --git a/src/main/java/org/mariadb/r2dbc/client/DecoderState.java b/src/main/java/org/mariadb/r2dbc/client/DecoderState.java index 05ae26d8..d333f35d 100644 --- a/src/main/java/org/mariadb/r2dbc/client/DecoderState.java +++ b/src/main/java/org/mariadb/r2dbc/client/DecoderState.java @@ -1,16 +1,18 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.client; import io.netty.buffer.ByteBuf; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.server.*; import org.mariadb.r2dbc.util.ServerPrepareResult; import org.mariadb.r2dbc.util.constants.Capabilities; +import org.mariadb.r2dbc.util.constants.ServerStatus; public enum DecoderState implements DecoderStateInterface { INIT_HANDSHAKE { - public DecoderState decoder(short val, int len, long serverCapabilities) { + public DecoderState decoder(short val, int len) { switch (val) { case 255: // 0xFF return ERROR; @@ -20,35 +22,32 @@ public DecoderState decoder(short val, int len, long serverCapabilities) { } @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return InitialHandshakePacket.decode(sequencer, body); } }, OK_PACKET { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return OkPacket.decode(sequencer, body, decoder.getContext()); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { return QUERY_RESPONSE; } }, AUTHENTICATION_SWITCH { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return AuthSwitchPacket.decode(sequencer, body, decoder.getContext()); } }, AUTHENTICATION_SWITCH_RESPONSE { - public DecoderState decoder(short val, int len, long serverCapabilities) { + public DecoderState decoder(short val, int len) { switch (val) { case 1: return AUTHENTICATION_MORE_DATA; @@ -64,14 +63,13 @@ public DecoderState decoder(short val, int len, long serverCapabilities) { AUTHENTICATION_MORE_DATA { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return AuthMoreDataPacket.decode(sequencer, body, decoder.getContext()); } }, QUERY_RESPONSE { - public DecoderState decoder(short val, int len, long serverCapabilities) { + public DecoderState decoder(short val, int len) { switch (val) { case 0: return OK_PACKET; @@ -87,17 +85,16 @@ public DecoderState decoder(short val, int len, long serverCapabilities) { ColumnCountPacket columnCountPacket; @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { columnCountPacket = ColumnCountPacket.decode(sequencer, body, decoder.getContext()); decoder.setStateCounter(columnCountPacket.getColumnCount()); return columnCountPacket; } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { if (columnCountPacket.isMetaFollows()) return COLUMN_DEFINITION; - if ((decoder.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) { + if ((decoder.getClientCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) { return ROW_RESPONSE; } else { return EOF_INTERMEDIATE_RESPONSE; @@ -108,61 +105,74 @@ public DecoderState next(MariadbPacketDecoder decoder) { COLUMN_DEFINITION { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { decoder.decrementStateCounter(); - return ColumnDefinitionPacket.decode(sequencer, body, decoder.getContext(), false); + return ColumnDefinitionPacket.decode( + sequencer, body, decoder.getContext(), false, decoder.getConf()); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { if (decoder.getStateCounter() <= 0) { - if ((decoder.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) { - return ROW_RESPONSE; - } else { - return EOF_INTERMEDIATE_RESPONSE; - } + return EOF_INTERMEDIATE_RESPONSE; } return this; } }, EOF_INTERMEDIATE_RESPONSE { - @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { - return EofPacket.decode(sequencer, body, decoder.getContext(), false); + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { + EofPacket eof = EofPacket.decode(sequencer, body, decoder.getContext(), false); + decoder.setStateCounter((eof.getServerStatus() & ServerStatus.PS_OUT_PARAMETERS) > 0 ? 1 : 0); + return eof; } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { + // mysql has a broken protocol for output parameter, then driver need to know state + if (decoder.getStateCounter() > 0) { + decoder.setStateCounter(0); + return ROW_RESPONSE_OUT_PARAM; + } return ROW_RESPONSE; } }, EOF_END { - @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return EofPacket.decode(sequencer, body, decoder.getContext(), true); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { + return QUERY_RESPONSE; + } + }, + + EOF_END_OUT_PARAM { + + @Override + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { + // specific for mysql that break protocol, forgetting sometime to set PS_OUT_PARAMETERS and + // more importantly MORE_RESULTS_EXISTS + // breaking protocol + return EofPacket.decodeOutputParam(sequencer, body, decoder.getContext()); + } + + @Override + public DecoderState next(ServerMsgDecoder decoder) { return QUERY_RESPONSE; } }, ROW_RESPONSE { - public DecoderState decoder(short val, int len, long serverCapabilities) { + public DecoderState decoder(short val, int len) { switch (val) { case 254: - if ((serverCapabilities & Capabilities.CLIENT_DEPRECATE_EOF) == 0 && len < 0xffffff) { + if (len < 0xffffff) { return EOF_END; - } else if (len < 0xffffff) { - return OK_PACKET; } else { // normal ROW return ROW; @@ -175,22 +185,51 @@ public DecoderState decoder(short val, int len, long serverCapabilities) { } }, + ROW_RESPONSE_OUT_PARAM { + public DecoderState decoder(short val, int len) { + switch (val) { + case 254: + if (len < 0xffffff) { + return EOF_END_OUT_PARAM; + } else { + // normal ROW + return ROW_OUTPUT_PARAM; + } + case 255: // 0xFF + return ERROR; + default: + return ROW_OUTPUT_PARAM; + } + } + }, + ROW { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return new RowPacket(body); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { return ROW_RESPONSE; } }, + ROW_OUTPUT_PARAM { + @Override + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { + return new RowPacket(body); + } + + @Override + public DecoderState next(ServerMsgDecoder decoder) { + return ROW_RESPONSE_OUT_PARAM; + } + }, + PREPARE_RESPONSE { - public DecoderState decoder(short val, int len, long serverCapabilities) { + public DecoderState decoder(short val, int len) { switch (val) { case 255: // 0xFF return ERROR; @@ -200,12 +239,9 @@ public DecoderState decoder(short val, int len, long serverCapabilities) { } @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { - PrepareResultPacket packet = - PrepareResultPacket.decode(sequencer, body, decoder.getContext(), false); - decoder.setPrepare(packet); - if (packet.getNumParams() == 0 && packet.getNumColumns() == 0) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { + decoder.setPrepare(PrepareResultPacket.decode(sequencer, body, decoder.getContext(), false)); + if (decoder.getPrepare().getNumParams() == 0 && decoder.getPrepare().getNumColumns() == 0) { ServerPrepareResult serverPrepareResult = decoder.endPrepare(); return new CompletePrepareResult(serverPrepareResult, false); } @@ -213,25 +249,24 @@ public ServerMessage decode( } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { if (decoder.getPrepare().getNumParams() == 0) { - // if next, then columns > 0 + if (decoder.getPrepare().getNumColumns() == 0) { + decoder.setPrepare(null); + return QUERY_RESPONSE; + } decoder.setStateCounter(decoder.getPrepare().getNumColumns()); return PREPARE_COLUMN; } - // skip param and EOF if needed - decoder.setStateCounter( - decoder.getPrepare().getNumParams() - + (((decoder.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) - ? 0 - : 1)); + // skip param and EOF + decoder.setStateCounter(decoder.getPrepare().getNumParams()); return PREPARE_PARAMETER; } }, PREPARE_AND_EXECUTE_RESPONSE { - public DecoderState decoder(short val, int len, long serverCapabilities) { + public DecoderState decoder(short val, int len) { switch (val) { case 255: // 0xFF return ERROR_AND_EXECUTE_RESPONSE; @@ -241,8 +276,7 @@ public DecoderState decoder(short val, int len, long serverCapabilities) { } @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { decoder.setPrepare(PrepareResultPacket.decode(sequencer, body, decoder.getContext(), true)); if (decoder.getPrepare().getNumParams() == 0 && decoder.getPrepare().getNumColumns() == 0) { ServerPrepareResult serverPrepareResult = decoder.endPrepare(); @@ -252,7 +286,7 @@ public ServerMessage decode( } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { if (decoder.getPrepare().getNumParams() == 0) { if (decoder.getPrepare().getNumColumns() == 0) { decoder.setPrepare(null); @@ -261,12 +295,8 @@ public DecoderState next(MariadbPacketDecoder decoder) { decoder.setStateCounter(decoder.getPrepare().getNumColumns()); return PREPARE_COLUMN; } - // skip param and EOF if needed - decoder.setStateCounter( - decoder.getPrepare().getNumParams() - + (((decoder.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) - ? 0 - : 1)); + // skip param and EOF + decoder.setStateCounter(decoder.getPrepare().getNumParams()); return PREPARE_PARAMETER; } }, @@ -274,61 +304,61 @@ public DecoderState next(MariadbPacketDecoder decoder) { PREPARE_PARAMETER { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { decoder.decrementStateCounter(); - if (decoder.getStateCounter() == 0 && decoder.getPrepare().getNumColumns() == 0) { - // end parameter without columns - boolean ending = !decoder.getPrepare().isContinueOnEnd(); + return SkipPacket.decode(false); + } + + @Override + public DecoderState next(ServerMsgDecoder decoder) { + if (decoder.getStateCounter() == 0) { + return PREPARE_PARAMETER_EOF; + } + return this; + } + }, + + PREPARE_PARAMETER_EOF { + + @Override + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { + if (decoder.getPrepare().getNumColumns() == 0) { ServerPrepareResult serverPrepareResult = decoder.endPrepare(); - return new CompletePrepareResult(serverPrepareResult, ending); + return new CompletePrepareResult( + serverPrepareResult, decoder.getPrepare().isContinueOnEnd()); } return SkipPacket.decode(false); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { - if (decoder.getStateCounter() <= 0) { - if (decoder.getPrepare() == null) { - return QUERY_RESPONSE; - } + public DecoderState next(ServerMsgDecoder decoder) { + if (decoder.getPrepare().getNumColumns() > 0) { decoder.setStateCounter(decoder.getPrepare().getNumColumns()); return PREPARE_COLUMN; } - return PREPARE_PARAMETER; + decoder.setPrepare(null); + return QUERY_RESPONSE; } }, PREPARE_COLUMN { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { ColumnDefinitionPacket columnDefinitionPacket = - ColumnDefinitionPacket.decode(sequencer, body, decoder.getContext(), false); + ColumnDefinitionPacket.decode( + sequencer, body, decoder.getContext(), false, decoder.getConf()); decoder .getPrepareColumns()[ decoder.getPrepare().getNumColumns() - decoder.getStateCounter()] = columnDefinitionPacket; decoder.decrementStateCounter(); - - if (decoder.getStateCounter() <= 0) { - if ((decoder.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) { - boolean ending = !decoder.getPrepare().isContinueOnEnd(); - ServerPrepareResult prepareResult = decoder.endPrepare(); - return new CompletePrepareResult(prepareResult, ending); - } - } - return SkipPacket.decode(false); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { if (decoder.getStateCounter() <= 0) { - if ((decoder.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0) { - return QUERY_RESPONSE; - } return PREPARE_COLUMN_EOF; } return this; @@ -338,15 +368,14 @@ public DecoderState next(MariadbPacketDecoder decoder) { PREPARE_COLUMN_EOF { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { - boolean ending = !decoder.getPrepare().isContinueOnEnd(); + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { + boolean continueOnEnd = decoder.getPrepare().isContinueOnEnd(); ServerPrepareResult prepareResult = decoder.endPrepare(); - return new CompletePrepareResult(prepareResult, ending); + return new CompletePrepareResult(prepareResult, continueOnEnd); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { return QUERY_RESPONSE; } }, @@ -354,13 +383,12 @@ public DecoderState next(MariadbPacketDecoder decoder) { ERROR { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return ErrorPacket.decode(sequencer, body, true); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { throw new IllegalArgumentException("unexpected state"); } }, @@ -368,13 +396,12 @@ public DecoderState next(MariadbPacketDecoder decoder) { ERROR_AND_EXECUTE_RESPONSE { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { return ErrorPacket.decode(sequencer, body, false); } @Override - public DecoderState next(MariadbPacketDecoder decoder) { + public DecoderState next(ServerMsgDecoder decoder) { return SKIP_EXECUTE; } }, @@ -382,8 +409,7 @@ public DecoderState next(MariadbPacketDecoder decoder) { SKIP_EXECUTE { @Override - public ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + public ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { decoder.decrementStateCounter(); return SkipPacket.decode(true); } diff --git a/src/main/java/org/mariadb/r2dbc/client/DecoderStateInterface.java b/src/main/java/org/mariadb/r2dbc/client/DecoderStateInterface.java index c226e624..b1de4c0b 100644 --- a/src/main/java/org/mariadb/r2dbc/client/DecoderStateInterface.java +++ b/src/main/java/org/mariadb/r2dbc/client/DecoderStateInterface.java @@ -1,28 +1,23 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.client; import io.netty.buffer.ByteBuf; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.server.Sequencer; -import org.mariadb.r2dbc.message.server.ServerMessage; public interface DecoderStateInterface { - // default DecoderState decoder(short val, int len, long serverCapabilities) { - // throw new IllegalArgumentException("unexpected state"); - // } - - default DecoderState decoder(short val, int len, long serverCapabilities) { + default DecoderState decoder(short val, int len) { return (DecoderState) this; } - default ServerMessage decode( - ByteBuf body, Sequencer sequencer, MariadbPacketDecoder decoder, CmdElement element) { + default ServerMessage decode(ByteBuf body, Sequencer sequencer, ServerMsgDecoder decoder) { throw new IllegalArgumentException("unexpected state"); } - default DecoderState next(MariadbPacketDecoder decoder) { - throw new IllegalArgumentException("unexpected state"); + default DecoderState next(ServerMsgDecoder decoder) { + return null; } } diff --git a/src/main/java/org/mariadb/r2dbc/client/Exchange.java b/src/main/java/org/mariadb/r2dbc/client/Exchange.java new file mode 100644 index 00000000..f47920bb --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/Exchange.java @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.client; + +import java.util.concurrent.atomic.AtomicLong; +import org.mariadb.r2dbc.message.ServerMessage; +import reactor.core.publisher.FluxSink; + +public class Exchange { + + private final FluxSink sink; + private final DecoderState initialState; + private final String sql; + private final AtomicLong demand = new AtomicLong(); + + public Exchange(FluxSink sink, DecoderState initialState) { + this.sink = sink; + this.initialState = initialState; + this.sql = null; + } + + public Exchange(FluxSink sink, DecoderState initialState, String sql) { + this.sink = sink; + this.initialState = initialState; + this.sql = sql; + } + + public FluxSink getSink() { + return sink; + } + + public DecoderState getInitialState() { + return initialState; + } + + public String getSql() { + return sql; + } + + public boolean hasDemand() { + return demand.get() > 0; + } + + public void emit(ServerMessage srvMsg) { + demand.decrementAndGet(); + if (this.sink.isCancelled()) { + return; + } + this.sink.next(srvMsg); + if (srvMsg.ending()) { + this.sink.complete(); + } + } + + public void incrementDemand(long n) { + demand.addAndGet(n); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/FailoverClient.java b/src/main/java/org/mariadb/r2dbc/client/FailoverClient.java new file mode 100644 index 00000000..e4309879 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/FailoverClient.java @@ -0,0 +1,594 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.client; + +import io.r2dbc.spi.R2dbcNonTransientException; +import io.r2dbc.spi.R2dbcTransientResourceException; +import io.r2dbc.spi.TransactionDefinition; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Predicate; +import org.mariadb.r2dbc.ExceptionFactory; +import org.mariadb.r2dbc.HaMode; +import org.mariadb.r2dbc.MariadbConnectionConfiguration; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.client.*; +import org.mariadb.r2dbc.message.server.*; +import org.mariadb.r2dbc.util.HostAddress; +import org.mariadb.r2dbc.util.PrepareCache; +import org.mariadb.r2dbc.util.ServerPrepareResult; +import org.mariadb.r2dbc.util.constants.Capabilities; +import org.mariadb.r2dbc.util.constants.ServerStatus; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.Sinks; + +public class FailoverClient implements Client { + + private static final Predicate FAIL_PREDICATE = + R2dbcNonTransientException.class::isInstance; + + private final AtomicReference client = new AtomicReference<>(); + private final MariadbConnectionConfiguration conf; + private final ReentrantLock lock; + + private static final Mono reconnectIfNeeded( + MariadbConnectionConfiguration conf, ReentrantLock lock, AtomicReference client) { + if (client.get().isConnected()) return Mono.just(Boolean.TRUE); + return reconnectFallbackReplay(null, conf, lock, client, true, false, null) + .then(Mono.just(Boolean.TRUE)); + } + + private static final Mono reconnectFallback( + Throwable t, + MariadbConnectionConfiguration conf, + ReentrantLock lock, + AtomicReference client) { + HaMode.failHost(client.get().getHostAddress()); + return conf.getHaMode() + .connectHost(conf, lock, false) + .flatMap( + c -> + syncNewState(client.get(), c, conf) + .flatMap( + v -> { + client.set(c); + return Mono.error( + new R2dbcTransientResourceException( + String.format( + "Driver has reconnect connection after a communications link failure with %s. In progress transaction was lost", + client.get().getHostAddress()), + "25S03")); + })); + } + + private static final Mono reconnectFallbackReplay( + Throwable throwable, + MariadbConnectionConfiguration conf, + ReentrantLock lock, + AtomicReference client, + boolean canSafelyBeReExecuted, + boolean firstMsgReceived, + ClientMessage request) { + HaMode.failHost(client.get().getHostAddress()); + return conf.getHaMode() + .connectHost(conf, lock, false) + .onErrorMap( + t -> + new R2dbcTransientResourceException( + String.format( + "Communications link failure with %s, failing to recreate new connection", + client.get().getHostAddress()), + "25S03", + t)) + .flatMap( + c -> { + Client oldcli = client.get(); + client.set(c); + return syncNewState(oldcli, c, conf) + .then( + replayIfPossible( + throwable, + oldcli, + c, + conf, + canSafelyBeReExecuted, + firstMsgReceived, + request)) + .thenReturn(c); + }); + } + + public FailoverClient(MariadbConnectionConfiguration conf, ReentrantLock lock, Client client) { + this.client.set(client); + this.conf = conf; + this.lock = lock; + } + + private static Mono syncNewState( + Client oldCli, Client currentClient, MariadbConnectionConfiguration conf) { + Context oldCtx = oldCli.getContext(); + + // sync database + Mono monoDatabase; + if ((oldCtx.getClientCapabilities() | Capabilities.CLIENT_SESSION_TRACK) > 0 + && oldCtx.getDatabase() != null + && oldCtx.getDatabase().equals(conf.getDatabase())) { + monoDatabase = Mono.empty(); + } else { + ExceptionFactory exceptionFactory = ExceptionFactory.withSql("COM_INIT_DB"); + monoDatabase = + currentClient + .sendCommand(new ChangeSchemaPacket(oldCtx.getDatabase()), true) + .handle(exceptionFactory::handleErrorResponse) + .then(); + } + + // sync transaction isolation + Mono monoIsolationLevel; + if (currentClient.getContext().getIsolationLevel() == oldCtx.getIsolationLevel()) { + monoIsolationLevel = Mono.empty(); + } else { + String sql = + String.format( + "SET SESSION TRANSACTION ISOLATION LEVEL %s", oldCtx.getIsolationLevel().asSql()); + ExceptionFactory exceptionFactory = ExceptionFactory.withSql(sql); + monoIsolationLevel = + currentClient + .sendCommand(new QueryPacket(sql), true) + .handle(exceptionFactory::handleErrorResponse) + .then(); + } + + // sync autoCommit + return currentClient + .setAutoCommit(oldCli.isAutoCommit()) + .then(monoDatabase) + .then(monoIsolationLevel) + .then(); + } + + private static Mono replayIfPossible( + Throwable throwable, + Client oldClient, + Client client, + MariadbConnectionConfiguration conf, + boolean canRedo, + boolean firstMsgReceived, + ClientMessage request) { + if ((oldClient.getContext().getServerStatus() & ServerStatus.IN_TRANSACTION) > 0) { + if (conf.isTransactionReplay()) { + if (firstMsgReceived) { + return Mono.error( + new R2dbcTransientResourceException( + String.format( + "Driver has reconnect connection after a communications link failure with %s during command.", + oldClient.getHostAddress()), + "25S03", + throwable)); + } + + return executeTransactionReplay(oldClient, client, request); + } else { + // transaction is lost, but connection is now up again. + // changing exception to SQLTransientConnectionException + return Mono.error( + new R2dbcTransientResourceException( + String.format( + "Driver has reconnect connection after a communications link failure with %s. In progress transaction was lost", + oldClient.getHostAddress()), + "25S03", + throwable)); + } + } + return canRedo + ? Mono.empty() + : Mono.error( + new R2dbcTransientResourceException( + String.format( + "Driver has reconnect connection after a communications link failure with %s", + oldClient.getHostAddress()), + "25S03", + throwable)); + } + + private static Mono executeTransactionReplay( + Client oldCli, Client client, ClientMessage request) { + // transaction replay + RedoContext ctx = (RedoContext) oldCli.getContext(); + if (ctx.getTransactionSaver().isDirty()) { + ctx.getTransactionSaver().clear(); + return Mono.error( + new R2dbcTransientResourceException( + String.format( + "Driver has reconnect connection after a communications link failure with %s. In progress transaction was too big to be replayed, and was lost", + oldCli.getHostAddress()), + "25S03")); + } + TransactionSaver transactionSaver = ctx.getTransactionSaver(); + + Queue endedCmdQueue = transactionSaver.getMessages(); + if (endedCmdQueue.isEmpty()) return Mono.empty(); + transactionSaver.forceDirty(); + Sinks.Many cmdSink = Sinks.many().unicast().onBackpressureBuffer(); + AtomicBoolean canceled = new AtomicBoolean(); + return cmdSink + .asFlux() + .map( + it -> { + it.resetSequencer(); + if (it instanceof PreparePacket) { + return client + .sendCommand(it, DecoderState.PREPARE_RESPONSE, false) + .doOnComplete(() -> tryNextCommand(endedCmdQueue, cmdSink, canceled, request)); + } else if (it instanceof ExecutePacket) { + // command is a prepare statement query + // redo on new connection need to re-prepare query + // and substitute statement id + Mono req = ((ExecutePacket) it).rePrepare(client); + return req.flatMapMany( + req2 -> + client + .sendCommand(req2, false) + .doOnComplete( + () -> tryNextCommand(endedCmdQueue, cmdSink, canceled, request))); + } else { + return client + .sendCommand(it, false) + .doOnComplete(() -> tryNextCommand(endedCmdQueue, cmdSink, canceled, request)); + } + }) + .flatMap(mariadbResultFlux -> mariadbResultFlux) + .doOnCancel(() -> canceled.set(true)) + .doOnDiscard(RowPacket.class, RowPacket::release) + .doOnError(e -> canceled.set(true)) + .doOnSubscribe(it -> tryNextCommand(endedCmdQueue, cmdSink, canceled, request)) + .onErrorMap( + e -> new R2dbcTransientResourceException("Socket error during transaction replay", e)) + .doOnComplete( + () -> { + ctx.getTransactionSaver().clear(); + ctx.getTransactionSaver().forceDirty(); + }) + .then(); + } + + private static void tryNextCommand( + Queue endedCmdQueue, + Sinks.Many cmdSink, + AtomicBoolean canceled, + ClientMessage request) { + + if (canceled.get()) { + return; + } + + try { + ClientMessage endedCmd = endedCmdQueue.poll(); + if (endedCmd != null && (request == null || !request.equals(endedCmd))) { + cmdSink.emitNext(endedCmd, Sinks.EmitFailureHandler.FAIL_FAST); + } else { + cmdSink.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST); + } + } catch (Exception e) { + cmdSink.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST); + } + } + + @Override + public Mono close() { + return client.get().close(); + } + + @Override + public void sendCommandWithoutResult(ClientMessage requests) { + client.get().sendCommandWithoutResult(requests); + } + + @Override + public Flux sendCommand(ClientMessage requests, boolean canSafelyBeReExecuted) { + return sendCommand(requests, DecoderState.QUERY_RESPONSE, null, canSafelyBeReExecuted); + } + + @Override + public Flux sendCommand( + ClientMessage requests, DecoderState initialState, boolean canSafelyBeReExecuted) { + return sendCommand(requests, initialState, null, canSafelyBeReExecuted); + } + + @Override + public Flux sendCommand( + ClientMessage requests, + DecoderState initialState, + String sql, + boolean canSafelyBeReExecuted) { + AtomicBoolean firstMsgReceived = new AtomicBoolean(false); + return reconnectIfNeeded(conf, lock, client) + .flatMapMany( + reconnected -> { + Mono clientMsg; + if (reconnected && requests instanceof ExecutePacket) { + // in case reconnection occurs during an ExecutePacket, need to re-prepare + clientMsg = ((ExecutePacket) requests).rePrepare(client.get()); + } else { + clientMsg = Mono.just(requests); + } + return clientMsg.flatMapMany( + req -> + client + .get() + .sendCommand(req, initialState, sql, canSafelyBeReExecuted) + .switchOnFirst( + (signal, serverMessageFlux) -> { + // Redo can only be done if subscriber has not began to receive data + // for UPSERT command would be ok in a transaction, + // but resulting operation would be wrong, having already handle + // some + // serverMessage. + // so all commands that fails before completion, and after receiving + // first + // message + // mustn't be replayed + if (signal.getType() == SignalType.ON_NEXT) + firstMsgReceived.set(true); + return serverMessageFlux; + }) + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay( + t, + conf, + lock, + client, + canSafelyBeReExecuted, + firstMsgReceived.get(), + req) + .map( + c -> { + req.resetSequencer(); + Mono clientMsg2; + if (reconnected && req instanceof ExecutePacket) { + // in case reconnection occurs during an + // ExecutePacket, need to re-prepare + clientMsg2 = + ((ExecutePacket) req).rePrepare(client.get()); + } else { + clientMsg2 = Mono.just(req); + } + return clientMsg.flatMapMany( + req2 -> + c.sendCommand( + req2, + initialState, + sql, + canSafelyBeReExecuted)); + }) + .flatMapMany(flux -> flux))); + }); + } + + public Mono sendPrepare( + ClientMessage requests, ExceptionFactory factory, String sql) { + return this.sendCommand(requests, DecoderState.PREPARE_RESPONSE, sql, true) + .handle( + (it, sink) -> { + if (it instanceof ErrorPacket) { + sink.error(factory.from((ErrorPacket) it)); + return; + } + if (it instanceof CompletePrepareResult) { + sink.next(((CompletePrepareResult) it).getPrepare()); + } + if (it.ending()) { + sink.complete(); + } + }) + .cast(ServerPrepareResult.class) + .singleOrEmpty(); + } + + @Override + public Flux sendCommand( + PreparePacket preparePacket, ExecutePacket executePacket, boolean canSafelyBeReExecuted) { + AtomicBoolean firstMsgReceived = new AtomicBoolean(false); + return reconnectIfNeeded(conf, lock, client) + .flatMapMany( + cc -> + client + .get() + .sendCommand(preparePacket, executePacket, canSafelyBeReExecuted) + .switchOnFirst( + (signal, serverMessageFlux) -> { + // Redo can only be done if subscriber has not begun to receive data + // for UPSERT command would be ok in a transaction, + // but resulting operation would be wrong, having already handle some + // serverMessage. + // so all commands that fails before completion, and after receiving first + // message + // mustn't be replayed + if (signal.getType() == SignalType.ON_NEXT) firstMsgReceived.set(true); + return serverMessageFlux; + }) + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay( + t, + conf, + lock, + client, + canSafelyBeReExecuted, + firstMsgReceived.get(), + executePacket) + .map( + c -> { + preparePacket.resetSequencer(); + executePacket.resetSequencer(); + return c.sendCommand( + preparePacket, executePacket, canSafelyBeReExecuted); + }) + .flatMapMany(flux -> flux))); + } + + @Override + public Mono sendSslRequest( + SslRequestPacket sslRequest, MariadbConnectionConfiguration configuration) { + return client.get().sendSslRequest(sslRequest, configuration); + } + + @Override + public boolean isAutoCommit() { + return client.get().isAutoCommit(); + } + + @Override + public boolean isInTransaction() { + return client.get().isInTransaction(); + } + + @Override + public boolean noBackslashEscapes() { + return client.get().noBackslashEscapes(); + } + + @Override + public ServerVersion getVersion() { + return client.get().getVersion(); + } + + @Override + public boolean isConnected() { + return client.get().isConnected(); + } + + @Override + public boolean isCloseRequested() { + return client.get().isCloseRequested(); + } + + @Override + public void setContext(InitialHandshakePacket packet, long clientCapabilities) { + client.get().setContext(packet, clientCapabilities); + } + + @Override + public Context getContext() { + return client.get().getContext(); + } + + @Override + public PrepareCache getPrepareCache() { + return client.get().getPrepareCache(); + } + + @Override + public Mono beginTransaction() { + return reconnectIfNeeded(conf, lock, client) + .flatMap( + cc -> + client + .get() + .beginTransaction() + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay(t, conf, lock, client, true, false, null) + .map(c -> c.beginTransaction()) + .flatMap(flux -> flux))); + } + + @Override + public Mono beginTransaction(TransactionDefinition definition) { + return reconnectIfNeeded(conf, lock, client) + .flatMap( + cc -> + client + .get() + .beginTransaction(definition) + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay(t, conf, lock, client, true, true, null) + .map(c -> c.beginTransaction(definition)) + .flatMap(flux -> flux))); + } + + @Override + public Mono commitTransaction() { + // just reconnect + return client + .get() + .commitTransaction() + .doOnError(FAIL_PREDICATE, t -> reconnectFallback(t, conf, lock, client)); + } + + @Override + public Mono rollbackTransaction() { + return reconnectIfNeeded(conf, lock, client) + .flatMap( + cc -> + client + .get() + .rollbackTransaction() + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay(t, conf, lock, client, true, true, null) + .map(c -> c.rollbackTransaction()) + .flatMap(flux -> flux))); + } + + @Override + public Mono setAutoCommit(boolean autoCommit) { + // setting autocommit to true will commit existing transaction, so if failing we cannot knows if + // was really committed + if (autoCommit) { + return client + .get() + .setAutoCommit(true) + .doOnError(FAIL_PREDICATE, t -> reconnectFallback(t, conf, lock, client)); + } + return reconnectIfNeeded(conf, lock, client) + .flatMap( + cc -> + client + .get() + .setAutoCommit(false) + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay(t, conf, lock, client, true, true, null) + .map(c -> c.setAutoCommit(false)) + .flatMap(flux -> flux))); + } + + @Override + public Mono rollbackTransactionToSavepoint(String name) { + return client + .get() + .rollbackTransactionToSavepoint(name) + .onErrorResume( + FAIL_PREDICATE, + t -> + reconnectFallbackReplay(t, conf, lock, client, true, true, null) + .map(c -> c.rollbackTransactionToSavepoint(name)) + .flatMap(flux -> flux)); + } + + @Override + public long getThreadId() { + return client.get().getThreadId(); + } + + @Override + public HostAddress getHostAddress() { + return client.get().getHostAddress(); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/MariadbFrameDecoder.java b/src/main/java/org/mariadb/r2dbc/client/MariadbFrameDecoder.java new file mode 100644 index 00000000..924634c7 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/MariadbFrameDecoder.java @@ -0,0 +1,67 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package org.mariadb.r2dbc.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import java.util.List; + +public class MariadbFrameDecoder extends ByteToMessageDecoder { + private CompositeByteBuf multipart = null; + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf buf, List out) throws Exception { + while (buf.readableBytes() > 4) { + int length = buf.getUnsignedMediumLE(buf.readerIndex()); + + // packet not complete + if (buf.readableBytes() < length + 4) return; + + // extract packet + if (length == 0xffffff) { + // multipart packet + if (multipart == null) { + multipart = buf.alloc().compositeBuffer(); + } + buf.skipBytes(4); // skip length + header + multipart.addComponent(true, buf.readRetainedSlice(length)); + continue; + } + + // wait for complete packet + if (multipart != null) { + // last part of multipart packet + buf.skipBytes(3); // skip length + + // add sequence byte + multipart.addComponent(true, 0, Unpooled.wrappedBuffer(new byte[] {buf.readByte()})); + // add data + multipart.addComponent(true, buf.readRetainedSlice(length)); + out.add(multipart); + multipart = null; + continue; + } + + // create Object from packet + ByteBuf packet = buf.readRetainedSlice(4 + length); + packet.skipBytes(3); // skip length + out.add(packet); + } + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/MariadbPacketDecoder.java b/src/main/java/org/mariadb/r2dbc/client/MariadbPacketDecoder.java deleted file mode 100644 index 4941428f..00000000 --- a/src/main/java/org/mariadb/r2dbc/client/MariadbPacketDecoder.java +++ /dev/null @@ -1,185 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.client; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.ByteToMessageDecoder; -import io.r2dbc.spi.R2dbcNonTransientResourceException; -import java.util.List; -import java.util.Queue; -import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; -import org.mariadb.r2dbc.message.server.PrepareResultPacket; -import org.mariadb.r2dbc.message.server.Sequencer; -import org.mariadb.r2dbc.message.server.ServerMessage; -import org.mariadb.r2dbc.util.BufferUtils; -import org.mariadb.r2dbc.util.PrepareCache; -import org.mariadb.r2dbc.util.ServerPrepareResult; - -public class MariadbPacketDecoder extends ByteToMessageDecoder { - - private final Queue responseReceivers; - private final Client client; - - private Context context = null; - private boolean isMultipart = false; - private DecoderState state = DecoderState.INIT_HANDSHAKE; - private CmdElement cmdElement; - private CompositeByteBuf multipart; - private long serverCapabilities; - private int stateCounter = 0; - private PrepareResultPacket prepare; - private ColumnDefinitionPacket[] prepareColumns; - - public MariadbPacketDecoder(Queue responseReceivers, Client client) { - this.responseReceivers = responseReceivers; - this.client = client; - } - - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf buf, List out) throws Exception { - while (buf.readableBytes() > 4) { - int length = buf.getUnsignedMediumLE(buf.readerIndex()); - - // packet not complete - if (buf.readableBytes() < length + 4) return; - - // extract packet - if (length == 0xffffff) { - // multipart packet - if (!isMultipart) { - isMultipart = true; - multipart = buf.alloc().compositeBuffer(); - } - buf.skipBytes(4); // skip length + header - multipart.addComponent(true, buf.readRetainedSlice(length)); - continue; - } - - // wait for complete packet - if (isMultipart) { - // last part of multipart packet - buf.skipBytes(3); // skip length - Sequencer sequencer = new Sequencer(buf.readByte()); - multipart.addComponent(true, buf.readRetainedSlice(length)); - - handleBuffer(multipart, sequencer); - - multipart.release(); - isMultipart = false; - continue; - } - - // create Object from packet - ByteBuf packet = buf.readRetainedSlice(4 + length); - packet.skipBytes(3); // skip length - Sequencer sequencer = new Sequencer(packet.readByte()); - handleBuffer(packet, sequencer); - packet.release(); - } - } - - private void handleBuffer(ByteBuf packet, Sequencer sequencer) { - if (cmdElement == null && !loadNextResponse()) { - throw new R2dbcNonTransientResourceException( - String.format( - "unexpected message received when no command was send: 0x%s", - BufferUtils.toString(packet))); - } - - state = - state.decoder( - packet.getUnsignedByte(packet.readerIndex()), - packet.readableBytes(), - serverCapabilities); - ServerMessage msg = state.decode(packet, sequencer, this, cmdElement); - cmdElement.getSink().next(msg); - if (msg.ending()) { - if (cmdElement != null) { - // complete executed only after setting next element. - CmdElement element = cmdElement; - loadNextResponse(); - element.getSink().complete(); - } - client.sendNext(); - } else { - state = state.next(this); - } - } - - public void connectionError(Throwable err) { - if (cmdElement != null) { - cmdElement.getSink().error(err); - cmdElement = null; - state = null; - } - } - - public Context getContext() { - return context; - } - - public int getStateCounter() { - return stateCounter; - } - - public void setStateCounter(int counter) { - stateCounter = counter; - } - - public PrepareResultPacket getPrepare() { - return prepare; - } - - public void setPrepare(PrepareResultPacket prepare) { - this.prepare = prepare; - this.prepareColumns = new ColumnDefinitionPacket[prepare.getNumColumns()]; - } - - public ColumnDefinitionPacket[] getPrepareColumns() { - return prepareColumns; - } - - public ServerPrepareResult endPrepare() { - // this.prepareColumns = new ColumnDefinitionPacket[prepare.getNumColumns()]; - ServerPrepareResult prepareResult = - new ServerPrepareResult( - this.prepare.getStatementId(), this.prepare.getNumParams(), prepareColumns); - PrepareCache prepareCache = client.getPrepareCache(); - if (prepareCache != null) { - ServerPrepareResult cached = prepareCache.put(cmdElement.getSql(), prepareResult); - if (cached != null) { - // race condition, remove new one to get the one in cache - prepareResult.decrementUse(client); - prepareResult = cached; - } - } - this.prepare = null; - return prepareResult; - } - - public void decrementStateCounter() { - stateCounter--; - } - - public long getServerCapabilities() { - return serverCapabilities; - } - - private boolean loadNextResponse() { - this.cmdElement = responseReceivers.poll(); - if (cmdElement != null) { - state = cmdElement.getInitialState(); - return true; - } - state = null; - return false; - } - - public void setContext(Context context) { - this.context = context; - this.serverCapabilities = this.context.getServerCapabilities(); - } -} diff --git a/src/main/java/org/mariadb/r2dbc/client/MariadbPacketEncoder.java b/src/main/java/org/mariadb/r2dbc/client/MariadbPacketEncoder.java index c3f9b1f1..4d867814 100644 --- a/src/main/java/org/mariadb/r2dbc/client/MariadbPacketEncoder.java +++ b/src/main/java/org/mariadb/r2dbc/client/MariadbPacketEncoder.java @@ -1,53 +1,49 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.client; import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToByteEncoder; -import org.mariadb.r2dbc.message.client.ClientMessage; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; -public class MariadbPacketEncoder extends MessageToByteEncoder { +public class MariadbPacketEncoder { private Context context = null; - @Override - protected void encode(ChannelHandlerContext ctx, ClientMessage msg, ByteBuf out) - throws Exception { - - ByteBuf buf = null; - try { - buf = msg.encode(this.context, ctx.alloc()); - - // single mysql packet - if (buf.writerIndex() - buf.readerIndex() < 0xffffff) { - out.writeMediumLE(buf.writerIndex() - buf.readerIndex()); - out.writeByte(msg.getSequencer().next()); - out.writeBytes(buf); - // buf.release(); - return; - } - - // multiple mysql packet - split in 16mb packet - int readerIndex = buf.readerIndex(); - int packetLength = -1; - while (readerIndex < buf.writerIndex()) { - packetLength = Math.min(0xffffff, buf.writerIndex() - readerIndex); - out.writeMediumLE(packetLength); - out.writeByte(msg.getSequencer().next()); - out.writeBytes(buf.slice(readerIndex, packetLength)); - readerIndex += packetLength; - } - - if (packetLength == 0xffffff) { - // in case last packet is full, sending an empty packet to indicate that command is complete - out.writeMediumLE(packetLength); - out.writeByte(msg.getSequencer().next()); - } - - } finally { - if (buf != null) buf.release(); + public CompositeByteBuf encodeFlux(ClientMessage msg) { + ByteBufAllocator allocator = context.getByteBufAllocator(); + CompositeByteBuf out = allocator.compositeBuffer(); + + ByteBuf buf = msg.encode(context, allocator); + int initialReaderIndex = buf.readerIndex(); + int packetLength; + do { + packetLength = Math.min(0xffffff, buf.readableBytes()); + + ByteBuf header = Unpooled.buffer(4, 4); + header.writeMediumLE(packetLength); + header.writeByte(msg.getSequencer().next()); + + out.addComponent(true, header); + out.addComponent(true, buf.retainedSlice(buf.readerIndex(), packetLength)); + buf.skipBytes(packetLength); + } while (buf.readableBytes() > 0); + + if (packetLength == 0xffffff) { + // in case last packet is full, sending an empty packet to indicate that command is complete + ByteBuf header = Unpooled.buffer(4, 4); + header.writeMediumLE(0); + header.writeByte(msg.getSequencer().next()); + out.addComponent(true, header); } + + context.saveRedo(msg, buf, initialReaderIndex); + msg.releaseEncodedBinds(); + buf.release(); + return out; } public void setContext(Context context) { diff --git a/src/main/java/org/mariadb/r2dbc/client/MariadbResult.java b/src/main/java/org/mariadb/r2dbc/client/MariadbResult.java new file mode 100644 index 00000000..33aaf80c --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/MariadbResult.java @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.r2dbc.spi.R2dbcException; +import io.r2dbc.spi.Readable; +import io.r2dbc.spi.Result; +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Predicate; +import org.mariadb.r2dbc.*; +import org.mariadb.r2dbc.codec.BinaryRowDecoder; +import org.mariadb.r2dbc.codec.RowDecoder; +import org.mariadb.r2dbc.codec.TextRowDecoder; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.server.*; +import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.ServerPrepareResult; +import org.mariadb.r2dbc.util.constants.ServerStatus; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; + +public final class MariadbResult implements org.mariadb.r2dbc.api.MariadbResult { + + private final Flux dataRows; + private final ExceptionFactory factory; + + private final String[] generatedColumns; + private final boolean supportReturning; + private final boolean text; + private final MariadbConnectionConfiguration conf; + private final AtomicReference prepareResult; + private Predicate filter; + + private volatile MariadbDataSegment segment; + + public MariadbResult( + boolean text, + AtomicReference prepareResult, + Flux dataRows, + ExceptionFactory factory, + String[] generatedColumns, + boolean supportReturning, + MariadbConnectionConfiguration conf) { + this.text = text; + this.dataRows = dataRows; + this.factory = factory; + this.generatedColumns = generatedColumns; + this.supportReturning = supportReturning; + this.conf = conf; + this.prepareResult = prepareResult; + this.filter = null; + } + + @Override + public Flux getRowsUpdated() { + // Since CLIENT_DEPRECATE_EOF is not set in order to identify output parameter + // number of updated row can be identified either by OK_Packet or number of rows in case of + // RETURNING + final AtomicInteger rowCount = new AtomicInteger(0); + return this.dataRows + .takeUntil(ServerMessage::resultSetEnd) + .handle( + (serverMessage, sink) -> { + if (serverMessage instanceof ErrorPacket) { + sink.error(this.factory.from((ErrorPacket) serverMessage)); + return; + } + + if (serverMessage instanceof OkPacket) { + OkPacket okPacket = ((OkPacket) serverMessage); + sink.next((int) okPacket.value()); + sink.complete(); + return; + } + + if (serverMessage instanceof EofPacket) { + EofPacket eofPacket = ((EofPacket) serverMessage); + if (eofPacket.resultSetEnd()) { + sink.next(rowCount.get()); + rowCount.set(0); + sink.complete(); + } + return; + } + + if (serverMessage instanceof RowPacket) { + rowCount.incrementAndGet(); + ((RowPacket) serverMessage).release(); + return; + } + }); + } + + public Flux map(BiFunction mappingFunction) { + Assert.requireNonNull(mappingFunction, "mappingFunction must not be null"); + Flux flux = + this.dataRows + .takeUntil(ServerMessage::resultSetEnd) + .handle(this.handler(true)) + .filter(MariadbRowSegment.class::isInstance); + if (filter != null) flux = flux.filter(filter); + + return flux.cast(MariadbRowSegment.class) + .map( + it -> { + try { + return mappingFunction.apply(it.row(), it.getMetadata()); + } catch (IllegalArgumentException i) { + throw this.factory.createException(i.getMessage(), "HY000", -1); + } + }); + } + + @Override + public Flux map(Function mappingFunction) { + Assert.requireNonNull(mappingFunction, "mappingFunction must not be null"); + Flux flux = + this.dataRows.takeUntil(ServerMessage::resultSetEnd).handle(this.handler(true)); + if (filter != null) flux = flux.filter(filter); + + return flux.cast(MariadbRowSegment.class).map(it -> mappingFunction.apply(it.row())); + } + + @Override + public Result filter(Predicate filter) { + this.filter = filter; + return this; + } + + @Override + public Flux flatMap(Function> mappingFunction) { + Assert.requireNonNull(mappingFunction, "mappingFunction must not be null"); + Flux flux = + this.dataRows.takeUntil(ServerMessage::resultSetEnd).handle(this.handler(true)); + if (filter != null) flux = flux.filter(filter); + return flux.flatMap(it -> mappingFunction.apply(it)); + } + + private BiConsumer> handler(boolean throwError) { + final List columns = new ArrayList<>(); + final AtomicBoolean metaFollows = new AtomicBoolean(true); + return (serverMessage, sink) -> { + if (serverMessage instanceof ErrorPacket) { + if (throwError) { + sink.error(this.factory.from((ErrorPacket) serverMessage)); + } else { + sink.next((ErrorPacket) serverMessage); + sink.complete(); + } + return; + } + + if (serverMessage instanceof CompletePrepareResult) { + this.prepareResult.set(((CompletePrepareResult) serverMessage).getPrepare()); + return; + } + + if (serverMessage instanceof ColumnCountPacket) { + metaFollows.set(((ColumnCountPacket) serverMessage).isMetaFollows()); + if (!metaFollows.get()) { + columns.addAll(Arrays.asList(this.prepareResult.get().getColumns())); + } + return; + } + + if (serverMessage instanceof ColumnDefinitionPacket) { + columns.add((ColumnDefinitionPacket) serverMessage); + return; + } + + if (serverMessage instanceof OkPacket) { + OkPacket okPacket = ((OkPacket) serverMessage); + // This is for server that doesn't permit RETURNING: rely on OK_packet LastInsertId + // to retrieve the last generated ID. + if (generatedColumns != null && !supportReturning) { + String colName = generatedColumns.length > 0 ? generatedColumns[0] : "ID"; + List tmpCol = + Collections.singletonList(ColumnDefinitionPacket.fromGeneratedId(colName, conf)); + if (okPacket.value() > 1) { + sink.error( + this.factory.createException( + "Connector cannot get generated ID (using returnGeneratedValues) multiple rows before MariaDB 10.5.1", + "HY000", + -1)); + return; + } + + ByteBuf buf = getLongTextEncoded(okPacket.getLastInsertId()); + segment = + new MariadbRowSegment(new TextRowDecoder(tmpCol, this.conf, this.factory), tmpCol); + segment.updateRaw(buf); + sink.next(segment); + } else sink.next(okPacket); + return; + } + + if (serverMessage instanceof EofPacket) { + RowDecoder decoder = + text + ? new TextRowDecoder(columns, this.conf, this.factory) + : new BinaryRowDecoder(columns, this.conf, this.factory); + boolean outputParameter = + (((EofPacket) serverMessage).getServerStatus() & ServerStatus.PS_OUT_PARAMETERS) > 0; + // in case metadata follows and prepared statement, update meta + if (prepareResult != null && prepareResult.get() != null && metaFollows.get()) { + prepareResult.get().setColumns(columns.toArray(new ColumnDefinitionPacket[0])); + } + segment = + outputParameter + ? new MariadbOutSegment(decoder, columns) + : new MariadbRowSegment(decoder, columns); + return; + } + + if (serverMessage instanceof RowPacket) { + RowPacket row = ((RowPacket) serverMessage); + try { + segment.updateRaw(row.getRaw()); + sink.next(segment); + } catch (IllegalArgumentException i) { + sink.error(this.factory.createException(i.getMessage(), "HY000", -1)); + } catch (R2dbcException i) { + sink.error(i); + } finally { + row.release(); + } + return; + } + }; + } + + private ByteBuf getLongTextEncoded(long value) { + byte[] byteValue = Long.toString(value).getBytes(StandardCharsets.US_ASCII); + byte[] encodedLength; + int length = byteValue.length; + encodedLength = new byte[] {(byte) length}; + return Unpooled.copiedBuffer(encodedLength, byteValue); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/RedoContext.java b/src/main/java/org/mariadb/r2dbc/client/RedoContext.java new file mode 100644 index 00000000..1481265a --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/RedoContext.java @@ -0,0 +1,64 @@ +package org.mariadb.r2dbc.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.r2dbc.spi.IsolationLevel; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.util.constants.ServerStatus; + +public class RedoContext extends SimpleContext { + + private final TransactionSaver transactionSaver; + + public RedoContext( + String serverVersion, + long threadId, + long capabilities, + short serverStatus, + boolean mariaDBServer, + long clientCapabilities, + String database, + ByteBufAllocator byteBufAllocator, + IsolationLevel isolationLevel) { + super( + serverVersion, + threadId, + capabilities, + serverStatus, + mariaDBServer, + clientCapabilities, + database, + byteBufAllocator, + isolationLevel); + transactionSaver = new TransactionSaver(); + } + + /** + * Set server status + * + * @param serverStatus server status + */ + public void setServerStatus(short serverStatus) { + super.setServerStatus(serverStatus); + if ((serverStatus & ServerStatus.IN_TRANSACTION) == 0) transactionSaver.clear(); + } + + /** + * Save client message + * + * @param msg client message + */ + public void saveRedo(ClientMessage msg, ByteBuf buf, int initialReaderIndex) { + msg.save(buf, initialReaderIndex); + transactionSaver.add(msg); + } + + /** + * Get transaction saver cache + * + * @return transaction saver cache + */ + public TransactionSaver getTransactionSaver() { + return transactionSaver; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/ServerMsgDecoder.java b/src/main/java/org/mariadb/r2dbc/client/ServerMsgDecoder.java new file mode 100644 index 00000000..52012bdd --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/ServerMsgDecoder.java @@ -0,0 +1,104 @@ +package org.mariadb.r2dbc.client; + +import io.netty.buffer.ByteBuf; +import java.util.Queue; +import org.mariadb.r2dbc.MariadbConnectionConfiguration; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.message.server.PrepareResultPacket; +import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.util.PrepareCache; +import org.mariadb.r2dbc.util.ServerPrepareResult; +import reactor.util.concurrent.Queues; + +public class ServerMsgDecoder { + private final Client client; + private final MariadbConnectionConfiguration configuration; + private DecoderState state = null; + private final Queue prepareSql = Queues.small().get(); + private long clientCapabilities; + private int stateCounter = 0; + private PrepareResultPacket prepare; + private ColumnDefinitionPacket[] prepareColumns; + private Context context = null; + + public ServerMsgDecoder(Client client, MariadbConnectionConfiguration configuration) { + this.client = client; + this.configuration = configuration; + } + + public ServerMessage decode(ByteBuf packet, Exchange exchange) { + Sequencer sequencer = new Sequencer(packet.readByte()); + if (state == null) + state = exchange == null ? DecoderState.QUERY_RESPONSE : exchange.getInitialState(); + state = state.decoder(packet.getUnsignedByte(packet.readerIndex()), packet.readableBytes()); + ServerMessage msg = state.decode(packet, sequencer, this); + state = msg.ending() ? null : state.next(this); + return msg; + } + + public Context getContext() { + return context; + } + + public int getStateCounter() { + return stateCounter; + } + + public void setStateCounter(int counter) { + stateCounter = counter; + } + + public PrepareResultPacket getPrepare() { + return prepare; + } + + public void setPrepare(PrepareResultPacket prepare) { + this.prepare = prepare; + this.prepareColumns = + (prepare == null) ? null : new ColumnDefinitionPacket[prepare.getNumColumns()]; + } + + public ColumnDefinitionPacket[] getPrepareColumns() { + return prepareColumns; + } + + public MariadbConnectionConfiguration getConf() { + return configuration; + } + + public ServerPrepareResult endPrepare() { + ServerPrepareResult prepareResult = + new ServerPrepareResult( + this.prepare.getStatementId(), this.prepare.getNumParams(), prepareColumns); + String sql = prepareSql.poll(); + PrepareCache prepareCache = client.getPrepareCache(); + if (prepareCache != null) { + ServerPrepareResult cached = prepareCache.put(sql, prepareResult); + if (cached != null) { + // race condition, remove new one to get the one in cache + prepareResult.decrementUse(client); + prepareResult = cached; + } + } + return prepareResult; + } + + public void decrementStateCounter() { + stateCounter--; + } + + public long getClientCapabilities() { + return clientCapabilities; + } + + public boolean addPrepare(String sql) { + return this.prepareSql.offer(sql); + } + + public void setContext(Context context) { + this.context = context; + this.clientCapabilities = this.context.getClientCapabilities(); + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/ServerVersion.java b/src/main/java/org/mariadb/r2dbc/client/ServerVersion.java index 10372ff6..c046a02f 100644 --- a/src/main/java/org/mariadb/r2dbc/client/ServerVersion.java +++ b/src/main/java/org/mariadb/r2dbc/client/ServerVersion.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.client; diff --git a/src/main/java/org/mariadb/r2dbc/client/SimpleClient.java b/src/main/java/org/mariadb/r2dbc/client/SimpleClient.java new file mode 100644 index 00000000..8fd3b8df --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/SimpleClient.java @@ -0,0 +1,760 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelOption; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.r2dbc.spi.*; +import java.net.SocketAddress; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Function; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import org.mariadb.r2dbc.*; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.message.client.*; +import org.mariadb.r2dbc.message.server.CompletePrepareResult; +import org.mariadb.r2dbc.message.server.ErrorPacket; +import org.mariadb.r2dbc.message.server.InitialHandshakePacket; +import org.mariadb.r2dbc.util.HostAddress; +import org.mariadb.r2dbc.util.PrepareCache; +import org.mariadb.r2dbc.util.ServerPrepareResult; +import org.mariadb.r2dbc.util.constants.ServerStatus; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.*; +import reactor.netty.Connection; +import reactor.netty.resources.ConnectionProvider; +import reactor.netty.tcp.TcpClient; +import reactor.util.Logger; +import reactor.util.Loggers; +import reactor.util.concurrent.Queues; + +public class SimpleClient implements Client { + + private static final Logger logger = Loggers.getLogger(SimpleClient.class); + protected final MariadbConnectionConfiguration configuration; + private final ServerMessageSubscriber messageSubscriber; + private final Sinks.Many> requestSink = + Sinks.many().unicast().onBackpressureBuffer(); + private final Queue exchangeQueue = + Queues.get(Queues.SMALL_BUFFER_SIZE).get(); + private final Queue receiverQueue = + Queues.get(Queues.SMALL_BUFFER_SIZE).get(); + + private final AtomicBoolean isClosed = new AtomicBoolean(false); + private final ServerMsgDecoder decoder; + private final MariadbPacketEncoder encoder; + private final PrepareCache prepareCache; + private final ByteBufAllocator byteBufAllocator; + + private volatile boolean closeRequested = false; + + protected final ReentrantLock lock; + protected final Connection connection; + protected final HostAddress hostAddress; + protected volatile Context context; + + @Override + public Context getContext() { + return context; + } + + protected SimpleClient( + Connection connection, + MariadbConnectionConfiguration configuration, + HostAddress hostAddress, + ReentrantLock lock) { + this.connection = connection; + this.configuration = configuration; + this.hostAddress = hostAddress; + this.lock = lock; + this.prepareCache = + new PrepareCache( + this.configuration.useServerPrepStmts() ? this.configuration.getPrepareCacheSize() : 0, + this); + this.decoder = new ServerMsgDecoder(this, configuration); + this.encoder = new MariadbPacketEncoder(); + this.byteBufAllocator = connection.outbound().alloc(); + this.messageSubscriber = + new ServerMessageSubscriber(this.lock, this.isClosed, exchangeQueue, receiverQueue); + connection.addHandler(new MariadbFrameDecoder()); + + if (logger.isTraceEnabled()) { + connection.addHandlerFirst( + LoggingHandler.class.getSimpleName(), + new LoggingHandler(SimpleClient.class, LogLevel.TRACE)); + } + + connection + .inbound() + .receive() + .doOnError(this::handleConnectionError) + .doOnComplete(this::handleConnectionEnd) + .subscribe(messageSubscriber); + + Mono request = + this.requestSink + .asFlux() + .concatMap(Function.identity()) + .cast(ClientMessage.class) + .map(encoder::encodeFlux) + .flatMap(b -> connection.outbound().send(Mono.just(b)), 1) + .then(); + + request.doAfterTerminate(this::handleConnectionEnd).subscribe(); + } + + public static Mono connect( + ConnectionProvider connectionProvider, + SocketAddress socketAddress, + HostAddress hostAddress, + MariadbConnectionConfiguration configuration, + ReentrantLock lock) { + TcpClient tcpClient = + TcpClient.create(connectionProvider) + .remoteAddress(() -> socketAddress) + .runOn(configuration.loopResources()); + tcpClient = setSocketOption(configuration, tcpClient); + return tcpClient + .connect() + .flatMap(it -> Mono.just(new SimpleClient(it, configuration, hostAddress, lock))); + } + + public static TcpClient setSocketOption( + MariadbConnectionConfiguration configuration, TcpClient tcpClient) { + if (configuration.getConnectTimeout() != null) { + tcpClient = + tcpClient.option( + ChannelOption.CONNECT_TIMEOUT_MILLIS, + Math.toIntExact(configuration.getConnectTimeout().toMillis())); + } + + if (configuration.isTcpKeepAlive()) { + tcpClient = tcpClient.option(ChannelOption.SO_KEEPALIVE, configuration.isTcpKeepAlive()); + } + + if (configuration.isTcpAbortiveClose()) { + tcpClient = tcpClient.option(ChannelOption.SO_LINGER, 0); + } + return tcpClient; + } + + private void sendClientMsgs(Publisher it) { + this.requestSink.emitNext(it, Sinks.EmitFailureHandler.FAIL_FAST); + } + + private void handleConnectionError(Throwable throwable) { + if (closeChannelIfNeeded()) { + logger.error("Connection unexpected error", throwable); + messageSubscriber.endExchanges( + new R2dbcNonTransientResourceException( + "Connection unexpected error", "08000", throwable)); + } else { + logger.error("Connection error", throwable); + messageSubscriber.endExchanges( + new R2dbcNonTransientResourceException("Connection error", "08000", throwable)); + } + } + + private void handleConnectionEnd() { + messageSubscriber.endExchanges( + new R2dbcNonTransientResourceException( + "Connection " + (closeChannelIfNeeded() ? "unexpectedly " : "") + "closed", "08000")); + } + + private boolean closeChannelIfNeeded() { + if (this.isClosed.compareAndSet(false, true)) { + Channel channel = this.connection.channel(); + if (channel.isOpen()) { + this.connection.dispose(); + } + return true; + } + return false; + } + + @Override + public Mono close() { + closeRequested = true; + return Mono.defer( + () -> { + if (this.isClosed.compareAndSet(false, true)) { + + Channel channel = this.connection.channel(); + if (!channel.isOpen()) { + this.connection.dispose(); + return this.connection.onDispose(); + } + + return Flux.just(QuitPacket.INSTANCE) + .doOnNext(message -> connection.channel().writeAndFlush(message)) + .then() + .doOnSuccess(v -> this.connection.dispose()) + .then(this.connection.onDispose()); + } + + return Mono.empty(); + }); + } + + @Override + public Mono sendSslRequest( + SslRequestPacket sslRequest, MariadbConnectionConfiguration configuration) { + CompletableFuture result = new CompletableFuture<>(); + try { + SSLEngine engine = + configuration.getSslConfig().getSslContext().newEngine(connection.channel().alloc()); + final SslHandler sslHandler = new SslHandler(engine); + + final GenericFutureListener> listener = + configuration + .getSslConfig() + .getHostNameVerifier( + result, + this.hostAddress == null ? null : this.hostAddress.getHost(), + context.getThreadId(), + engine); + + sslHandler.handshakeFuture().addListener(listener); + // send SSL request in clear + this.sendClientMsgs(Mono.just(sslRequest)); + + // add SSL handler + connection.addHandlerFirst(sslHandler); + return Mono.fromFuture(result); + + } catch (SSLException | R2dbcTransientResourceException e) { + result.completeExceptionally(e); + return Mono.fromFuture(result); + } + } + + private Flux execute(Consumer> s) { + return Flux.create( + sink -> { + if (!isConnected()) { + sink.error( + new R2dbcNonTransientResourceException( + "Connection is close. Cannot send anything")); + return; + } + try { + lock.lock(); + s.accept(sink); + } finally { + lock.unlock(); + } + }); + } + + public long getThreadId() { + return context.getThreadId(); + } + + /** + * Specific implementation, to avoid executing BEGIN if already in transaction + * + * @return publisher + */ + public Mono beginTransaction() { + try { + lock.lock(); + + return execute( + sink -> { + if (!exchangeQueue.isEmpty() + || (context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0) { + Exchange exchange = new Exchange(sink, DecoderState.QUERY_RESPONSE, "BEGIN"); + if (this.exchangeQueue.offer(exchange)) { + sendClientMsgs(Mono.just(new QueryPacket("BEGIN"))); + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + } else { + sink.error(new R2dbcTransientResourceException("Request queue limit reached")); + } + } else { + logger.debug("Skipping start transaction because already in transaction"); + sink.complete(); + } + }) + .handle(ExceptionFactory.withSql("BEGIN")::handleErrorResponse) + .then(); + + } finally { + lock.unlock(); + } + } + + /** + * Specific implementation, to avoid executing START TRANSACTION if already in transaction + * + * @param definition transaction definition + * @return publisher + */ + public Mono beginTransaction(TransactionDefinition definition) { + StringBuilder sb = new StringBuilder("START TRANSACTION"); + boolean first = true; + if (Boolean.TRUE.equals(definition.getAttribute(TransactionDefinition.READ_ONLY))) { + sb.append(" READ ONLY"); + first = false; + } + if (Boolean.TRUE.equals( + definition.getAttribute(MariadbTransactionDefinition.WITH_CONSISTENT_SNAPSHOT))) { + if (!first) sb.append(","); + sb.append(" WITH CONSISTENT SNAPSHOT"); + } + + String sql = sb.toString(); + try { + lock.lock(); + return execute( + sink -> { + if (!exchangeQueue.isEmpty() + || (context.getServerStatus() & ServerStatus.IN_TRANSACTION) == 0) { + Exchange exchange = new Exchange(sink, DecoderState.QUERY_RESPONSE, sql); + if (this.exchangeQueue.offer(exchange)) { + sendClientMsgs(Mono.just(new QueryPacket(sql))); + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + } else { + sink.error(new R2dbcTransientResourceException("Request queue limit reached")); + } + } else { + logger.debug("Skipping start transaction because already in transaction"); + sink.complete(); + } + }) + .handle(ExceptionFactory.withSql(sql)::handleErrorResponse) + .then(); + } finally { + lock.unlock(); + } + } + + /** + * Specific implementation, to avoid executing COMMIT if no transaction + * + * @return publisher + */ + public Mono commitTransaction() { + try { + lock.lock(); + return execute(sink -> executeWhenTransaction(sink, "COMMIT")) + .handle(ExceptionFactory.withSql("COMMIT")::handleErrorResponse) + .then(); + } finally { + lock.unlock(); + } + } + + private void executeWhenTransaction(FluxSink sink, String sql) { + if (!exchangeQueue.isEmpty() || (context.getServerStatus() & ServerStatus.IN_TRANSACTION) > 0) { + try { + lock.lock(); + Exchange exchange = new Exchange(sink, DecoderState.QUERY_RESPONSE, sql); + if (this.exchangeQueue.offer(exchange)) { + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + sendClientMsgs(Mono.just(new QueryPacket(sql))); + } else { + sink.error(new R2dbcTransientResourceException("Request queue limit reached")); + } + } catch (Throwable t) { + t.printStackTrace(); + throw t; + } finally { + lock.unlock(); + } + } else { + logger.debug(String.format("Skipping '%s' because no active transaction", sql)); + sink.complete(); + } + } + + /** + * Specific implementation, to avoid executing ROLLBACK if no transaction + * + * @return publisher + */ + public Mono rollbackTransaction() { + try { + lock.lock(); + return execute(sink -> executeWhenTransaction(sink, "ROLLBACK")) + .handle(ExceptionFactory.withSql("ROLLBACK")::handleErrorResponse) + .then(); + } finally { + lock.unlock(); + } + } + + /** + * Specific implementation, to avoid executing ROLLBACK TO TRANSACTION if no transaction + * + * @return publisher + */ + public Mono rollbackTransactionToSavepoint(String name) { + try { + lock.lock(); + String sql = String.format("ROLLBACK TO SAVEPOINT `%s`", name.replace("`", "``")); + return execute(sink -> executeWhenTransaction(sink, sql)) + .handle(ExceptionFactory.withSql(sql)::handleErrorResponse) + .then(); + } finally { + lock.unlock(); + } + } + + /** + * Specific implementation, to avoid changing autocommit mode if already in this autocommit mode + * + * @return publisher + */ + public Mono setAutoCommit(boolean autoCommit) { + try { + lock.lock(); + return execute( + sink -> { + String sql = "SET autocommit=" + (autoCommit ? '1' : '0'); + if (!this.exchangeQueue.isEmpty() || autoCommit != isAutoCommit()) { + + try { + Exchange exchange = new Exchange(sink, DecoderState.QUERY_RESPONSE, sql); + if (this.exchangeQueue.offer(exchange)) { + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + sendClientMsgs(Mono.just(new QueryPacket(sql))); + } else { + sink.error( + new R2dbcTransientResourceException("Request queue limit reached")); + } + } catch (Throwable t) { + t.printStackTrace(); + throw t; + } + + } else { + logger.debug("Skipping autocommit since already in that state"); + sink.complete(); + } + }) + .handle(ExceptionFactory.withSql(null)::handleErrorResponse) + .then(); + } finally { + lock.unlock(); + } + } + + public Flux receive(DecoderState initialState) { + return Flux.create( + sink -> { + try { + lock.lock(); + Exchange exchange = new Exchange(sink, initialState); + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + if (!this.exchangeQueue.offer(exchange)) { + sink.error( + new R2dbcTransientResourceException( + "Request queue limit reached during handshake")); + } + } catch (Throwable t) { + t.printStackTrace(); + throw t; + } finally { + lock.unlock(); + } + }); + } + + public void setContext(InitialHandshakePacket handshake, long clientCapabilities) { + this.context = + !HaMode.NONE.equals(configuration.getHaMode()) && configuration.isTransactionReplay() + ? new RedoContext( + handshake.getServerVersion(), + handshake.getThreadId(), + handshake.getCapabilities(), + handshake.getServerStatus(), + handshake.isMariaDBServer(), + clientCapabilities, + configuration.getDatabase(), + byteBufAllocator, + configuration.getIsolationLevel()) + : new SimpleContext( + handshake.getServerVersion(), + handshake.getThreadId(), + handshake.getCapabilities(), + handshake.getServerStatus(), + handshake.isMariaDBServer(), + clientCapabilities, + configuration.getDatabase(), + byteBufAllocator, + configuration.getIsolationLevel()); + decoder.setContext(context); + encoder.setContext(context); + } + + /** + * Get current server autocommit. + * + * @return autocommit current server value. + */ + @Override + public boolean isAutoCommit() { + return (this.context.getServerStatus() & ServerStatus.AUTOCOMMIT) > 0; + } + + @Override + public boolean isInTransaction() { + return (this.context.getServerStatus() & ServerStatus.IN_TRANSACTION) > 0; + } + + @Override + public boolean noBackslashEscapes() { + return (this.context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0; + } + + @Override + public ServerVersion getVersion() { + return (this.context != null) ? this.context.getVersion() : ServerVersion.UNKNOWN_VERSION; + } + + @Override + public boolean isConnected() { + if (this.isClosed.get()) { + return false; + } + return this.connection.channel().isOpen(); + } + + @Override + public boolean isCloseRequested() { + return this.closeRequested; + } + + protected class ServerMessageSubscriber implements CoreSubscriber { + private Subscription upstream; + private AtomicBoolean close; + private final AtomicLong receiverDemands = new AtomicLong(0); + private final ReentrantLock lock; + private final Queue exchangeQueue; + private final Queue receiverQueue; + + public ServerMessageSubscriber( + ReentrantLock lock, + AtomicBoolean close, + Queue exchangeQueue, + Queue receiverQueue) { + this.lock = lock; + this.close = close; + this.receiverQueue = receiverQueue; + this.exchangeQueue = exchangeQueue; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.upstream = subscription; + } + + public void onError(Throwable t) {} + + public void onComplete() {} + + @Override + public void onNext(ByteBuf message) { + if (this.close.get()) { + message.release(); + Operators.onNextDropped(message, currentContext()); + return; + } + + this.receiverDemands.decrementAndGet(); + Exchange exchange = this.exchangeQueue.peek(); + ServerMessage srvMsg = decoder.decode(message, exchange); + + if (this.receiverQueue.isEmpty() && exchange != null && exchange.hasDemand()) { + // nothing buffered => directly emit message + if (srvMsg.ending()) this.exchangeQueue.poll(); + exchange.emit(srvMsg); + return; + } + + // queue message + if (!this.receiverQueue.offer(srvMsg)) { + message.release(); + Operators.onNextDropped(message, currentContext()); + onError( + new R2dbcNonTransientResourceException("unexpected : server message queue is full")); + return; + } + + tryDrainQueue(); + } + + public void onRequest(Exchange exchange, long n) { + exchange.incrementDemand(n); + requestQueueFilling(); + tryDrainQueue(); + } + + private void requestQueueFilling() { + if (this.receiverQueue.isEmpty() + && this.receiverDemands.compareAndSet(0, Queues.SMALL_BUFFER_SIZE)) { + this.upstream.request(Queues.SMALL_BUFFER_SIZE); + } + } + + private void tryDrainQueue() { + Exchange exchange; + ServerMessage srvMsg; + while (!this.receiverQueue.isEmpty()) { + if (!lock.tryLock()) return; + try { + while (!this.receiverQueue.isEmpty()) { + if ((exchange = this.exchangeQueue.peek()) == null || !exchange.hasDemand()) return; + if ((srvMsg = this.receiverQueue.poll()) == null) return; + if (srvMsg.ending()) this.exchangeQueue.poll(); + exchange.emit(srvMsg); + } + } finally { + lock.unlock(); + } + + if ((exchange = this.exchangeQueue.peek()) == null || exchange.hasDemand()) { + requestQueueFilling(); + } + } + } + + public void endExchanges(Throwable exception) { + Exchange exchange; + while ((exchange = this.exchangeQueue.poll()) != null) { + exchange.getSink().error(exception); + } + } + } + + public void sendCommandWithoutResult(ClientMessage message) { + try { + lock.lock(); + sendClientMsgs(Mono.just(message)); + } finally { + lock.unlock(); + } + } + + public Flux sendCommand(ClientMessage message, boolean canSafelyBeReExecuted) { + return sendCommand(message, DecoderState.QUERY_RESPONSE, null, canSafelyBeReExecuted); + } + + public Flux sendCommand( + ClientMessage message, DecoderState initialState, boolean canSafelyBeReExecuted) { + return sendCommand(message, initialState, null, canSafelyBeReExecuted); + } + + public Flux sendCommand( + ClientMessage message, DecoderState initialState, String sql, boolean canSafelyBeReExecuted) { + return Flux.create( + sink -> { + if (!isConnected()) { + sink.error( + new R2dbcNonTransientResourceException( + "Connection is close. Cannot send anything")); + return; + } + try { + lock.lock(); + Exchange exchange = new Exchange(sink, initialState, sql); + if (this.exchangeQueue.offer(exchange)) { + if (message instanceof PreparePacket) { + decoder.addPrepare(((PreparePacket) message).getSql()); + } + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + sendClientMsgs(Mono.just(message)); + } else { + sink.error(new R2dbcTransientResourceException("Request queue limit reached")); + } + } catch (Throwable t) { + sink.error(t); + } finally { + lock.unlock(); + } + }); + } + + public Mono sendPrepare( + ClientMessage requests, ExceptionFactory factory, String sql) { + return sendCommand(requests, DecoderState.PREPARE_RESPONSE, sql, true) + .handle( + (it, sink) -> { + if (it instanceof ErrorPacket) { + sink.error(factory.from((ErrorPacket) it)); + return; + } + if (it instanceof CompletePrepareResult) { + sink.next(((CompletePrepareResult) it).getPrepare()); + } + if (it.ending()) { + sink.complete(); + } + }) + .cast(ServerPrepareResult.class) + .singleOrEmpty(); + } + + public Flux sendCommand( + PreparePacket preparePacket, ExecutePacket executePacket, boolean canSafelyBeReExecuted) { + return Flux.create( + sink -> { + if (!isConnected()) { + sink.error( + new R2dbcNonTransientResourceException( + "Connection is close. Cannot send anything")); + return; + } + try { + lock.lock(); + Exchange exchange = + new Exchange( + sink, DecoderState.PREPARE_AND_EXECUTE_RESPONSE, preparePacket.getSql()); + if (this.exchangeQueue.offer(exchange)) { + sink.onRequest(value -> messageSubscriber.onRequest(exchange, value)); + decoder.addPrepare(preparePacket.getSql()); + sendClientMsgs(Flux.just(preparePacket, executePacket)); + } else { + sink.error(new R2dbcTransientResourceException("Request queue limit reached")); + return; + } + } catch (Throwable t) { + t.printStackTrace(); + sink.error(t); + } finally { + lock.unlock(); + } + }); + } + + public HostAddress getHostAddress() { + return hostAddress; + } + + public PrepareCache getPrepareCache() { + return prepareCache; + } + + @Override + public String toString() { + return "Client{isClosed=" + isClosed + ", context=" + context + '}'; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/SimpleContext.java b/src/main/java/org/mariadb/r2dbc/client/SimpleContext.java new file mode 100644 index 00000000..c092a789 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/SimpleContext.java @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.client; + +import io.netty.buffer.ByteBufAllocator; +import io.r2dbc.spi.IsolationLevel; +import org.mariadb.r2dbc.message.Context; + +public class SimpleContext implements Context { + + private final long threadId; + private final long serverCapabilities; + private final long clientCapabilities; + private short serverStatus; + private final ServerVersion version; + private final ByteBufAllocator byteBufAllocator; + private IsolationLevel isolationLevel; + private String database; + + public SimpleContext( + String serverVersion, + long threadId, + long capabilities, + short serverStatus, + boolean mariaDBServer, + long clientCapabilities, + String database, + ByteBufAllocator byteBufAllocator, + IsolationLevel isolationLevel) { + + this.threadId = threadId; + this.serverCapabilities = capabilities; + this.clientCapabilities = clientCapabilities; + this.serverStatus = serverStatus; + this.version = new ServerVersion(serverVersion, mariaDBServer); + this.isolationLevel = isolationLevel == null ? IsolationLevel.REPEATABLE_READ : isolationLevel; + this.database = database; + this.byteBufAllocator = byteBufAllocator; + } + + public long getThreadId() { + return threadId; + } + + public long getServerCapabilities() { + return serverCapabilities; + } + + public long getClientCapabilities() { + return clientCapabilities; + } + + public short getServerStatus() { + return serverStatus; + } + + public void setServerStatus(short serverStatus) { + this.serverStatus = serverStatus; + } + + public String getDatabase() { + return database; + } + + public void setDatabase(String database) { + this.database = database; + } + + public IsolationLevel getIsolationLevel() { + return isolationLevel; + } + + public void setIsolationLevel(IsolationLevel isolationLevel) { + this.isolationLevel = isolationLevel; + } + + public ServerVersion getVersion() { + return version; + } + + public ByteBufAllocator getByteBufAllocator() { + return byteBufAllocator; + } + + @Override + public String toString() { + return "ConnectionContext{" + "threadId=" + threadId + ", version=" + version + '}'; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/client/TransactionSaver.java b/src/main/java/org/mariadb/r2dbc/client/TransactionSaver.java new file mode 100644 index 00000000..d90c63d0 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/client/TransactionSaver.java @@ -0,0 +1,54 @@ +package org.mariadb.r2dbc.client; + +import java.util.Queue; +import org.mariadb.r2dbc.message.ClientMessage; +import reactor.util.concurrent.Queues; + +/** + * Transaction cache Huge command are not cached, cache is limited to configuration + * transactionReplaySize commands + */ +public class TransactionSaver { + private final Queue messages = + Queues.get(Queues.SMALL_BUFFER_SIZE).get(); + private boolean dirty = false; + + /** + * Add a command to cache. + * + * @param clientMessage client message + */ + public void add(ClientMessage clientMessage) { + if (!messages.offer(clientMessage)) { + dirty = true; + } + } + + /** Transaction finished, clearing cache */ + public void clear() { + messages.clear(); + dirty = false; + } + + /** + * Is cache not valid (some commands have not been cached) + * + * @return is dirty + */ + public boolean isDirty() { + return dirty; + } + + public void forceDirty() { + dirty = true; + } + + /** + * cache buffer + * + * @return cached messages + */ + public Queue getMessages() { + return messages; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/codec/BinaryRowDecoder.java b/src/main/java/org/mariadb/r2dbc/codec/BinaryRowDecoder.java index 7e8be833..5de6d01f 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/BinaryRowDecoder.java +++ b/src/main/java/org/mariadb/r2dbc/codec/BinaryRowDecoder.java @@ -1,23 +1,28 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec; import io.netty.buffer.ByteBuf; +import java.util.List; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; public class BinaryRowDecoder extends RowDecoder { - private int columnNumber; - private ColumnDefinitionPacket[] columns; - private byte[] nullBitmap; + private final int columnNumber; + private final List columns; + private final byte[] nullBitmap; public BinaryRowDecoder( - int columnNumber, ColumnDefinitionPacket[] columns, MariadbConnectionConfiguration conf) { - super(conf); + List columns, + MariadbConnectionConfiguration conf, + ExceptionFactory factory) { + super(conf, factory); this.columns = columns; - this.columnNumber = columnNumber; + this.columnNumber = columns.size(); + nullBitmap = new byte[(columnNumber + 9) / 8]; } @SuppressWarnings("unchecked") @@ -26,7 +31,7 @@ public T get(int index, ColumnDefinitionPacket column, Class type) // check NULL-Bitmap that indicate if field is null if ((nullBitmap[(index + 2) / 8] & (1 << ((index + 2) % 8))) != 0) { - if (type.isPrimitive()) { + if (type != null && type.isPrimitive()) { throw new IllegalArgumentException( String.format("Cannot return null for primitive %s", type.getName())); } @@ -36,14 +41,14 @@ public T get(int index, ColumnDefinitionPacket column, Class type) setPosition(index); // type generic, return "natural" java type - if (Object.class == type || type == null) { - Codec defaultCodec = ((Codec) column.getDefaultCodec(conf)); - return defaultCodec.decodeBinary(buf, length, column, type); + if (Object.class == type) { + Codec defaultCodec = ((Codec) column.getType().getDefaultCodec()); + return defaultCodec.decodeBinary(buf, length, column, type, factory); } for (Codec codec : Codecs.LIST) { if (codec.canDecode(column, type)) { - return ((Codec) codec).decodeBinary(buf, length, column, type); + return ((Codec) codec).decodeBinary(buf, length, column, type, factory); } } @@ -55,7 +60,6 @@ public T get(int index, ColumnDefinitionPacket column, Class type) @Override public void resetRow(ByteBuf buf) { buf.skipBytes(1); // skip 0x00 header - nullBitmap = new byte[(columnNumber + 9) / 8]; buf.readBytes(nullBitmap); super.resetRow(buf); } @@ -77,7 +81,7 @@ public void setPosition(int newIndex) { for (; index < newIndex; index++) { if ((nullBitmap[(index + 2) / 8] & (1 << ((index + 2) % 8))) == 0) { // skip bytes - switch (columns[index].getType()) { + switch (columns.get(index).getDataType()) { case BIGINT: case DOUBLE: buf.skipBytes(8); @@ -126,7 +130,7 @@ public void setPosition(int newIndex) { } // read asked field position and length - switch (columns[index].getType()) { + switch (columns.get(index).getDataType()) { case BIGINT: case DOUBLE: length = 8; diff --git a/src/main/java/org/mariadb/r2dbc/codec/Codec.java b/src/main/java/org/mariadb/r2dbc/codec/Codec.java index 5ff3bc4a..9dfe59f3 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/Codec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/Codec.java @@ -1,11 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec; import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; +import io.netty.buffer.ByteBufAllocator; +import java.util.function.Supplier; +import org.mariadb.r2dbc.ExceptionFactory; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; +import reactor.core.publisher.Mono; public interface Codec { @@ -13,14 +18,32 @@ public interface Codec { boolean canEncode(Class value); - T decodeText(ByteBuf buffer, int length, ColumnDefinitionPacket column, Class type); - - void encodeText(ByteBuf buf, Context context, T value); + T decodeText( + ByteBuf buffer, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory); T decodeBinary( - ByteBuf buffer, int length, ColumnDefinitionPacket column, Class type); + ByteBuf buffer, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory); + + BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory); - void encodeBinary(ByteBuf buf, Context context, T value); + BindValue encodeBinary(ByteBufAllocator allocator, Object value, ExceptionFactory factory); DataType getBinaryEncodeType(); + + default BindValue createEncodedValue(Supplier bufferSupplier) { + return new BindValue(this, Mono.fromSupplier(bufferSupplier)); + } + + default BindValue createEncodedValue(Mono value) { + return new BindValue(this, value); + } } diff --git a/src/main/java/org/mariadb/r2dbc/codec/Codecs.java b/src/main/java/org/mariadb/r2dbc/codec/Codecs.java index 08fa5a2d..d078701b 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/Codecs.java +++ b/src/main/java/org/mariadb/r2dbc/codec/Codecs.java @@ -1,19 +1,38 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec; +import io.r2dbc.spi.*; +import java.io.InputStream; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.BitSet; import java.util.HashMap; import java.util.Map; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.list.*; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.Protocol; +import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.BindValue; public final class Codecs { + private static final Map> r2dbcTypeMapper = r2dbcTypeToDataTypeMap(); + private static final Map, Codec> codecMapper = classToCodecMap(); + public static final Codec[] LIST = new Codec[] { BigDecimalCodec.INSTANCE, BigIntegerCodec.INSTANCE, BooleanCodec.INSTANCE, + ByteBufferCodec.INSTANCE, BlobCodec.INSTANCE, ByteArrayCodec.INSTANCE, ByteCodec.INSTANCE, @@ -32,35 +51,116 @@ public final class Codecs { StringCodec.INSTANCE }; - // association with enum, since doesn't supporting generics in enum :( - public static final Map> CODEC_LIST = new HashMap<>(); - - static { - CODEC_LIST.put(DataType.OLDDECIMAL, BigDecimalCodec.INSTANCE); - CODEC_LIST.put(DataType.TINYINT, ByteCodec.INSTANCE); - CODEC_LIST.put(DataType.SMALLINT, ShortCodec.INSTANCE); - CODEC_LIST.put(DataType.INTEGER, IntCodec.INSTANCE); - CODEC_LIST.put(DataType.DOUBLE, DoubleCodec.INSTANCE); - CODEC_LIST.put(DataType.NULL, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.TIMESTAMP, LocalDateTimeCodec.INSTANCE); - CODEC_LIST.put(DataType.BIGINT, LongCodec.INSTANCE); - CODEC_LIST.put(DataType.DATE, LocalDateTimeCodec.INSTANCE); - CODEC_LIST.put(DataType.TIME, DurationCodec.INSTANCE); - CODEC_LIST.put(DataType.DATETIME, LocalDateTimeCodec.INSTANCE); - CODEC_LIST.put(DataType.YEAR, ShortCodec.INSTANCE); - CODEC_LIST.put(DataType.NEWDATE, LocalDateTimeCodec.INSTANCE); - CODEC_LIST.put(DataType.VARCHAR, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.BIT, BitSetCodec.INSTANCE); - CODEC_LIST.put(DataType.JSON, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.DECIMAL, BigDecimalCodec.INSTANCE); - CODEC_LIST.put(DataType.ENUM, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.SET, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.TINYBLOB, ByteArrayCodec.INSTANCE); - CODEC_LIST.put(DataType.MEDIUMBLOB, ByteArrayCodec.INSTANCE); - CODEC_LIST.put(DataType.LONGBLOB, BlobCodec.INSTANCE); - CODEC_LIST.put(DataType.BLOB, ByteArrayCodec.INSTANCE); - CODEC_LIST.put(DataType.VARSTRING, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.STRING, StringCodec.INSTANCE); - CODEC_LIST.put(DataType.GEOMETRY, ByteArrayCodec.INSTANCE); + public static BindValue encodeNull(Class type, int index) { + return new BindValue(codecFromClass(type, index), BindValue.NULL_VALUE); + } + + public static BindValue encode( + Object value, int index, Protocol protocol, ExceptionFactory factory, Context context) { + + Codec codec = StringCodec.INSTANCE; + Object parameterValue = value; + + if (value instanceof Parameter) { + Parameter parameter = (Parameter) value; + parameterValue = parameter.getValue(); + + if (parameterValue == null) { + if (parameter.getType() instanceof R2dbcType) { + codec = codecFromR2dbcType((R2dbcType) parameter.getType()); + return new BindValue(codec, BindValue.NULL_VALUE); + } + return new BindValue(codec, BindValue.NULL_VALUE); + } + } + + codec = codecFromClass(parameterValue.getClass(), index); + + if (parameterValue == null) { + return new BindValue(codec, BindValue.NULL_VALUE); + } + + if (protocol == Protocol.TEXT) + return codec.encodeText(context.getByteBufAllocator(), parameterValue, context, factory); + return codec.encodeBinary(context.getByteBufAllocator(), parameterValue, factory); + } + + public static Codec codecFromClass(Class javaType, int index) { + if (javaType == null) return StringCodec.INSTANCE; + Codec codec = codecMapper.get(javaType); + if (codec != null) return codec; + codec = codecByClass(javaType, index); + if (codec != null) return codec; + return StringCodec.INSTANCE; + } + + public static Codec codecFromR2dbcType(R2dbcType type) { + Assert.requireNonNull(type, "type must not be null"); + Codec codec = r2dbcTypeMapper.get(type); + if (codec != null) return codec; + return StringCodec.INSTANCE; + } + + private static Map> r2dbcTypeToDataTypeMap() { + Map> myMap = new HashMap<>(); + myMap.put(R2dbcType.NCHAR, StringCodec.INSTANCE); + myMap.put(R2dbcType.CHAR, StringCodec.INSTANCE); + myMap.put(R2dbcType.NVARCHAR, StringCodec.INSTANCE); + myMap.put(R2dbcType.VARCHAR, StringCodec.INSTANCE); + myMap.put(R2dbcType.CLOB, ClobCodec.INSTANCE); + myMap.put(R2dbcType.NCLOB, ClobCodec.INSTANCE); + myMap.put(R2dbcType.BOOLEAN, BooleanCodec.INSTANCE); + myMap.put(R2dbcType.TINYINT, ByteCodec.INSTANCE); + myMap.put(R2dbcType.BINARY, ByteArrayCodec.INSTANCE); + myMap.put(R2dbcType.VARBINARY, ByteArrayCodec.INSTANCE); + myMap.put(R2dbcType.BLOB, BlobCodec.INSTANCE); + myMap.put(R2dbcType.INTEGER, IntCodec.INSTANCE); + myMap.put(R2dbcType.SMALLINT, ShortCodec.INSTANCE); + myMap.put(R2dbcType.BIGINT, BigIntegerCodec.INSTANCE); + myMap.put(R2dbcType.NUMERIC, BigDecimalCodec.INSTANCE); + myMap.put(R2dbcType.DECIMAL, BigDecimalCodec.INSTANCE); + myMap.put(R2dbcType.FLOAT, FloatCodec.INSTANCE); + myMap.put(R2dbcType.REAL, FloatCodec.INSTANCE); + myMap.put(R2dbcType.DOUBLE, DoubleCodec.INSTANCE); + myMap.put(R2dbcType.DATE, LocalDateCodec.INSTANCE); + myMap.put(R2dbcType.TIME, LocalTimeCodec.INSTANCE); + myMap.put(R2dbcType.TIME_WITH_TIME_ZONE, LocalTimeCodec.INSTANCE); + myMap.put(R2dbcType.TIMESTAMP, LocalDateTimeCodec.INSTANCE); + myMap.put(R2dbcType.TIMESTAMP_WITH_TIME_ZONE, LocalDateTimeCodec.INSTANCE); + return myMap; + } + + private static Map, Codec> classToCodecMap() { + Map, Codec> myMap = new HashMap<>(); + myMap.put(String.class, StringCodec.INSTANCE); + myMap.put(Clob.class, ClobCodec.INSTANCE); + myMap.put(InputStream.class, StreamCodec.INSTANCE); + myMap.put(Boolean.class, BooleanCodec.INSTANCE); + myMap.put(byte[].class, ByteArrayCodec.INSTANCE); + myMap.put(Blob.class, BlobCodec.INSTANCE); + myMap.put(ByteBuffer.class, ByteBufferCodec.INSTANCE); + myMap.put(BitSet.class, BitSetCodec.INSTANCE); + myMap.put(Integer.class, IntCodec.INSTANCE); + myMap.put(Short.class, ShortCodec.INSTANCE); + myMap.put(BigInteger.class, BigIntegerCodec.INSTANCE); + myMap.put(Long.class, LongCodec.INSTANCE); + myMap.put(BigDecimal.class, BigDecimalCodec.INSTANCE); + myMap.put(Float.class, FloatCodec.INSTANCE); + myMap.put(Double.class, DoubleCodec.INSTANCE); + myMap.put(LocalDate.class, LocalDateCodec.INSTANCE); + myMap.put(LocalTime.class, LocalTimeCodec.INSTANCE); + myMap.put(Duration.class, DurationCodec.INSTANCE); + myMap.put(LocalDateTime.class, LocalDateTimeCodec.INSTANCE); + return myMap; + } + + public static Codec codecByClass(Class value, int index) { + for (Codec codec : Codecs.LIST) { + if (codec.canEncode(value)) { + return codec; + } + } + throw new IllegalArgumentException( + String.format("No encoder for class %s (parameter at index %s) ", value.getName(), index)); } } diff --git a/src/main/java/org/mariadb/r2dbc/codec/DataType.java b/src/main/java/org/mariadb/r2dbc/codec/DataType.java index f0b2fcae..13d0c8b2 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/DataType.java +++ b/src/main/java/org/mariadb/r2dbc/codec/DataType.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec; @@ -19,7 +19,7 @@ public enum DataType { DATETIME(12), YEAR(13), NEWDATE(14), - VARCHAR(15), + TEXT(15), BIT(16), JSON(245), DECIMAL(246), @@ -65,7 +65,7 @@ public static DataType fromServer(int typeValue, int charsetNumber) { if (charsetNumber != 63 && typeValue >= 249 && typeValue <= 252) { // MariaDB Text dataType - return DataType.VARCHAR; + return DataType.TEXT; } return dataType; diff --git a/src/main/java/org/mariadb/r2dbc/codec/Parameter.java b/src/main/java/org/mariadb/r2dbc/codec/Parameter.java deleted file mode 100644 index 3039fc9e..00000000 --- a/src/main/java/org/mariadb/r2dbc/codec/Parameter.java +++ /dev/null @@ -1,58 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.codec; - -import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.util.BufferUtils; - -public class Parameter { - @SuppressWarnings({"rawtypes", "unchecked"}) - public static final Parameter NULL_PARAMETER = - new Parameter(null, null) { - @Override - public void encodeText(ByteBuf out, Context context) { - BufferUtils.writeAscii(out, "null"); - } - - @Override - public DataType getBinaryEncodeType() { - return DataType.VARCHAR; - } - - @Override - public boolean isNull() { - return true; - } - }; - - private final Codec codec; - private final T value; - - public Parameter(Codec codec, T value) { - this.codec = codec; - this.value = value; - } - - public void encodeText(ByteBuf out, Context context) { - codec.encodeText(out, context, this.value); - } - - public void encodeBinary(ByteBuf out, Context context) { - codec.encodeBinary(out, context, this.value); - } - - public DataType getBinaryEncodeType() { - return codec.getBinaryEncodeType(); - } - - public boolean isNull() { - return false; - } - - @Override - public String toString() { - return "Parameter{codec=" + codec.getClass().getSimpleName() + ", value=" + value + '}'; - } -} diff --git a/src/main/java/org/mariadb/r2dbc/codec/RowDecoder.java b/src/main/java/org/mariadb/r2dbc/codec/RowDecoder.java index f040392c..0309d44c 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/RowDecoder.java +++ b/src/main/java/org/mariadb/r2dbc/codec/RowDecoder.java @@ -1,12 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec; import io.netty.buffer.ByteBuf; import java.util.EnumSet; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.Assert; public abstract class RowDecoder { protected static final int NULL_LENGTH = -1; @@ -15,9 +17,12 @@ public abstract class RowDecoder { protected int length; protected int index; protected MariadbConnectionConfiguration conf; + protected ExceptionFactory factory; - public RowDecoder(MariadbConnectionConfiguration conf) { + public RowDecoder(MariadbConnectionConfiguration conf, ExceptionFactory factory) { + Assert.requireNonNull(factory, "missing factory parameter"); this.conf = conf; + this.factory = factory; } public void resetRow(ByteBuf buf) { @@ -36,18 +41,18 @@ protected IllegalArgumentException noDecoderException( DataType.MEDIUMINT, DataType.INTEGER, DataType.BIGINT) - .contains(column.getType())) { + .contains(column.getDataType())) { throw new IllegalArgumentException( String.format( "No decoder for type %s[] and column type %s(%s)", type.getComponentType().getName(), - column.getType().toString(), + column.getDataType().toString(), column.isSigned() ? "signed" : "unsigned")); } throw new IllegalArgumentException( String.format( "No decoder for type %s[] and column type %s", - type.getComponentType().getName(), column.getType().toString())); + type.getComponentType().getName(), column.getDataType().toString())); } if (EnumSet.of( DataType.TINYINT, @@ -55,18 +60,18 @@ protected IllegalArgumentException noDecoderException( DataType.MEDIUMINT, DataType.INTEGER, DataType.BIGINT) - .contains(column.getType())) { + .contains(column.getDataType())) { throw new IllegalArgumentException( String.format( "No decoder for type %s and column type %s(%s)", type.getName(), - column.getType().toString(), + column.getDataType().toString(), column.isSigned() ? "signed" : "unsigned")); } throw new IllegalArgumentException( String.format( "No decoder for type %s and column type %s", - type.getName(), column.getType().toString())); + type.getName(), column.getDataType().toString())); } public abstract void setPosition(int position); diff --git a/src/main/java/org/mariadb/r2dbc/codec/TextRowDecoder.java b/src/main/java/org/mariadb/r2dbc/codec/TextRowDecoder.java index 55239f16..5f859d3e 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/TextRowDecoder.java +++ b/src/main/java/org/mariadb/r2dbc/codec/TextRowDecoder.java @@ -1,16 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec; +import java.util.List; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; public class TextRowDecoder extends RowDecoder { public TextRowDecoder( - int columnNumber, ColumnDefinitionPacket[] columns, MariadbConnectionConfiguration conf) { - super(conf); + List columns, + MariadbConnectionConfiguration conf, + ExceptionFactory factory) { + super(conf, factory); } @SuppressWarnings("unchecked") @@ -28,13 +32,13 @@ public T get(int index, ColumnDefinitionPacket column, Class type) // type generic, return "natural" java type if (Object.class == type || type == null) { - Codec defaultCodec = ((Codec) column.getDefaultCodec(conf)); - return defaultCodec.decodeText(buf, length, column, type); + Codec defaultCodec = ((Codec) column.getType().getDefaultCodec()); + return defaultCodec.decodeText(buf, length, column, type, factory); } for (Codec codec : Codecs.LIST) { if (codec.canDecode(column, type)) { - return ((Codec) codec).decodeText(buf, length, column, type); + return ((Codec) codec).decodeText(buf, length, column, type, factory); } } diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/BigDecimalCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/BigDecimalCodec.java index 0ce27c1d..78fe5571 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/BigDecimalCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/BigDecimalCodec.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class BigDecimalCodec implements Codec { @@ -32,12 +34,13 @@ public class BigDecimalCodec implements Codec { DataType.OLDDECIMAL, DataType.YEAR, DataType.DECIMAL, - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(BigDecimal.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) + && type.isAssignableFrom(BigDecimal.class); } public boolean canEncode(Class value) { @@ -46,8 +49,12 @@ public boolean canEncode(Class value) { @Override public BigDecimal decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case TINYINT: case SMALLINT: case MEDIUMINT: @@ -74,7 +81,7 @@ public BigDecimal decodeText( try { return new BigDecimal(str); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as BigDecimal", str)); } } @@ -82,9 +89,13 @@ public BigDecimal decodeText( @Override public BigDecimal decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { - switch (column.getType()) { + switch (column.getDataType()) { case TINYINT: if (!column.isSigned()) { return BigDecimal.valueOf(buf.readUnsignedByte()); @@ -145,22 +156,24 @@ public BigDecimal decodeBinary( try { return new BigDecimal(str); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as BigDecimal", str)); } } } @Override - public void encodeText(ByteBuf buf, Context context, BigDecimal value) { - BufferUtils.writeAscii(buf, value.toPlainString()); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> BufferUtils.encodeAscii(allocator, ((BigDecimal) value).toPlainString())); } @Override - public void encodeBinary(ByteBuf buf, Context context, BigDecimal value) { - String asciiFormat = value.toPlainString(); - BufferUtils.writeLengthEncode(asciiFormat.length(), buf); - BufferUtils.writeAscii(buf, asciiFormat); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> BufferUtils.encodeLengthAscii(allocator, ((BigDecimal) value).toPlainString())); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/BigIntegerCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/BigIntegerCodec.java index dc0a765b..cc502757 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/BigIntegerCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/BigIntegerCodec.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class BigIntegerCodec implements Codec { @@ -32,12 +34,13 @@ public class BigIntegerCodec implements Codec { DataType.OLDDECIMAL, DataType.FLOAT, DataType.BIT, - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(BigInteger.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) + && type.isAssignableFrom(BigInteger.class); } public boolean canEncode(Class value) { @@ -46,9 +49,13 @@ public boolean canEncode(Class value) { @Override public BigInteger decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { - switch (column.getType()) { + switch (column.getDataType()) { case FLOAT: case DOUBLE: case DECIMAL: @@ -78,7 +85,7 @@ public BigInteger decodeText( try { return new BigDecimal(str2).toBigIntegerExact(); } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as BigInteger", str2)); } } @@ -86,9 +93,13 @@ public BigInteger decodeText( @Override public BigInteger decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { - switch (column.getType()) { + switch (column.getDataType()) { case BIT: long result = 0; for (int i = 0; i < length; i++) { @@ -147,22 +158,22 @@ public BigInteger decodeBinary( try { return new BigInteger(str); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as BigInteger", str)); } } } @Override - public void encodeText(ByteBuf buf, Context context, BigInteger value) { - BufferUtils.writeAscii(buf, value.toString()); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeAscii(allocator, value.toString())); } @Override - public void encodeBinary(ByteBuf buf, Context context, BigInteger value) { - String asciiFormat = value.toString(); - BufferUtils.writeLengthEncode(asciiFormat.length(), buf); - BufferUtils.writeAscii(buf, asciiFormat); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeLengthAscii(allocator, value.toString())); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/BitSetCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/BitSetCodec.java index 9830695f..1638fcc4 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/BitSetCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/BitSetCodec.java @@ -1,15 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import java.nio.charset.StandardCharsets; +import io.netty.buffer.ByteBufAllocator; import java.util.BitSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class BitSetCodec implements Codec { @@ -37,18 +39,26 @@ public static void revertOrder(byte[] array) { } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return column.getType() == DataType.BIT && type.isAssignableFrom(BitSet.class); + return column.getDataType() == DataType.BIT && type.isAssignableFrom(BitSet.class); } @Override public BitSet decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { return parseBit(buf, length); } @Override public BitSet decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { return parseBit(buf, length); } @@ -57,24 +67,31 @@ public boolean canEncode(Class value) { } @Override - public void encodeText(ByteBuf buf, Context context, BitSet value) { - byte[] bytes = value.toByteArray(); - revertOrder(bytes); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + byte[] bytes = ((BitSet) value).toByteArray(); + revertOrder(bytes); - StringBuilder sb = new StringBuilder(bytes.length * Byte.SIZE + 3); - sb.append("b'"); - for (int i = 0; i < Byte.SIZE * bytes.length; i++) - sb.append((bytes[i / Byte.SIZE] << i % Byte.SIZE & 0x80) == 0 ? '0' : '1'); - sb.append("'"); - buf.writeCharSequence(sb.toString(), StandardCharsets.US_ASCII); + StringBuilder sb = new StringBuilder(bytes.length * Byte.SIZE + 3); + sb.append("b'"); + for (int i = 0; i < Byte.SIZE * bytes.length; i++) + sb.append((bytes[i / Byte.SIZE] << i % Byte.SIZE & 0x80) == 0 ? '0' : '1'); + sb.append("'"); + return BufferUtils.encodeAscii(allocator, sb.toString()); + }); } @Override - public void encodeBinary(ByteBuf buf, Context context, BitSet value) { - byte[] bytes = value.toByteArray(); - revertOrder(bytes); - BufferUtils.writeLengthEncode(bytes.length, buf); - buf.writeBytes(bytes); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + byte[] bytes = ((BitSet) value).toByteArray(); + revertOrder(bytes); + return BufferUtils.encodeLengthBytes(allocator, bytes); + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/BlobCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/BlobCodec.java index 61cd51a9..0acb5ef3 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/BlobCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/BlobCodec.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; import io.r2dbc.spi.Blob; -import io.r2dbc.spi.R2dbcNonTransientResourceException; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -31,24 +33,28 @@ public class BlobCodec implements Codec { DataType.LONGBLOB, DataType.STRING, DataType.VARSTRING, - DataType.VARCHAR); + DataType.TEXT); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(Blob.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) && type.isAssignableFrom(Blob.class); } @Override public Blob decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case STRING: - case VARCHAR: + case TEXT: case VARSTRING: if (!column.isBinary()) { buf.skipBytes(length); - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format( - "Data type %s (not binary) cannot be decoded as Blob", column.getType())); + "Data type %s (not binary) cannot be decoded as Blob", column.getDataType())); } return new MariaDbBlob(buf.readRetainedSlice(length)); @@ -60,8 +66,12 @@ public Blob decodeText( @Override public Blob decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case BIT: case TINYBLOB: case MEDIUMBLOB: @@ -74,9 +84,9 @@ public Blob decodeBinary( // STRING, VARCHAR, VARSTRING: if (!column.isBinary()) { buf.skipBytes(length); - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format( - "Data type %s (not binary) cannot be decoded as Blob", column.getType())); + "Data type %s (not binary) cannot be decoded as Blob", column.getDataType())); } return new MariaDbBlob(buf.readRetainedSlice(length)); } @@ -87,49 +97,37 @@ public boolean canEncode(Class value) { } @Override - public void encodeText(ByteBuf buf, Context context, Blob value) { - buf.writeBytes("_binary '".getBytes(StandardCharsets.US_ASCII)); - Flux.from(value.stream()) - .handle( - (tempVal, sync) -> { - if (tempVal.hasArray()) { - BufferUtils.writeEscaped( - buf, tempVal.array(), tempVal.arrayOffset(), tempVal.remaining(), context); - } else { - byte[] intermediaryBuf = new byte[tempVal.remaining()]; - tempVal.get(intermediaryBuf); - BufferUtils.writeEscaped(buf, intermediaryBuf, 0, intermediaryBuf.length, context); - } - sync.next(buf); - }) - .doOnComplete( - () -> { - buf.writeByte((byte) '\''); - }) - .subscribe(); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + Flux.from(((Blob) value).stream()) + .reduce( + allocator.compositeBuffer(), + (a, b) -> a.addComponent(true, Unpooled.wrappedBuffer(b))) + .map( + b -> { + ByteBuf returnedBuf = BufferUtils.encodeEscapedBuffer(allocator, b, context); + b.release(); + return returnedBuf; + }) + .doOnSubscribe(e -> ((Blob) value).discard())); } @Override - public void encodeBinary(ByteBuf buf, Context context, Blob value) { - buf.writeByte(0xfe); - int initialPos = buf.writerIndex(); - buf.writerIndex(buf.writerIndex() + 8); // reserve length encoded length bytes - - Flux.from(value.stream()) - .handle( - (tempVal, sync) -> { - buf.writeBytes(tempVal); - sync.next(buf); - }) - .doOnComplete( - () -> { - // Write length - int endPos = buf.writerIndex(); - buf.writerIndex(initialPos); - buf.writeLongLE(endPos - (initialPos + 8)); - buf.writerIndex(endPos); - }) - .subscribe(); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + Flux.from(((Blob) value).stream()) + .reduce( + allocator.compositeBuffer(), + (a, b) -> a.addComponent(true, Unpooled.wrappedBuffer(b))) + .map( + c -> + c.addComponent( + true, + 0, + Unpooled.wrappedBuffer(BufferUtils.encodeLength(c.readableBytes())))) + .doOnSubscribe(e -> ((Blob) value).discard())); } private class MariaDbBlob implements Blob { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/BooleanCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/BooleanCodec.java index 0a6b03cd..9bba769f 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/BooleanCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/BooleanCodec.java @@ -1,16 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class BooleanCodec implements Codec { @@ -19,7 +22,7 @@ public class BooleanCodec implements Codec { private static final EnumSet COMPATIBLE_TYPES = EnumSet.of( - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING, DataType.BIGINT, @@ -34,7 +37,7 @@ public class BooleanCodec implements Codec { DataType.BIT); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Boolean.TYPE) || type.isAssignableFrom(Boolean.class)); } @@ -44,8 +47,12 @@ public boolean canEncode(Class value) { @Override public Boolean decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case BIT: return ByteCodec.parseBit(buf, length) != 0; @@ -75,9 +82,13 @@ public Boolean decodeText( @Override public Boolean decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { - switch (column.getType()) { + switch (column.getDataType()) { case BIT: return ByteCodec.parseBit(buf, length) != 0; @@ -113,13 +124,16 @@ public Boolean decodeBinary( } @Override - public void encodeText(ByteBuf buf, Context context, Boolean value) { - BufferUtils.writeAscii(buf, value ? "1" : "0"); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> BufferUtils.encodeAscii(allocator, ((Boolean) value) ? "1" : "0")); } @Override - public void encodeBinary(ByteBuf buf, Context context, Boolean value) { - buf.writeByte(value ? 1 : 0); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeByte(allocator, ((Boolean) value) ? 1 : 0)); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/ByteArrayCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/ByteArrayCodec.java index 43d5f26e..1954eaa3 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/ByteArrayCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/ByteArrayCodec.java @@ -1,20 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class ByteArrayCodec implements Codec { - public static final byte[] BINARY_PREFIX = {'_', 'b', 'i', 'n', 'a', 'r', 'y', ' ', '\''}; - public static final ByteArrayCodec INSTANCE = new ByteArrayCodec(); private static final EnumSet COMPATIBLE_TYPES = @@ -25,18 +26,22 @@ public class ByteArrayCodec implements Codec { DataType.LONGBLOB, DataType.GEOMETRY, DataType.VARSTRING, - DataType.VARCHAR, + DataType.TEXT, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Byte.TYPE && type.isArray()) || type.isAssignableFrom(byte[].class)); } @Override public byte[] decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { byte[] arr = new byte[length]; buf.readBytes(arr); @@ -45,7 +50,11 @@ public byte[] decodeText( @Override public byte[] decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { byte[] arr = new byte[length]; buf.readBytes(arr); return arr; @@ -56,17 +65,18 @@ public boolean canEncode(Class value) { } @Override - public void encodeText(ByteBuf buf, Context context, byte[] value) { - buf.writeBytes(BINARY_PREFIX); - BufferUtils.writeEscaped(buf, value, 0, value.length, context); - buf.writeByte('\''); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> + BufferUtils.encodeEscapedBytes( + allocator, BufferUtils.BINARY_PREFIX, (byte[]) value, context)); } @Override - public void encodeBinary(ByteBuf buf, Context context, byte[] value) { - - BufferUtils.writeLengthEncode(value.length, buf); - buf.writeBytes(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeLengthBytes(allocator, (byte[]) value)); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/ByteBufferCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/ByteBufferCodec.java new file mode 100644 index 00000000..b79ad620 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/codec/list/ByteBufferCodec.java @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.codec.list; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import java.nio.ByteBuffer; +import java.util.EnumSet; +import org.mariadb.r2dbc.ExceptionFactory; +import org.mariadb.r2dbc.codec.Codec; +import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; +import org.mariadb.r2dbc.util.BufferUtils; + +public class ByteBufferCodec implements Codec { + + public static final ByteBufferCodec INSTANCE = new ByteBufferCodec(); + + private static final EnumSet COMPATIBLE_TYPES = + EnumSet.of( + DataType.BIT, + DataType.BLOB, + DataType.TINYBLOB, + DataType.MEDIUMBLOB, + DataType.LONGBLOB, + DataType.STRING, + DataType.VARSTRING, + DataType.TEXT); + + public boolean canDecode(ColumnDefinitionPacket column, Class type) { + return COMPATIBLE_TYPES.contains(column.getDataType()) + && type.isAssignableFrom(ByteBuffer.class); + } + + @Override + public ByteBuffer decodeText( + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + return decode(buf, length, column, factory); + } + + @Override + public ByteBuffer decodeBinary( + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + return decode(buf, length, column, factory); + } + + private static ByteBuffer decode( + ByteBuf buf, int length, ColumnDefinitionPacket column, ExceptionFactory factory) { + switch (column.getDataType()) { + case STRING: + case TEXT: + case VARSTRING: + if (!column.isBinary()) { + buf.skipBytes(length); + throw factory.createParsingException( + String.format( + "Data type %s (not binary) cannot be decoded as Blob", column.getDataType())); + } + ByteBuffer value = ByteBuffer.allocate(length); + buf.readBytes(value); + return value; + + default: + // BIT, TINYBLOB, MEDIUMBLOB, LONGBLOB, BLOB, GEOMETRY + byte[] val = new byte[length]; + buf.readBytes(val); + return ByteBuffer.wrap(val); + } + } + + public boolean canEncode(Class value) { + return ByteBuf.class.isAssignableFrom(value); + } + + @Override + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuffer val = (ByteBuffer) value; + ByteBuf byteBuf = allocator.buffer(); + if (val.hasArray()) { + BufferUtils.escapedBytes(byteBuf, val.array(), val.remaining(), context); + } else { + byte[] arr = new byte[val.remaining()]; + val.get(arr); + BufferUtils.escapedBytes(byteBuf, arr, arr.length, context); + } + byteBuf.writeByte('\''); + return byteBuf; + }); + } + + @Override + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuffer val = (ByteBuffer) value; + CompositeByteBuf compositeByteBuf = allocator.compositeBuffer(); + ByteBuf buf = Unpooled.wrappedBuffer(val); + compositeByteBuf.addComponent( + true, Unpooled.wrappedBuffer(BufferUtils.encodeLength(val.remaining()))); + compositeByteBuf.addComponent(true, buf); + return compositeByteBuf; + }); + } + + public DataType getBinaryEncodeType() { + return DataType.BLOB; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/ByteCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/ByteCodec.java index d57ec54b..f8bc744d 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/ByteCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/ByteCodec.java @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class ByteCodec implements Codec { @@ -40,7 +42,7 @@ public class ByteCodec implements Codec { DataType.ENUM, DataType.VARSTRING, DataType.STRING, - DataType.VARCHAR); + DataType.TEXT); public static long parseBit(ByteBuf buf, int length) { if (length == 1) { @@ -56,7 +58,7 @@ public static long parseBit(ByteBuf buf, int length) { } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Byte.TYPE) || type.isAssignableFrom(Byte.class)); } @@ -66,10 +68,14 @@ public boolean canEncode(Class value) { @Override public Byte decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case TINYINT: case SMALLINT: case MEDIUMINT: @@ -100,14 +106,15 @@ public Byte decodeText( try { result = new BigDecimal(str).setScale(0, RoundingMode.DOWN).byteValueExact(); } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Byte", str, column.getType())); + throw factory.createParsingException( + String.format( + "value '%s' (%s) cannot be decoded as Byte", str, column.getDataType())); } break; } if ((byte) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException("byte overflow"); + throw factory.createParsingException("byte overflow"); } return (byte) result; @@ -115,10 +122,14 @@ public Byte decodeText( @Override public Byte decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case TINYINT: result = column.isSigned() ? buf.readByte() : buf.readUnsignedByte(); break; @@ -160,8 +171,8 @@ public Byte decodeBinary( float f = buf.readFloatLE(); result = (long) f; if ((byte) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Byte", f, column.getType())); + throw factory.createParsingException( + String.format("value '%s' (%s) cannot be decoded as Byte", f, column.getDataType())); } break; @@ -169,23 +180,24 @@ public Byte decodeBinary( double d = buf.readDoubleLE(); result = (long) d; if ((byte) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Byte", d, column.getType())); + throw factory.createParsingException( + String.format("value '%s' (%s) cannot be decoded as Byte", d, column.getDataType())); } break; case OLDDECIMAL: case DECIMAL: case ENUM: - case VARCHAR: + case TEXT: case VARSTRING: case STRING: String str = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); try { result = new BigDecimal(str).setScale(0, RoundingMode.DOWN).byteValueExact(); } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Byte", str, column.getType())); + throw factory.createParsingException( + String.format( + "value '%s' (%s) cannot be decoded as Byte", str, column.getDataType())); } break; @@ -198,20 +210,23 @@ public Byte decodeBinary( } if ((byte) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException("byte overflow"); + throw factory.createParsingException("byte overflow"); } return (byte) result; } @Override - public void encodeText(ByteBuf buf, Context context, Byte value) { - BufferUtils.writeAscii(buf, Integer.toString((int) value)); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> BufferUtils.encodeAscii(allocator, Integer.toString((Byte) value))); } @Override - public void encodeBinary(ByteBuf buf, Context context, Byte value) { - buf.writeByte(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeByte(allocator, (Byte) value)); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/ClobCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/ClobCodec.java index 9217cb50..c282457f 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/ClobCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/ClobCodec.java @@ -1,16 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.r2dbc.spi.Clob; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -20,10 +23,10 @@ public class ClobCodec implements Codec { public static final ClobCodec INSTANCE = new ClobCodec(); private static final EnumSet COMPATIBLE_TYPES = - EnumSet.of(DataType.VARCHAR, DataType.VARSTRING, DataType.STRING); + EnumSet.of(DataType.TEXT, DataType.VARSTRING, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && (type.isAssignableFrom(Clob.class)); + return COMPATIBLE_TYPES.contains(column.getDataType()) && (type.isAssignableFrom(Clob.class)); } public boolean canEncode(Class value) { @@ -32,51 +35,50 @@ public boolean canEncode(Class value) { @Override public Clob decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { String rawValue = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); return Clob.from(Mono.just(rawValue)); } @Override public Clob decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { String rawValue = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); return Clob.from(Mono.just(rawValue)); } @Override - public void encodeText(ByteBuf buf, Context context, Clob value) { - buf.writeByte('\''); - Flux.from(value.stream()) - .handle( - (tempVal, sync) -> { - BufferUtils.write(buf, tempVal.toString(), false, context); - sync.next(buf); - }) - .subscribe(); - buf.writeByte('\''); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + Flux.from(((Clob) value).stream()) + .reduce(new StringBuilder(), (a, b) -> a.append(b)) + .map( + b -> + BufferUtils.encodeEscapedBytes( + allocator, + BufferUtils.STRING_PREFIX, + b.toString().getBytes(StandardCharsets.UTF_8), + context)) + .doOnSubscribe(e -> ((Clob) value).discard())); } @Override - public void encodeBinary(ByteBuf buf, Context context, Clob value) { - buf.writeByte(0xfe); - int initialPos = buf.writerIndex(); - buf.writerIndex(buf.writerIndex() + 8); // reserve length encoded length bytes - Flux.from(value.stream()) - .handle( - (tempVal, sync) -> { - buf.writeCharSequence(tempVal, StandardCharsets.UTF_8); - sync.next(buf); - }) - .doOnComplete( - () -> { - // Write length - int endPos = buf.writerIndex(); - buf.writerIndex(initialPos); - buf.writeLongLE(endPos - (initialPos + 8)); - buf.writerIndex(endPos); - }) - .subscribe(); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + Flux.from(((Clob) value).stream()) + .reduce(new StringBuilder(), (a, b) -> a.append(b)) + .map(b -> BufferUtils.encodeLengthUtf8(allocator, b.toString())) + .doOnSubscribe(e -> ((Clob) value).discard())); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/DoubleCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/DoubleCodec.java index 5533499e..2a1ca03f 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/DoubleCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/DoubleCodec.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class DoubleCodec implements Codec { @@ -31,12 +33,12 @@ public class DoubleCodec implements Codec { DataType.YEAR, DataType.OLDDECIMAL, DataType.DECIMAL, - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Double.TYPE) || type.isAssignableFrom(Double.class)); } @@ -46,8 +48,12 @@ public boolean canEncode(Class value) { @Override public Double decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case TINYINT: case SMALLINT: case MEDIUMINT: @@ -66,7 +72,7 @@ public Double decodeText( try { return Double.valueOf(str2); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Double", str2)); } } @@ -74,8 +80,12 @@ public Double decodeText( @Override public Double decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case DOUBLE: return buf.readDoubleLE(); @@ -130,20 +140,27 @@ public Double decodeBinary( try { return Double.valueOf(str2); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Double", str2)); } } } @Override - public void encodeText(ByteBuf buf, Context context, Double value) { - BufferUtils.writeAscii(buf, String.valueOf(value)); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeAscii(allocator, value.toString())); } @Override - public void encodeBinary(ByteBuf buf, Context context, Double value) { - buf.writeDoubleLE(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(8, 8); + buf.writeDoubleLE((Double) value); + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/DurationCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/DurationCodec.java index 936c6142..86f320d3 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/DurationCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/DurationCodec.java @@ -1,16 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; public class DurationCodec implements Codec { @@ -22,11 +25,11 @@ public class DurationCodec implements Codec { DataType.DATETIME, DataType.TIMESTAMP, DataType.VARSTRING, - DataType.VARCHAR, + DataType.TEXT, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(Duration.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) && type.isAssignableFrom(Duration.class); } public boolean canEncode(Class value) { @@ -35,10 +38,14 @@ public boolean canEncode(Class value) { @Override public Duration decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int[] parts; - switch (column.getType()) { + switch (column.getDataType()) { case TIMESTAMP: case DATETIME: parts = @@ -54,7 +61,7 @@ public Duration decodeText( default: // TIME, VARCHAR, VARSTRING, STRING: - parts = LocalTimeCodec.parseTime(buf, length, column); + parts = LocalTimeCodec.parseTime(buf, length, column, factory); Duration d = Duration.ZERO .plusHours(parts[1]) @@ -68,7 +75,11 @@ public Duration decodeText( @Override public Duration decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long days = 0; int hours = 0; @@ -76,7 +87,7 @@ public Duration decodeBinary( int seconds = 0; long microseconds = 0; - switch (column.getType()) { + switch (column.getDataType()) { case TIME: boolean negate = false; if (length > 0) { @@ -124,7 +135,7 @@ public Duration decodeBinary( default: // VARCHAR, VARSTRING, STRING: - int[] parts = LocalTimeCodec.parseTime(buf, length, column); + int[] parts = LocalTimeCodec.parseTime(buf, length, column, factory); Duration d = Duration.ZERO .plusHours(parts[1]) @@ -137,68 +148,83 @@ public Duration decodeBinary( } @Override - public void encodeText(ByteBuf buf, Context context, Duration val) { - long s = val.getSeconds(); - boolean negate = false; - if (s < 0) { - negate = true; - s = -s; - } - - long microSecond = val.getNano() / 1000; - buf.writeByte('\''); - if (microSecond != 0) { - if (negate) { - s = s - 1; - buf.writeCharSequence( - String.format( - "-%d:%02d:%02d.%06d", s / 3600, (s % 3600) / 60, (s % 60), 1000000 - microSecond), - StandardCharsets.US_ASCII); - - } else { - buf.writeCharSequence( - String.format("%d:%02d:%02d.%06d", s / 3600, (s % 3600) / 60, (s % 60), microSecond), - StandardCharsets.US_ASCII); - } - } else { - String format = negate ? "-%d:%02d:%02d" : "%d:%02d:%02d"; - buf.writeCharSequence( - String.format(format, s / 3600, (s % 3600) / 60, (s % 60)), StandardCharsets.US_ASCII); - } - buf.writeByte('\''); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + Duration val = (Duration) value; + long s = val.getSeconds(); + boolean negate = false; + if (s < 0) { + negate = true; + s = -s; + } + ByteBuf buf = allocator.buffer(); + long microSecond = val.getNano() / 1000; + buf.writeByte('\''); + String durationStr; + if (microSecond != 0) { + if (negate) { + s = s - 1; + durationStr = + String.format( + "-%d:%02d:%02d.%06d", + s / 3600, (s % 3600) / 60, (s % 60), 1000000 - microSecond); + + } else { + durationStr = + String.format( + "%d:%02d:%02d.%06d", s / 3600, (s % 3600) / 60, (s % 60), microSecond); + } + } else { + durationStr = + String.format( + negate ? "-%d:%02d:%02d" : "%d:%02d:%02d", s / 3600, (s % 3600) / 60, (s % 60)); + } + buf.writeCharSequence(durationStr, StandardCharsets.US_ASCII); + buf.writeByte('\''); + return buf; + }); } @Override - public void encodeBinary(ByteBuf buf, Context context, Duration value) { - long microSecond = value.getNano() / 1000; - long s = Math.abs(value.getSeconds()); - if (microSecond > 0) { - if (value.isNegative()) { - s = s - 1; - buf.writeByte((byte) 12); - buf.writeByte((byte) (value.isNegative() ? 1 : 0)); - buf.writeIntLE((int) (s / (24 * 3600))); - buf.writeByte((int) (s % (24 * 3600)) / 3600); - buf.writeByte((int) (s % 3600) / 60); - buf.writeByte((int) (s % 60)); - buf.writeIntLE((int) (1000000 - microSecond)); - } else { - buf.writeByte((byte) 12); - buf.writeByte((byte) (value.isNegative() ? 1 : 0)); - buf.writeIntLE((int) (s / (24 * 3600))); - buf.writeByte((int) (s % (24 * 3600)) / 3600); - buf.writeByte((int) (s % 3600) / 60); - buf.writeByte((int) (s % 60)); - buf.writeIntLE((int) microSecond); - } - } else { - buf.writeByte((byte) 8); - buf.writeByte((byte) (value.isNegative() ? 1 : 0)); - buf.writeIntLE((int) (s / (24 * 3600))); - buf.writeByte((int) (s % (24 * 3600)) / 3600); - buf.writeByte((int) (s % 3600) / 60); - buf.writeByte((int) (s % 60)); - } + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + Duration val = (Duration) value; + ByteBuf buf = allocator.buffer(); + long microSecond = val.getNano() / 1000; + long s = Math.abs(val.getSeconds()); + if (microSecond > 0) { + if (val.isNegative()) { + s = s - 1; + buf.writeByte((byte) 12); + buf.writeByte((byte) (val.isNegative() ? 1 : 0)); + buf.writeIntLE((int) (s / (24 * 3600))); + buf.writeByte((int) (s % (24 * 3600)) / 3600); + buf.writeByte((int) (s % 3600) / 60); + buf.writeByte((int) (s % 60)); + buf.writeIntLE((int) (1000000 - microSecond)); + } else { + buf.writeByte((byte) 12); + buf.writeByte((byte) (val.isNegative() ? 1 : 0)); + buf.writeIntLE((int) (s / (24 * 3600))); + buf.writeByte((int) (s % (24 * 3600)) / 3600); + buf.writeByte((int) (s % 3600) / 60); + buf.writeByte((int) (s % 60)); + buf.writeIntLE((int) microSecond); + } + } else { + buf.writeByte((byte) 8); + buf.writeByte((byte) (val.isNegative() ? 1 : 0)); + buf.writeIntLE((int) (s / (24 * 3600))); + buf.writeByte((int) (s % (24 * 3600)) / 3600); + buf.writeByte((int) (s % 3600) / 60); + buf.writeByte((int) (s % 60)); + } + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/FloatCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/FloatCodec.java index 943a4947..afdc361b 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/FloatCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/FloatCodec.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class FloatCodec implements Codec { @@ -31,12 +33,12 @@ public class FloatCodec implements Codec { DataType.DECIMAL, DataType.YEAR, DataType.DOUBLE, - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Float.TYPE) || type.isAssignableFrom(Float.class)); } @@ -46,8 +48,12 @@ public boolean canEncode(Class value) { @Override public Float decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - switch (column.getType()) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + switch (column.getDataType()) { case TINYINT: case SMALLINT: case MEDIUMINT: @@ -66,7 +72,7 @@ public Float decodeText( try { return Float.valueOf(val); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Float", val)); } } @@ -74,9 +80,13 @@ public Float decodeText( @Override public Float decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { - switch (column.getType()) { + switch (column.getDataType()) { case FLOAT: return buf.readFloatLE(); @@ -130,20 +140,27 @@ public Float decodeBinary( try { return Float.valueOf(str2); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Float", str2)); } } } @Override - public void encodeText(ByteBuf buf, Context context, Float value) { - BufferUtils.writeAscii(buf, String.valueOf(value)); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeAscii(allocator, value.toString())); } @Override - public void encodeBinary(ByteBuf buf, Context context, Float value) { - buf.writeFloatLE(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(4, 4); + buf.writeFloatLE((Float) value); + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/IntCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/IntCodec.java index b53bb583..7682ab51 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/IntCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/IntCodec.java @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class IntCodec implements Codec { @@ -25,7 +27,7 @@ public class IntCodec implements Codec { DataType.FLOAT, DataType.DOUBLE, DataType.OLDDECIMAL, - DataType.VARCHAR, + DataType.TEXT, DataType.DECIMAL, DataType.ENUM, DataType.VARSTRING, @@ -39,7 +41,7 @@ public class IntCodec implements Codec { DataType.YEAR); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Integer.TYPE) || type.isAssignableFrom(Integer.class)); } @@ -49,9 +51,13 @@ public boolean canEncode(Class value) { @Override public Integer decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case TINYINT: case SMALLINT: case MEDIUMINT: @@ -63,7 +69,7 @@ public Integer decodeText( case BIGINT: result = LongCodec.parse(buf, length); if (result < 0 & !column.isSigned()) { - throw new R2dbcNonTransientResourceException("integer overflow"); + throw factory.createParsingException("integer overflow"); } break; @@ -82,13 +88,13 @@ public Integer decodeText( result = new BigDecimal(str).setScale(0, RoundingMode.DOWN).longValueExact(); break; } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Integer", str)); } } if ((int) result != result) { - throw new R2dbcNonTransientResourceException("integer overflow"); + throw factory.createParsingException("integer overflow"); } return (int) result; @@ -96,10 +102,14 @@ public Integer decodeText( @Override public Integer decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case INTEGER: result = column.isSigned() ? buf.readIntLE() : buf.readUnsignedIntLE(); break; @@ -131,7 +141,7 @@ public Integer decodeBinary( try { return val.intValueExact(); } catch (ArithmeticException ae) { - throw new R2dbcNonTransientResourceException("integer overflow"); + throw factory.createParsingException("integer overflow"); } } @@ -157,26 +167,33 @@ public Integer decodeBinary( result = new BigDecimal(str).setScale(0, RoundingMode.DOWN).longValueExact(); break; } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Integer", str)); } } if ((int) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException("integer overflow"); + throw factory.createParsingException("integer overflow"); } return (int) result; } @Override - public void encodeText(ByteBuf buf, Context context, Integer value) { - BufferUtils.writeAscii(buf, String.valueOf(value)); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeAscii(allocator, value.toString())); } @Override - public void encodeBinary(ByteBuf buf, Context context, Integer value) { - buf.writeIntLE(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(4, 4); + buf.writeIntLE((Integer) value); + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateCodec.java index f505aa01..9d5c75ce 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateCodec.java @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoField; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; public class LocalDateCodec implements Codec { @@ -27,7 +29,7 @@ public class LocalDateCodec implements Codec { DataType.TIMESTAMP, DataType.YEAR, DataType.VARSTRING, - DataType.VARCHAR, + DataType.TEXT, DataType.STRING); public static int[] parseDate(ByteBuf buf, int length) { @@ -51,7 +53,8 @@ public static int[] parseDate(ByteBuf buf, int length) { } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(LocalDate.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) + && type.isAssignableFrom(LocalDate.class); } public boolean canEncode(Class value) { @@ -60,10 +63,14 @@ public boolean canEncode(Class value) { @Override public LocalDate decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int[] parts; - switch (column.getType()) { + switch (column.getDataType()) { case YEAR: short y = (short) LongCodec.parse(buf, length); @@ -94,8 +101,9 @@ public LocalDate decodeText( String val = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); String[] stDatePart = val.split("-| "); if (stDatePart.length < 3) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Date", val, column.getType())); + throw factory.createParsingException( + String.format( + "value '%s' (%s) cannot be decoded as Date", val, column.getDataType())); } try { @@ -104,8 +112,9 @@ public LocalDate decodeText( int dayOfMonth = Integer.valueOf(stDatePart[2]); return LocalDate.of(year, month, dayOfMonth); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Date", val, column.getType())); + throw factory.createParsingException( + String.format( + "value '%s' (%s) cannot be decoded as Date", val, column.getDataType())); } } if (parts == null) return null; @@ -114,13 +123,17 @@ public LocalDate decodeText( @Override public LocalDate decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int year = 0; int month = 1; int dayOfMonth = 1; - switch (column.getType()) { + switch (column.getDataType()) { case TIMESTAMP: case DATETIME: if (length > 0) { @@ -162,8 +175,9 @@ public LocalDate decodeBinary( String val = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); String[] stDatePart = val.split("-| "); if (stDatePart.length < 3) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Date", val, column.getType())); + throw factory.createParsingException( + String.format( + "value '%s' (%s) cannot be decoded as Date", val, column.getDataType())); } try { @@ -172,27 +186,42 @@ public LocalDate decodeBinary( dayOfMonth = Integer.valueOf(stDatePart[2]); return LocalDate.of(year, month, dayOfMonth); } catch (NumberFormatException nfe) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' (%s) cannot be decoded as Date", val, column.getType())); + throw factory.createParsingException( + String.format( + "value '%s' (%s) cannot be decoded as Date", val, column.getDataType())); } } } @Override - public void encodeText(ByteBuf buf, Context context, LocalDate value) { - buf.writeByte('\''); - buf.writeCharSequence( - value.format(DateTimeFormatter.ISO_LOCAL_DATE), StandardCharsets.US_ASCII); - buf.writeByte('\''); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(); + buf.writeByte('\''); + buf.writeCharSequence( + ((LocalDate) value).format(DateTimeFormatter.ISO_LOCAL_DATE), + StandardCharsets.US_ASCII); + buf.writeByte('\''); + return buf; + }); } @Override - public void encodeBinary(ByteBuf buf, Context context, LocalDate value) { - buf.writeByte(7); // length - buf.writeShortLE((short) value.get(ChronoField.YEAR)); - buf.writeByte(value.get(ChronoField.MONTH_OF_YEAR)); - buf.writeByte(value.get(ChronoField.DAY_OF_MONTH)); - buf.writeBytes(new byte[] {0, 0, 0}); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + LocalDate val = (LocalDate) value; + ByteBuf buf = allocator.buffer(8, 8); + buf.writeByte(7); // length + buf.writeShortLE((short) val.get(ChronoField.YEAR)); + buf.writeByte(val.get(ChronoField.MONTH_OF_YEAR)); + buf.writeByte(val.get(ChronoField.DAY_OF_MONTH)); + buf.writeBytes(new byte[] {0, 0, 0}); + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateTimeCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateTimeCodec.java index 387a403f..3fdf4f11 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateTimeCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/LocalDateTimeCodec.java @@ -1,10 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; import java.time.DateTimeException; import java.time.LocalDateTime; @@ -12,10 +12,12 @@ import java.time.format.DateTimeFormatterBuilder; import java.time.temporal.ChronoField; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; public class LocalDateTimeCodec implements Codec { @@ -31,7 +33,7 @@ public class LocalDateTimeCodec implements Codec { DataType.DATETIME, DataType.TIMESTAMP, DataType.VARSTRING, - DataType.VARCHAR, + DataType.TEXT, DataType.STRING, DataType.TIME, DataType.DATE); @@ -84,7 +86,7 @@ public static int[] parseTimestamp(String raw) { } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && type.isAssignableFrom(LocalDateTime.class); } @@ -94,10 +96,14 @@ public boolean canEncode(Class value) { @Override public LocalDateTime decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int[] parts; - switch (column.getType()) { + switch (column.getDataType()) { case DATE: parts = LocalDateCodec.parseDate(buf, length); if (parts == null) return null; @@ -111,7 +117,7 @@ public LocalDateTime decodeText( .plusNanos(parts[6]); case TIME: - parts = LocalTimeCodec.parseTime(buf, length, column); + parts = LocalTimeCodec.parseTime(buf, length, column, factory); return LocalDateTime.of(1970, 1, 1, parts[1] % 24, parts[2], parts[3]).plusNanos(parts[4]); default: @@ -123,16 +129,20 @@ public LocalDateTime decodeText( return LocalDateTime.of(parts[0], parts[1], parts[2], parts[3], parts[4], parts[5]) .plusNanos(parts[6]); } catch (DateTimeException dte) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format( - "value '%s' (%s) cannot be decoded as LocalDateTime", val, column.getType())); + "value '%s' (%s) cannot be decoded as LocalDateTime", val, column.getDataType())); } } } @Override public LocalDateTime decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int year = 1970; int month = 1; @@ -142,7 +152,7 @@ public LocalDateTime decodeBinary( int seconds = 0; long microseconds = 0; - switch (column.getType()) { + switch (column.getDataType()) { case TIME: // specific case for TIME, to handle value not in 00:00:00-23:59:59 if (length > 0) { @@ -185,9 +195,9 @@ public LocalDateTime decodeBinary( return LocalDateTime.of(parts[0], parts[1], parts[2], parts[3], parts[4], parts[5]) .plusNanos(parts[6]); } catch (DateTimeException dte) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format( - "value '%s' (%s) cannot be decoded as LocalDateTime", val, column.getType())); + "value '%s' (%s) cannot be decoded as LocalDateTime", val, column.getDataType())); } } @@ -196,37 +206,51 @@ public LocalDateTime decodeBinary( } @Override - public void encodeText(ByteBuf buf, Context context, LocalDateTime val) { - - buf.writeByte('\''); - buf.writeCharSequence( - val.format(val.getNano() != 0 ? TIMESTAMP_FORMAT : TIMESTAMP_FORMAT_NO_FRACTIONAL), - StandardCharsets.US_ASCII); - buf.writeByte('\''); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + LocalDateTime val = (LocalDateTime) value; + ByteBuf buf = allocator.buffer(); + buf.writeByte('\''); + buf.writeCharSequence( + val.format(val.getNano() != 0 ? TIMESTAMP_FORMAT : TIMESTAMP_FORMAT_NO_FRACTIONAL), + StandardCharsets.US_ASCII); + buf.writeByte('\''); + return buf; + }); } @Override - public void encodeBinary(ByteBuf buf, Context context, LocalDateTime value) { - - int nano = value.getNano(); - if (nano > 0) { - buf.writeByte((byte) 11); - buf.writeShortLE((short) value.get(ChronoField.YEAR)); - buf.writeByte(value.get(ChronoField.MONTH_OF_YEAR)); - buf.writeByte(value.get(ChronoField.DAY_OF_MONTH)); - buf.writeByte(value.get(ChronoField.HOUR_OF_DAY)); - buf.writeByte(value.get(ChronoField.MINUTE_OF_HOUR)); - buf.writeByte(value.get(ChronoField.SECOND_OF_MINUTE)); - buf.writeIntLE(nano / 1000); - } else { - buf.writeByte((byte) 7); - buf.writeShortLE((short) value.get(ChronoField.YEAR)); - buf.writeByte(value.get(ChronoField.MONTH_OF_YEAR)); - buf.writeByte(value.get(ChronoField.DAY_OF_MONTH)); - buf.writeByte(value.get(ChronoField.HOUR_OF_DAY)); - buf.writeByte(value.get(ChronoField.MINUTE_OF_HOUR)); - buf.writeByte(value.get(ChronoField.SECOND_OF_MINUTE)); - } + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf; + LocalDateTime val = (LocalDateTime) value; + int nano = val.getNano(); + if (nano > 0) { + buf = allocator.buffer(12, 12); + buf.writeByte((byte) 11); + buf.writeShortLE((short) val.get(ChronoField.YEAR)); + buf.writeByte(val.get(ChronoField.MONTH_OF_YEAR)); + buf.writeByte(val.get(ChronoField.DAY_OF_MONTH)); + buf.writeByte(val.get(ChronoField.HOUR_OF_DAY)); + buf.writeByte(val.get(ChronoField.MINUTE_OF_HOUR)); + buf.writeByte(val.get(ChronoField.SECOND_OF_MINUTE)); + buf.writeIntLE(nano / 1000); + } else { + buf = allocator.buffer(8, 8); + buf.writeByte((byte) 7); + buf.writeShortLE((short) val.get(ChronoField.YEAR)); + buf.writeByte(val.get(ChronoField.MONTH_OF_YEAR)); + buf.writeByte(val.get(ChronoField.DAY_OF_MONTH)); + buf.writeByte(val.get(ChronoField.HOUR_OF_DAY)); + buf.writeByte(val.get(ChronoField.MINUTE_OF_HOUR)); + buf.writeByte(val.get(ChronoField.SECOND_OF_MINUTE)); + } + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/LocalTimeCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/LocalTimeCodec.java index 42194855..9c698337 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/LocalTimeCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/LocalTimeCodec.java @@ -1,20 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; import java.time.LocalDateTime; import java.time.LocalTime; import java.time.format.DateTimeParseException; import java.time.temporal.ChronoField; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; public class LocalTimeCodec implements Codec { @@ -26,10 +28,11 @@ public class LocalTimeCodec implements Codec { DataType.DATETIME, DataType.TIMESTAMP, DataType.VARSTRING, - DataType.VARCHAR, + DataType.TEXT, DataType.STRING); - public static int[] parseTime(ByteBuf buf, int length, ColumnDefinitionPacket column) { + public static int[] parseTime( + ByteBuf buf, int length, ColumnDefinitionPacket column, ExceptionFactory factory) { int initialPos = buf.readerIndex(); int[] parts = new int[5]; int idx = 1; @@ -52,8 +55,8 @@ public static int[] parseTime(ByteBuf buf, int length, ColumnDefinitionPacket co if (b < '0' || b > '9') { buf.readerIndex(initialPos); String val = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); - throw new R2dbcNonTransientResourceException( - String.format("%s value '%s' cannot be decoded as Time", column.getType(), val)); + throw factory.createParsingException( + String.format("%s value '%s' cannot be decoded as Time", column.getDataType(), val)); } partLength++; parts[idx] = parts[idx] * 10 + (b - '0'); @@ -62,8 +65,8 @@ public static int[] parseTime(ByteBuf buf, int length, ColumnDefinitionPacket co if (idx < 2) { buf.readerIndex(initialPos); String val = buf.readCharSequence(length, StandardCharsets.UTF_8).toString(); - throw new R2dbcNonTransientResourceException( - String.format("%s value '%s' cannot be decoded as Time", column.getType(), val)); + throw factory.createParsingException( + String.format("%s value '%s' cannot be decoded as Time", column.getDataType(), val)); } // set nano real value @@ -76,7 +79,8 @@ public static int[] parseTime(ByteBuf buf, int length, ColumnDefinitionPacket co } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(LocalTime.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) + && type.isAssignableFrom(LocalTime.class); } public boolean canEncode(Class value) { @@ -85,10 +89,14 @@ public boolean canEncode(Class value) { @Override public LocalTime decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int[] parts; - switch (column.getType()) { + switch (column.getDataType()) { case TIMESTAMP: case DATETIME: parts = @@ -98,7 +106,7 @@ public LocalTime decodeText( return LocalTime.of(parts[3], parts[4], parts[5], parts[6]); case TIME: - parts = parseTime(buf, length, column); + parts = parseTime(buf, length, column, factory); parts[1] = parts[1] % 24; if (parts[0] == 1) { // negative @@ -119,22 +127,26 @@ public LocalTime decodeText( return LocalTime.parse(val); } } catch (DateTimeParseException e) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format( - "value '%s' (%s) cannot be decoded as LocalTime", val, column.getType())); + "value '%s' (%s) cannot be decoded as LocalTime", val, column.getDataType())); } } } @Override public LocalTime decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { int hour = 0; int minutes = 0; int seconds = 0; long microseconds = 0; - switch (column.getType()) { + switch (column.getDataType()) { case TIMESTAMP: case DATETIME: if (length > 0) { @@ -182,59 +194,74 @@ public LocalTime decodeBinary( return LocalTime.parse(val); } } catch (DateTimeParseException e) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format( - "value '%s' (%s) cannot be decoded as LocalTime", val, column.getType())); + "value '%s' (%s) cannot be decoded as LocalTime", val, column.getDataType())); } } } @Override - public void encodeText(ByteBuf buf, Context context, LocalTime val) { - - StringBuilder dateString = new StringBuilder(15); - dateString - .append(val.getHour() < 10 ? "0" : "") - .append(val.getHour()) - .append(val.getMinute() < 10 ? ":0" : ":") - .append(val.getMinute()) - .append(val.getSecond() < 10 ? ":0" : ":") - .append(val.getSecond()); - - int microseconds = val.getNano() / 1000; - if (microseconds > 0) { - dateString.append("."); - if (microseconds % 1000 == 0) { - dateString.append(Integer.toString(microseconds / 1000 + 1000).substring(1)); - } else { - dateString.append(Integer.toString(microseconds + 1000000).substring(1)); - } - } + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + LocalTime val = (LocalTime) value; + ByteBuf buf = allocator.buffer(); + StringBuilder dateString = new StringBuilder(15); + dateString + .append(val.getHour() < 10 ? "0" : "") + .append(val.getHour()) + .append(val.getMinute() < 10 ? ":0" : ":") + .append(val.getMinute()) + .append(val.getSecond() < 10 ? ":0" : ":") + .append(val.getSecond()); + + int microseconds = val.getNano() / 1000; + if (microseconds > 0) { + dateString.append("."); + if (microseconds % 1000 == 0) { + dateString.append(Integer.toString(microseconds / 1000 + 1000).substring(1)); + } else { + dateString.append(Integer.toString(microseconds + 1000000).substring(1)); + } + } - buf.writeByte('\''); - buf.writeCharSequence(dateString.toString(), StandardCharsets.US_ASCII); - buf.writeByte('\''); + buf.writeByte('\''); + buf.writeCharSequence(dateString.toString(), StandardCharsets.US_ASCII); + buf.writeByte('\''); + return buf; + }); } @Override - public void encodeBinary(ByteBuf buf, Context context, LocalTime value) { - int nano = value.getNano(); - if (nano > 0) { - buf.writeByte((byte) 12); - buf.writeByte((byte) 0); - buf.writeIntLE(0); - buf.writeByte((byte) value.get(ChronoField.HOUR_OF_DAY)); - buf.writeByte((byte) value.get(ChronoField.MINUTE_OF_HOUR)); - buf.writeByte((byte) value.get(ChronoField.SECOND_OF_MINUTE)); - buf.writeIntLE(nano / 1000); - } else { - buf.writeByte((byte) 8); - buf.writeByte((byte) 0); - buf.writeIntLE(0); - buf.writeByte((byte) value.get(ChronoField.HOUR_OF_DAY)); - buf.writeByte((byte) value.get(ChronoField.MINUTE_OF_HOUR)); - buf.writeByte((byte) value.get(ChronoField.SECOND_OF_MINUTE)); - } + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf; + LocalTime val = (LocalTime) value; + int nano = val.getNano(); + if (nano > 0) { + buf = allocator.buffer(13, 13); + buf.writeByte((byte) 12); + buf.writeByte((byte) 0); + buf.writeIntLE(0); + buf.writeByte((byte) val.get(ChronoField.HOUR_OF_DAY)); + buf.writeByte((byte) val.get(ChronoField.MINUTE_OF_HOUR)); + buf.writeByte((byte) val.get(ChronoField.SECOND_OF_MINUTE)); + buf.writeIntLE(nano / 1000); + } else { + buf = allocator.buffer(9, 9); + buf.writeByte((byte) 8); + buf.writeByte((byte) 0); + buf.writeIntLE(0); + buf.writeByte((byte) val.get(ChronoField.HOUR_OF_DAY)); + buf.writeByte((byte) val.get(ChronoField.MINUTE_OF_HOUR)); + buf.writeByte((byte) val.get(ChronoField.SECOND_OF_MINUTE)); + } + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/LongCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/LongCodec.java index f36597a5..7deed712 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/LongCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/LongCodec.java @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class LongCodec implements Codec { @@ -25,7 +27,7 @@ public class LongCodec implements Codec { DataType.FLOAT, DataType.DOUBLE, DataType.OLDDECIMAL, - DataType.VARCHAR, + DataType.TEXT, DataType.DECIMAL, DataType.ENUM, DataType.VARSTRING, @@ -57,7 +59,7 @@ public static long parse(ByteBuf buf, int length) { } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Long.TYPE) || type.isAssignableFrom(Long.class)); } @@ -67,9 +69,13 @@ public boolean canEncode(Class value) { @Override public Long decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case DECIMAL: case OLDDECIMAL: case DOUBLE: @@ -78,7 +84,7 @@ public Long decodeText( try { return new BigDecimal(str1).setScale(0, RoundingMode.DOWN).longValueExact(); } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Long", str1)); } @@ -99,8 +105,8 @@ public Long decodeText( try { return val.longValueExact(); } catch (ArithmeticException ae) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' cannot be decoded as Long", val.toString())); + throw factory.createParsingException( + String.format("value '%s' cannot be decoded as Long", val)); } } @@ -118,7 +124,7 @@ public Long decodeText( try { return new BigInteger(str).longValueExact(); } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Long", str)); } } @@ -128,9 +134,13 @@ public Long decodeText( @Override public Long decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { - switch (column.getType()) { + switch (column.getDataType()) { case BIGINT: if (column.isSigned()) { return buf.readLongLE(); @@ -144,8 +154,8 @@ public Long decodeBinary( try { return val.longValueExact(); } catch (ArithmeticException ae) { - throw new R2dbcNonTransientResourceException( - String.format("value '%s' cannot be decoded as Long", val.toString())); + throw factory.createParsingException( + String.format("value '%s' cannot be decoded as Long", val)); } } @@ -194,20 +204,27 @@ public Long decodeBinary( try { return new BigDecimal(str).setScale(0, RoundingMode.DOWN).longValueExact(); } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Long", str)); } } } @Override - public void encodeText(ByteBuf buf, Context context, Long value) { - BufferUtils.writeAscii(buf, String.valueOf(value)); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeAscii(allocator, String.valueOf(value))); } @Override - public void encodeBinary(ByteBuf buf, Context context, Long value) { - buf.writeLongLE(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(8, 8); + buf.writeLongLE((Long) value); + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/ShortCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/ShortCodec.java index 1de2788e..6ca77cb1 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/ShortCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/ShortCodec.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.ByteBufAllocator; import java.math.BigDecimal; import java.math.RoundingMode; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class ShortCodec implements Codec { @@ -24,7 +26,7 @@ public class ShortCodec implements Codec { DataType.FLOAT, DataType.DOUBLE, DataType.OLDDECIMAL, - DataType.VARCHAR, + DataType.TEXT, DataType.DECIMAL, DataType.ENUM, DataType.VARSTRING, @@ -38,7 +40,7 @@ public class ShortCodec implements Codec { DataType.YEAR); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) + return COMPATIBLE_TYPES.contains(column.getDataType()) && ((type.isPrimitive() && type == Short.TYPE) || type.isAssignableFrom(Short.class)); } @@ -48,9 +50,13 @@ public boolean canEncode(Class value) { @Override public Short decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case TINYINT: case SMALLINT: case MEDIUMINT: @@ -75,13 +81,13 @@ public Short decodeText( result = new BigDecimal(str).setScale(0, RoundingMode.DOWN).longValueExact(); break; } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Short", str)); } } if ((short) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException("Short overflow"); + throw factory.createParsingException("Short overflow"); } return (short) result; @@ -89,10 +95,14 @@ public Short decodeText( @Override public Short decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { long result; - switch (column.getType()) { + switch (column.getDataType()) { case TINYINT: result = column.isSigned() ? buf.readByte() : buf.readUnsignedByte(); break; @@ -113,7 +123,7 @@ public Short decodeBinary( case BIGINT: result = buf.readLongLE(); if (result < 0 & !column.isSigned()) { - throw new R2dbcNonTransientResourceException("Short overflow"); + throw factory.createParsingException("Short overflow"); } break; @@ -140,26 +150,34 @@ public Short decodeBinary( result = new BigDecimal(str).setScale(0, RoundingMode.DOWN).longValueExact(); break; } catch (NumberFormatException | ArithmeticException nfe) { - throw new R2dbcNonTransientResourceException( + throw factory.createParsingException( String.format("value '%s' cannot be decoded as Short", str)); } } if ((short) result != result || (result < 0 && !column.isSigned())) { - throw new R2dbcNonTransientResourceException("Short overflow"); + throw factory.createParsingException("Short overflow"); } return (short) result; } @Override - public void encodeText(ByteBuf buf, Context context, Short value) { - BufferUtils.writeAscii(buf, String.valueOf(value)); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> BufferUtils.encodeAscii(allocator, Integer.toString((Short) value))); } @Override - public void encodeBinary(ByteBuf buf, Context context, Short value) { - buf.writeShortLE(value); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(2, 2); + buf.writeShortLE((Short) value); + return buf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/StreamCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/StreamCodec.java index 36c66e66..fe485d20 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/StreamCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/StreamCodec.java @@ -1,19 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufInputStream; -import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.netty.buffer.*; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class StreamCodec implements Codec { @@ -26,24 +26,33 @@ public class StreamCodec implements Codec { DataType.TINYBLOB, DataType.MEDIUMBLOB, DataType.LONGBLOB, - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING); public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(InputStream.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) + && type.isAssignableFrom(InputStream.class); } @Override public InputStream decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { // STRING, VARCHAR, VARSTRING, BLOB, TINYBLOB, MEDIUMBLOB, LONGBLOB: return new ByteBufInputStream(buf.readRetainedSlice(length), true); } @Override public InputStream decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { return new ByteBufInputStream(buf.readRetainedSlice(length), true); } @@ -52,43 +61,48 @@ public boolean canEncode(Class value) { } @Override - public void encodeText(ByteBuf buf, Context context, InputStream is) { - try { - buf.writeBytes("_binary '".getBytes(StandardCharsets.US_ASCII)); - byte[] array = new byte[4096]; - int len; - while ((len = is.read(array)) > 0) { - BufferUtils.writeEscaped(buf, array, 0, len, context); - } - buf.writeByte('\''); - } catch (IOException ioe) { - throw new R2dbcNonTransientResourceException("Failed to read InputStream", ioe); - } + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf buf = allocator.buffer(); + try { + buf.writeBytes("_binary '".getBytes(StandardCharsets.US_ASCII)); + byte[] array = new byte[4096]; + int len; + while ((len = ((InputStream) value).read(array)) > 0) { + BufferUtils.escapedBytes(buf, array, len, context); + } + buf.writeByte('\''); + } catch (IOException ioe) { + throw factory.createParsingException("Failed to read InputStream", ioe); + } + return buf; + }); } @Override - public void encodeBinary(ByteBuf buf, Context context, InputStream value) { - - // reserve place for length - buf.writeByte(0xfe); - int initialPos = buf.writerIndex(); - buf.writerIndex(buf.writerIndex() + 8); - - byte[] array = new byte[4096]; - int len; - try { - while ((len = value.read(array)) > 0) { - buf.writeBytes(array, 0, len); - } - } catch (IOException ioe) { - throw new R2dbcNonTransientResourceException("Failed to read InputStream", ioe); - } - - // Write length - int endPos = buf.writerIndex(); - buf.writerIndex(initialPos); - buf.writeLongLE(endPos - (initialPos + 8)); - buf.writerIndex(endPos); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue( + () -> { + ByteBuf val = allocator.buffer(); + try { + byte[] array = new byte[4096]; + int len; + while ((len = ((InputStream) value).read(array)) > 0) { + val.writeBytes(array, 0, len); + } + } catch (IOException ioe) { + throw factory.createParsingException("Failed to read InputStream", ioe); + } + CompositeByteBuf compositeByteBuf = allocator.compositeBuffer(); + ByteBuf buf = Unpooled.wrappedBuffer(val); + compositeByteBuf.addComponent( + true, Unpooled.wrappedBuffer(BufferUtils.encodeLength(buf.readableBytes()))); + compositeByteBuf.addComponent(true, buf); + return compositeByteBuf; + }); } public DataType getBinaryEncodeType() { diff --git a/src/main/java/org/mariadb/r2dbc/codec/list/StringCodec.java b/src/main/java/org/mariadb/r2dbc/codec/list/StringCodec.java index be4fbc2b..539c07fd 100644 --- a/src/main/java/org/mariadb/r2dbc/codec/list/StringCodec.java +++ b/src/main/java/org/mariadb/r2dbc/codec/list/StringCodec.java @@ -1,18 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.codec.list; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.time.LocalDate; import java.time.LocalDateTime; import java.util.EnumSet; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.BindValue; import org.mariadb.r2dbc.util.BufferUtils; public class StringCodec implements Codec { @@ -40,7 +43,7 @@ public class StringCodec implements Codec { DataType.DECIMAL, DataType.ENUM, DataType.SET, - DataType.VARCHAR, + DataType.TEXT, DataType.VARSTRING, DataType.STRING); @@ -54,7 +57,7 @@ public static String zeroFilling(String value, ColumnDefinitionPacket col) { } public boolean canDecode(ColumnDefinitionPacket column, Class type) { - return COMPATIBLE_TYPES.contains(column.getType()) && type.isAssignableFrom(String.class); + return COMPATIBLE_TYPES.contains(column.getDataType()) && type.isAssignableFrom(String.class); } public boolean canEncode(Class value) { @@ -63,8 +66,12 @@ public boolean canEncode(Class value) { @Override public String decodeText( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { - if (column.getType() == DataType.BIT) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { + if (column.getDataType() == DataType.BIT) { byte[] bytes = new byte[length]; buf.readBytes(bytes); @@ -89,9 +96,13 @@ public String decodeText( @Override public String decodeBinary( - ByteBuf buf, int length, ColumnDefinitionPacket column, Class type) { + ByteBuf buf, + int length, + ColumnDefinitionPacket column, + Class type, + ExceptionFactory factory) { String rawValue; - switch (column.getType()) { + switch (column.getDataType()) { case BIT: byte[] bytes = new byte[length]; buf.readBytes(bytes); @@ -133,6 +144,8 @@ public String decodeBinary( case MEDIUMINT: rawValue = String.valueOf(column.isSigned() ? buf.readMediumLE() : buf.readUnsignedMediumLE()); + // medium int is encoded on 3 bytes + one empty byte + buf.skipBytes(1); if (column.isZeroFill()) { return zeroFilling(rawValue, column); } @@ -247,18 +260,24 @@ public String decodeBinary( } @Override - public void encodeText(ByteBuf buf, Context context, String value) { - BufferUtils.write(buf, value, true, context); + public BindValue encodeText( + ByteBufAllocator allocator, Object value, Context context, ExceptionFactory factory) { + return createEncodedValue( + () -> + BufferUtils.encodeEscapedBytes( + allocator, + BufferUtils.STRING_PREFIX, + ((String) value).getBytes(StandardCharsets.UTF_8), + context)); } @Override - public void encodeBinary(ByteBuf buf, Context context, String value) { - byte[] b = value.getBytes(StandardCharsets.UTF_8); - BufferUtils.writeLengthEncode(b.length, buf); - buf.writeBytes(b); + public BindValue encodeBinary( + ByteBufAllocator allocator, Object value, ExceptionFactory factory) { + return createEncodedValue(() -> BufferUtils.encodeLengthUtf8(allocator, (String) value)); } public DataType getBinaryEncodeType() { - return DataType.VARSTRING; + return DataType.TEXT; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/AuthMoreData.java b/src/main/java/org/mariadb/r2dbc/message/AuthMoreData.java new file mode 100644 index 00000000..8b3e2c33 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/message/AuthMoreData.java @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.message; + +import io.netty.buffer.ByteBuf; + +public interface AuthMoreData { + + MessageSequence getSequencer(); + + ByteBuf getBuf(); +} diff --git a/src/main/java/org/mariadb/r2dbc/message/AuthSwitch.java b/src/main/java/org/mariadb/r2dbc/message/AuthSwitch.java new file mode 100644 index 00000000..042b6135 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/message/AuthSwitch.java @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.message; + +public interface AuthSwitch { + + String getPlugin(); + + byte[] getSeed(); + + MessageSequence getSequencer(); +} diff --git a/src/main/java/org/mariadb/r2dbc/message/client/ClientMessage.java b/src/main/java/org/mariadb/r2dbc/message/ClientMessage.java similarity index 54% rename from src/main/java/org/mariadb/r2dbc/message/client/ClientMessage.java rename to src/main/java/org/mariadb/r2dbc/message/ClientMessage.java index e5a3201a..febce47f 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/ClientMessage.java +++ b/src/main/java/org/mariadb/r2dbc/message/ClientMessage.java @@ -1,17 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.client; +package org.mariadb.r2dbc.message; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; import org.mariadb.r2dbc.message.server.Sequencer; public interface ClientMessage { - default Sequencer getSequencer() { + default MessageSequence getSequencer() { return new Sequencer((byte) 0xff); } + default void releaseEncodedBinds() {} + ByteBuf encode(Context context, ByteBufAllocator byteBufAllocator); + + default void save(ByteBuf buf, int initialReaderIndex) {} + + default void resetSequencer() {} } diff --git a/src/main/java/org/mariadb/r2dbc/message/Context.java b/src/main/java/org/mariadb/r2dbc/message/Context.java new file mode 100644 index 00000000..06b68d4b --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/message/Context.java @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.message; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.r2dbc.spi.IsolationLevel; +import org.mariadb.r2dbc.client.ServerVersion; + +public interface Context { + + long getThreadId(); + + long getServerCapabilities(); + + long getClientCapabilities(); + + short getServerStatus(); + + void setServerStatus(short serverStatus); + + IsolationLevel getIsolationLevel(); + + void setIsolationLevel(IsolationLevel isolationLevel); + + String getDatabase(); + + void setDatabase(String database); + + ServerVersion getVersion(); + + ByteBufAllocator getByteBufAllocator(); + + default void saveRedo(ClientMessage msg, ByteBuf buf, int initialReaderIndex) {} +} diff --git a/src/main/java/org/mariadb/r2dbc/message/MessageSequence.java b/src/main/java/org/mariadb/r2dbc/message/MessageSequence.java new file mode 100644 index 00000000..603f0a55 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/message/MessageSequence.java @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.message; + +public interface MessageSequence { + byte next(); + + void reset(); +} diff --git a/src/main/java/org/mariadb/r2dbc/message/Protocol.java b/src/main/java/org/mariadb/r2dbc/message/Protocol.java new file mode 100644 index 00000000..796674ba --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/message/Protocol.java @@ -0,0 +1,6 @@ +package org.mariadb.r2dbc.message; + +public enum Protocol { + BINARY, + TEXT +} diff --git a/src/main/java/org/mariadb/r2dbc/message/server/ServerMessage.java b/src/main/java/org/mariadb/r2dbc/message/ServerMessage.java similarity index 55% rename from src/main/java/org/mariadb/r2dbc/message/server/ServerMessage.java rename to src/main/java/org/mariadb/r2dbc/message/ServerMessage.java index 1586a6b6..297aa416 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/ServerMessage.java +++ b/src/main/java/org/mariadb/r2dbc/message/ServerMessage.java @@ -1,12 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab -package org.mariadb.r2dbc.message.server; +package org.mariadb.r2dbc.message; public interface ServerMessage { - default Sequencer getSequencer() { - return null; - } default boolean ending() { return false; @@ -15,4 +12,6 @@ default boolean ending() { default boolean resultSetEnd() { return false; } + + default void release() {} } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/AuthMoreRawPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/AuthMoreRawPacket.java index 21fbdf80..65f2c584 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/AuthMoreRawPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/AuthMoreRawPacket.java @@ -1,19 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class AuthMoreRawPacket implements ClientMessage { - private byte[] raw; - private Sequencer sequencer; + private final byte[] raw; + private final MessageSequence sequencer; - public AuthMoreRawPacket(Sequencer sequencer, byte[] raw) { + public AuthMoreRawPacket(MessageSequence sequencer, byte[] raw) { this.sequencer = sequencer; this.raw = raw; } @@ -26,7 +27,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/ChangeSchemaPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/ChangeSchemaPacket.java new file mode 100644 index 00000000..576d5972 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/message/client/ChangeSchemaPacket.java @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.message.client; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.nio.charset.StandardCharsets; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; + +/** + * see COM_INIT_DB https://mariadb.com/kb/en/com_init_db/ COM_INIT_DB is used to specify the default + * schema for the connection. + */ +public final class ChangeSchemaPacket implements ClientMessage { + private final String schema; + + /** + * Constructor + * + * @param schema new default schema + */ + public ChangeSchemaPacket(String schema) { + this.schema = schema; + } + + @Override + public ByteBuf encode(Context context, ByteBufAllocator allocator) { + ByteBuf buf = allocator.ioBuffer(); + buf.writeByte(0x02); + buf.writeCharSequence(this.schema, StandardCharsets.UTF_8); + return buf; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/message/client/ClearPasswordPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/ClearPasswordPacket.java index f9203d12..baf4dc40 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/ClearPasswordPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/ClearPasswordPacket.java @@ -1,20 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class ClearPasswordPacket implements ClientMessage { - private CharSequence password; - private Sequencer sequencer; + private final CharSequence password; + private final MessageSequence sequencer; - public ClearPasswordPacket(Sequencer sequencer, CharSequence password) { + public ClearPasswordPacket(MessageSequence sequencer, CharSequence password) { this.sequencer = sequencer; this.password = password; } @@ -29,7 +30,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/ClosePreparePacket.java b/src/main/java/org/mariadb/r2dbc/message/client/ClosePreparePacket.java index 3fba0344..fbf6351b 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/ClosePreparePacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/ClosePreparePacket.java @@ -1,11 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; /** * COM_STMT_CLOSE packet. See diff --git a/src/main/java/org/mariadb/r2dbc/message/client/Ed25519PasswordPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/Ed25519PasswordPacket.java index 0e31b73d..60ed05ea 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/Ed25519PasswordPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/Ed25519PasswordPacket.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; @@ -10,20 +10,21 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Arrays; -import org.mariadb.r2dbc.authentication.ed25519.math.GroupElement; -import org.mariadb.r2dbc.authentication.ed25519.math.ed25519.ScalarOps; -import org.mariadb.r2dbc.authentication.ed25519.spec.EdDSANamedCurveTable; -import org.mariadb.r2dbc.authentication.ed25519.spec.EdDSAParameterSpec; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.GroupElement; +import org.mariadb.r2dbc.authentication.standard.ed25519.math.ed25519.ScalarOps; +import org.mariadb.r2dbc.authentication.standard.ed25519.spec.EdDSANamedCurveTable; +import org.mariadb.r2dbc.authentication.standard.ed25519.spec.EdDSAParameterSpec; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class Ed25519PasswordPacket implements ClientMessage { - private Sequencer sequencer; - private CharSequence password; - private byte[] seed; + private final MessageSequence sequencer; + private final CharSequence password; + private final byte[] seed; - public Ed25519PasswordPacket(Sequencer sequencer, CharSequence password, byte[] seed) { + public Ed25519PasswordPacket(MessageSequence sequencer, CharSequence password, byte[] seed) { this.sequencer = sequencer; this.password = password; this.seed = seed; @@ -86,7 +87,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/ExecutePacket.java b/src/main/java/org/mariadb/r2dbc/message/client/ExecutePacket.java index 24f1dece..2d89c321 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/ExecutePacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/ExecutePacket.java @@ -1,76 +1,128 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import java.util.Map; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.codec.DataType; -import org.mariadb.r2dbc.codec.Parameter; +import java.util.List; +import org.mariadb.r2dbc.ExceptionFactory; +import org.mariadb.r2dbc.client.Client; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.util.BindEncodedValue; +import org.mariadb.r2dbc.util.ServerPrepareResult; +import reactor.core.publisher.Mono; public final class ExecutePacket implements ClientMessage { - private final Map> parameters; - private final int statementId; - private final Sequencer sequencer = new Sequencer((byte) 0xff); + private final List bindValues; + private int statementId; + private final int parameterCount; + private final String sql; + private final MessageSequence sequencer = new Sequencer((byte) 0xff); + private ByteBuf savedBuf = null; - public ExecutePacket(int statementId, Map> parameters) { - this.parameters = parameters; - this.statementId = statementId; - } - - public Sequencer getSequencer() { - return sequencer; + public ExecutePacket( + String sql, ServerPrepareResult prepareResult, List bindValues) { + this.sql = sql; + this.bindValues = bindValues; + this.statementId = prepareResult == null ? -1 : prepareResult.getStatementId(); + this.parameterCount = prepareResult == null ? bindValues.size() : prepareResult.getNumParams(); } @Override public ByteBuf encode(Context context, ByteBufAllocator allocator) { + if (savedBuf != null) return savedBuf; ByteBuf buf = allocator.ioBuffer(); buf.writeByte(0x17); buf.writeIntLE(statementId); buf.writeByte(0x00); // NO CURSOR buf.writeIntLE(1); // Iteration pos - Integer[] keys = parameters.keySet().toArray(new Integer[0]); - int parameterCount = 0; - for (Integer i : keys) { - if (i + 1 > parameterCount) parameterCount = i + 1; - } - // create null bitmap if (parameterCount > 0) { int nullCount = (parameterCount + 7) / 8; byte[] nullBitsBuffer = new byte[nullCount]; for (int i = 0; i < parameterCount; i++) { - Parameter p = parameters.get(i); - if (p == null || p.isNull()) { + if (bindValues.get(i).getValue() == null) { nullBitsBuffer[i / 8] |= (1 << (i % 8)); } } buf.writeBytes(nullBitsBuffer); buf.writeByte(0x01); // Send Parameter type flag - // Store types of parameters in first in first package that is sent to the server. + // Store types of parameters in first package that is sent to the server. for (int i = 0; i < parameterCount; i++) { - Parameter p = parameters.get(i); - if (p == null) { - buf.writeShortLE(DataType.VARCHAR.get()); - } else { - buf.writeShortLE(p.getBinaryEncodeType().get()); - } + buf.writeShortLE(bindValues.get(i).getCodec().getBinaryEncodeType().get()); } } - // TODO avoid to send long data here. for (int i = 0; i < parameterCount; i++) { - Parameter p = parameters.get(i); - if (p != null && !p.isNull()) { - p.encodeBinary(buf, context); + ByteBuf param = bindValues.get(i).getValue(); + if (param != null) { + buf.writeBytes(param); } } + return buf; } + + public Mono rePrepare(Client client) { + ServerPrepareResult res; + if (client.getPrepareCache() != null && (res = client.getPrepareCache().get(sql)) != null) { + this.forceStatementId(res.getStatementId()); + return Mono.just(this); + } + return client + .sendPrepare(new PreparePacket(sql), ExceptionFactory.INSTANCE, sql) + .flatMap( + serverPrepareResult -> { + this.forceStatementId(serverPrepareResult.getStatementId()); + return Mono.just(this); + }); + } + + public void save(ByteBuf buf, int initialReaderIndex) { + savedBuf = buf.readerIndex(initialReaderIndex).retain(); + } + + public void forceStatementId(int statementId) { + this.statementId = statementId; + if (savedBuf != null) { + // replace byte at position 1 with new statement id + int writerIndex = this.savedBuf.writerIndex(); + this.savedBuf.writerIndex(this.savedBuf.readerIndex() + 1); + this.savedBuf.writeIntLE(statementId); + this.savedBuf.writerIndex(writerIndex); + } + } + + public MessageSequence getSequencer() { + return sequencer; + } + + public void resetSequencer() { + sequencer.reset(); + } + + public String getSql() { + return sql; + } + + @Override + public void releaseEncodedBinds() { + bindValues.forEach( + b -> { + if (b.getValue() != null) b.getValue().release(); + }); + bindValues.clear(); + } + + @Override + public String toString() { + return "ExecutePacket{" + "sql='" + sql + '\'' + '}'; + } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/HandshakeResponse.java b/src/main/java/org/mariadb/r2dbc/message/client/HandshakeResponse.java index e50fc30d..fad1b996 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/HandshakeResponse.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/HandshakeResponse.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; @@ -10,26 +10,26 @@ import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Properties; -import java.util.function.Supplier; import org.mariadb.r2dbc.MariadbConnectionFactoryProvider; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.flow.ClearPasswordPluginFlow; -import org.mariadb.r2dbc.message.flow.NativePasswordPluginFlow; +import org.mariadb.r2dbc.authentication.addon.ClearPasswordPluginFlow; +import org.mariadb.r2dbc.authentication.standard.NativePasswordPluginFlow; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.message.server.InitialHandshakePacket; import org.mariadb.r2dbc.message.server.Sequencer; import org.mariadb.r2dbc.util.BufferUtils; -import org.mariadb.r2dbc.util.PidFactory; +import org.mariadb.r2dbc.util.HostAddress; import org.mariadb.r2dbc.util.constants.Capabilities; public final class HandshakeResponse implements ClientMessage { - private InitialHandshakePacket initialHandshakePacket; - private String username; - private CharSequence password; - private String database; - private Map connectionAttributes; - private String host; - private long clientCapabilities; + private final InitialHandshakePacket initialHandshakePacket; + private final String username; + private final CharSequence password; + private final String database; + private final Map connectionAttributes; + private final HostAddress hostAddress; + private final long clientCapabilities; public HandshakeResponse( InitialHandshakePacket initialHandshakePacket, @@ -37,14 +37,14 @@ public HandshakeResponse( CharSequence password, String database, Map connectionAttributes, - String host, + HostAddress hostAddress, long clientCapabilities) { this.initialHandshakePacket = initialHandshakePacket; this.username = username; this.password = password; this.database = database; this.connectionAttributes = connectionAttributes; - this.host = host; + this.hostAddress = hostAddress; this.clientCapabilities = clientCapabilities; } @@ -64,11 +64,7 @@ public static byte decideLanguage(short serverLanguage, int majorVersion, int mi || (serverLanguage >= 224 && serverLanguage <= 247)) { return (byte) serverLanguage; } - if (majorVersion == 5 && minorVersion <= 1) { - // 5.1 version doesn't know 4 bytes utf8 - return (byte) 33; // utf8_general_ci - } - return (byte) 224; // UTF8MB4_UNICODE_CI; + return (byte) ((majorVersion == 5 && minorVersion <= 1) ? 33 : 224); } @Override @@ -80,18 +76,15 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { initialHandshakePacket.getMajorServerVersion(), initialHandshakePacket.getMinorServerVersion()); - ByteBuf buf = allocator.ioBuffer(4096); + ByteBuf buf = allocator.buffer(4096); final byte[] authData; String authenticationPluginType = initialHandshakePacket.getAuthenticationPluginType(); switch (authenticationPluginType) { case ClearPasswordPluginFlow.TYPE: // TODO check that SSL is enable - if (password == null) { - authData = new byte[0]; - } else { - authData = password.toString().getBytes(StandardCharsets.UTF_8); - } + authData = + (password == null) ? new byte[0] : password.toString().getBytes(StandardCharsets.UTF_8); break; default: @@ -107,17 +100,15 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { buf.writeZero(19); // 19 buf.writeIntLE((int) (clientCapabilities >> 32)); // Maria extended flag - if (username != null && !username.isEmpty()) { - buf.writeCharSequence(username, StandardCharsets.UTF_8); - } else { - // to permit SSO - buf.writeCharSequence(System.getProperty("user.name"), StandardCharsets.UTF_8); - } + // to permit SSO + buf.writeCharSequence( + (username != null && !username.isEmpty()) ? username : System.getProperty("user.name"), + StandardCharsets.UTF_8); buf.writeZero(1); if ((initialHandshakePacket.getCapabilities() & Capabilities.PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) { - BufferUtils.writeLengthEncode(authData.length, buf); + buf.writeBytes(BufferUtils.encodeLength(authData.length)); buf.writeBytes(authData); } else if ((initialHandshakePacket.getCapabilities() & Capabilities.SECURE_CONNECTION) != 0) { buf.writeByte((byte) authData.length); @@ -139,8 +130,8 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { if ((initialHandshakePacket.getCapabilities() & Capabilities.CONNECT_ATTRS) != 0) { ByteBuf bufAttributes = allocator.buffer(2048); - writeConnectAttributes(bufAttributes, connectionAttributes, host); - BufferUtils.writeLengthEncode(bufAttributes.writerIndex(), buf); + writeConnectAttributes(bufAttributes, connectionAttributes, hostAddress); + buf.writeBytes(BufferUtils.encodeLength(bufAttributes.writerIndex())); buf.writeBytes(bufAttributes, 0, bufAttributes.writerIndex()); bufAttributes.release(); } @@ -154,7 +145,7 @@ public Sequencer getSequencer() { } private void writeConnectAttributes( - ByteBuf buf, Map connectionAttributes, String host) { + ByteBuf buf, Map connectionAttributes, HostAddress hostAddress) { BufferUtils.writeLengthEncode("_client_name", buf); BufferUtils.writeLengthEncode(MariadbConnectionFactoryProvider.MARIADB_DRIVER, buf); @@ -170,18 +161,11 @@ private void writeConnectAttributes( } BufferUtils.writeLengthEncode("_server_host", buf); - BufferUtils.writeLengthEncode(host != null ? host : "", buf); + BufferUtils.writeLengthEncode(hostAddress != null ? hostAddress.getHost() : "", buf); BufferUtils.writeLengthEncode("_os", buf); BufferUtils.writeLengthEncode(System.getProperty("os.name"), buf); - final Supplier pidRequest = PidFactory.getInstance(); - String pid = pidRequest.get(); - if (pid != null) { - BufferUtils.writeLengthEncode("_pid", buf); - BufferUtils.writeLengthEncode(pid, buf); - } - BufferUtils.writeLengthEncode("_thread", buf); BufferUtils.writeLengthEncode(Long.toString(Thread.currentThread().getId()), buf); diff --git a/src/main/java/org/mariadb/r2dbc/message/client/NativePasswordPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/NativePasswordPacket.java index 89e709e3..27414e47 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/NativePasswordPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/NativePasswordPacket.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; @@ -8,16 +8,17 @@ import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class NativePasswordPacket implements ClientMessage { - private Sequencer sequencer; - private CharSequence password; - private byte[] seed; + private final MessageSequence sequencer; + private final CharSequence password; + private final byte[] seed; - public NativePasswordPacket(Sequencer sequencer, CharSequence password, byte[] seed) { + public NativePasswordPacket(MessageSequence sequencer, CharSequence password, byte[] seed) { this.sequencer = sequencer; this.password = password; byte[] truncatedSeed = new byte[seed.length - 1]; @@ -62,7 +63,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/PingPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/PingPacket.java index 6f819860..6155ecb1 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/PingPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/PingPacket.java @@ -1,11 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; public final class PingPacket implements ClientMessage { @@ -15,4 +16,9 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { buf.writeByte(0x0e); return buf; } + + @Override + public String toString() { + return "PingPacket{}"; + } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/PreparePacket.java b/src/main/java/org/mariadb/r2dbc/message/client/PreparePacket.java index f66336c0..3d2e754e 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/PreparePacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/PreparePacket.java @@ -1,18 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; import org.mariadb.r2dbc.message.server.Sequencer; import org.mariadb.r2dbc.util.Assert; public final class PreparePacket implements ClientMessage { private final String sql; - private final Sequencer sequencer = new Sequencer((byte) 0xff); + private final MessageSequence sequencer = new Sequencer((byte) 0xff); public PreparePacket(String sql) { this.sql = Assert.requireNonNull(sql, "query must not be null"); @@ -22,10 +24,14 @@ public String getSql() { return sql; } - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } + public void resetSequencer() { + sequencer.reset(); + } + @Override public ByteBuf encode(Context context, ByteBufAllocator allocator) { ByteBuf buf = allocator.ioBuffer(this.sql.length() + 1); diff --git a/src/main/java/org/mariadb/r2dbc/message/client/QueryPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/QueryPacket.java index aaebb51c..60028b53 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/QueryPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/QueryPacket.java @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; import org.mariadb.r2dbc.message.server.Sequencer; import org.mariadb.r2dbc.util.Assert; public final class QueryPacket implements ClientMessage { private final String sql; - private final Sequencer sequencer = new Sequencer((byte) 0xff); + private final MessageSequence sequencer = new Sequencer((byte) 0xff); public QueryPacket(String sql) { this.sql = Assert.requireNonNull(sql, "query must not be null"); @@ -28,7 +30,16 @@ public ByteBuf encode(Context context, ByteBufAllocator byteBufAllocator) { return out; } - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } + + public void resetSequencer() { + sequencer.reset(); + } + + @Override + public String toString() { + return "QueryPacket{" + "sql='" + sql + '\'' + '}'; + } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/QueryWithParametersPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/QueryWithParametersPacket.java index 33d4090d..97cb4916 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/QueryWithParametersPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/QueryWithParametersPacket.java @@ -1,33 +1,40 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import java.nio.charset.StandardCharsets; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.codec.Parameter; +import java.util.List; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; import org.mariadb.r2dbc.message.server.Sequencer; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.BindEncodedValue; import org.mariadb.r2dbc.util.ClientPrepareResult; public final class QueryWithParametersPacket implements ClientMessage { private final ClientPrepareResult prepareResult; - private final Parameter[] parameters; + private final List bindValues; + private final MessageSequence sequencer = new Sequencer((byte) 0xff); private final String[] generatedColumns; - private final Sequencer sequencer = new Sequencer((byte) 0xff); + private ByteBuf savedBuf = null; public QueryWithParametersPacket( - ClientPrepareResult prepareResult, Parameter[] parameters, String[] generatedColumns) { + ClientPrepareResult prepareResult, + List bindValues, + String[] generatedColumns) { this.prepareResult = prepareResult; - this.parameters = parameters; + this.bindValues = bindValues; this.generatedColumns = generatedColumns; } @Override public ByteBuf encode(Context context, ByteBufAllocator byteBufAllocator) { + if (savedBuf != null) return savedBuf; Assert.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); String additionalReturningPart = null; if (generatedColumns != null) { @@ -47,7 +54,12 @@ public ByteBuf encode(Context context, ByteBufAllocator byteBufAllocator) { } else { out.writeBytes(prepareResult.getQueryParts().get(0)); for (int i = 0; i < prepareResult.getParamCount(); i++) { - parameters[i].encodeText(out, context); + BindEncodedValue param = bindValues.get(i); + if (param.getValue() == null) { + out.writeBytes("null".getBytes(StandardCharsets.US_ASCII)); + } else { + out.writeBytes(param.getValue()); + } out.writeBytes(prepareResult.getQueryParts().get(i + 1)); } if (additionalReturningPart != null) @@ -56,7 +68,24 @@ public ByteBuf encode(Context context, ByteBufAllocator byteBufAllocator) { return out; } - public Sequencer getSequencer() { + public void save(ByteBuf buf, int initialReaderIndex) { + savedBuf = buf.readerIndex(initialReaderIndex).retain(); + } + + @Override + public void releaseEncodedBinds() { + bindValues.forEach( + b -> { + if (b.getValue() != null) b.getValue().release(); + }); + bindValues.clear(); + } + + public void resetSequencer() { + sequencer.reset(); + } + + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/QuitPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/QuitPacket.java index a3005ea8..f9608039 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/QuitPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/QuitPacket.java @@ -1,11 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; public final class QuitPacket implements ClientMessage { public static final QuitPacket INSTANCE = new QuitPacket(); diff --git a/src/main/java/org/mariadb/r2dbc/message/client/RsaPublicKeyRequestPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/RsaPublicKeyRequestPacket.java index 59590fb3..d9f49d34 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/RsaPublicKeyRequestPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/RsaPublicKeyRequestPacket.java @@ -1,18 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class RsaPublicKeyRequestPacket implements ClientMessage { - private Sequencer sequencer; + private final MessageSequence sequencer; - public RsaPublicKeyRequestPacket(Sequencer sequencer) { + public RsaPublicKeyRequestPacket(MessageSequence sequencer) { this.sequencer = sequencer; } @@ -24,7 +25,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/Sha256PasswordPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/Sha256PasswordPacket.java index fa1d0952..cb2103e8 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/Sha256PasswordPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/Sha256PasswordPacket.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; @@ -11,18 +11,19 @@ import java.security.PublicKey; import java.util.Arrays; import javax.crypto.Cipher; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class Sha256PasswordPacket implements ClientMessage { - private Sequencer sequencer; - private CharSequence password; - private byte[] seed; - private PublicKey publicKey; + private final MessageSequence sequencer; + private final CharSequence password; + private final byte[] seed; + private final PublicKey publicKey; public Sha256PasswordPacket( - Sequencer sequencer, CharSequence password, byte[] seed, PublicKey publicKey) { + MessageSequence sequencer, CharSequence password, byte[] seed, PublicKey publicKey) { this.sequencer = sequencer; this.password = password; byte[] truncatedSeed = new byte[seed.length - 1]; @@ -72,7 +73,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/Sha2PublicKeyRequestPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/Sha2PublicKeyRequestPacket.java index 91590847..b45f61b9 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/Sha2PublicKeyRequestPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/Sha2PublicKeyRequestPacket.java @@ -1,18 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.message.server.Sequencer; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; public final class Sha2PublicKeyRequestPacket implements ClientMessage { - private Sequencer sequencer; + private final MessageSequence sequencer; - public Sha2PublicKeyRequestPacket(Sequencer sequencer) { + public Sha2PublicKeyRequestPacket(MessageSequence sequencer) { this.sequencer = sequencer; } @@ -24,7 +25,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } } diff --git a/src/main/java/org/mariadb/r2dbc/message/client/SslRequestPacket.java b/src/main/java/org/mariadb/r2dbc/message/client/SslRequestPacket.java index e88acc2a..5f595594 100644 --- a/src/main/java/org/mariadb/r2dbc/message/client/SslRequestPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/client/SslRequestPacket.java @@ -1,18 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.client; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; import org.mariadb.r2dbc.message.server.InitialHandshakePacket; -import org.mariadb.r2dbc.message.server.Sequencer; public final class SslRequestPacket implements ClientMessage { - private InitialHandshakePacket initialHandshakePacket; - private long clientCapabilities; + private final InitialHandshakePacket initialHandshakePacket; + private final long clientCapabilities; public SslRequestPacket(InitialHandshakePacket initialHandshakePacket, long clientCapabilities) { this.initialHandshakePacket = initialHandshakePacket; @@ -28,7 +29,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { initialHandshakePacket.getMajorServerVersion(), initialHandshakePacket.getMinorServerVersion()); - ByteBuf buf = allocator.ioBuffer(32); + ByteBuf buf = allocator.buffer(32, 32); buf.writeIntLE((int) clientCapabilities); buf.writeIntLE(1024 * 1024 * 1024); @@ -40,7 +41,7 @@ public ByteBuf encode(Context context, ByteBufAllocator allocator) { } @Override - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return initialHandshakePacket.getSequencer(); } } diff --git a/src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlow.java b/src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlow.java index 2cf2fafd..d5bae8c2 100644 --- a/src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlow.java +++ b/src/main/java/org/mariadb/r2dbc/message/flow/AuthenticationFlow.java @@ -1,21 +1,27 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.flow; import io.r2dbc.spi.R2dbcException; import io.r2dbc.spi.R2dbcNonTransientResourceException; +import io.r2dbc.spi.R2dbcPermissionDeniedException; +import java.util.Arrays; import org.mariadb.r2dbc.ExceptionFactory; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.SslMode; +import org.mariadb.r2dbc.authentication.AuthenticationFlowPluginLoader; import org.mariadb.r2dbc.authentication.AuthenticationPlugin; import org.mariadb.r2dbc.client.Client; import org.mariadb.r2dbc.client.DecoderState; -import org.mariadb.r2dbc.message.client.ClientMessage; +import org.mariadb.r2dbc.client.SimpleClient; +import org.mariadb.r2dbc.message.ClientMessage; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.message.client.HandshakeResponse; import org.mariadb.r2dbc.message.client.SslRequestPacket; import org.mariadb.r2dbc.message.server.*; import org.mariadb.r2dbc.util.Assert; +import org.mariadb.r2dbc.util.HostAddress; import org.mariadb.r2dbc.util.constants.Capabilities; import reactor.core.publisher.*; import reactor.util.Logger; @@ -24,22 +30,26 @@ public final class AuthenticationFlow { private static final Logger logger = Loggers.getLogger(AuthenticationFlow.class); - private MariadbConnectionConfiguration configuration; + private final MariadbConnectionConfiguration configuration; private InitialHandshakePacket initialHandshakePacket; private AuthenticationPlugin pluginHandler; private AuthSwitchPacket authSwitchPacket; private AuthMoreDataPacket authMoreDataPacket; - private Client client; + private final SimpleClient client; private FluxSink sink; + private final HostAddress hostAddress; private long clientCapabilities; - private AuthenticationFlow(Client client, MariadbConnectionConfiguration configuration) { + private AuthenticationFlow( + SimpleClient client, MariadbConnectionConfiguration configuration, HostAddress hostAddress) { this.client = client; this.configuration = configuration; + this.hostAddress = hostAddress; } - public static Mono exchange(Client client, MariadbConnectionConfiguration configuration) { - AuthenticationFlow flow = new AuthenticationFlow(client, configuration); + public static Mono exchange( + SimpleClient client, MariadbConnectionConfiguration configuration, HostAddress hostAddress) { + AuthenticationFlow flow = new AuthenticationFlow(client, configuration, hostAddress); Assert.requireNonNull(client, "client must not be null"); return Flux.create( @@ -50,7 +60,7 @@ public static Mono exchange(Client client, MariadbConnectionConfiguratio .doOnNext( state -> { if (State.COMPLETED == state) { - if (flow.authMoreDataPacket != null) flow.authMoreDataPacket.deallocate(); + if (flow.authMoreDataPacket != null) flow.authMoreDataPacket.release(); flow.sink.complete(); } else { if (logger.isTraceEnabled()) { @@ -93,14 +103,11 @@ private static long initializeClientCapabilities( capabilities |= Capabilities.MULTI_STATEMENTS; } - if ((serverCapabilities & Capabilities.CLIENT_DEPRECATE_EOF) != 0) { - capabilities |= Capabilities.CLIENT_DEPRECATE_EOF; - } - if (configuration.getDatabase() != null) { capabilities |= Capabilities.CONNECT_WITH_DB; } - return capabilities; + + return capabilities & serverCapabilities; } private HandshakeResponse createHandshakeResponse(long clientCapabilities) { @@ -110,7 +117,7 @@ private HandshakeResponse createHandshakeResponse(long clientCapabilities) { this.configuration.getPassword(), this.configuration.getDatabase(), configuration.getConnectionAttributes(), - configuration.getHost(), + this.hostAddress, clientCapabilities); } @@ -130,13 +137,12 @@ Mono handle(AuthenticationFlow flow) { if (message instanceof ErrorPacket) { sink.error(ExceptionFactory.INSTANCE.from((ErrorPacket) message)); } else if (message instanceof InitialHandshakePacket) { - flow.client.setContext((InitialHandshakePacket) message); - // TODO SET connection context with server data. InitialHandshakePacket packet = (InitialHandshakePacket) message; flow.initialHandshakePacket = packet; flow.clientCapabilities = initializeClientCapabilities( flow.initialHandshakePacket.getCapabilities(), flow.configuration); + flow.client.setContext(packet, flow.clientCapabilities); if (flow.configuration.getSslConfig().getSslMode() != SslMode.DISABLE) { if ((packet.getCapabilities() & Capabilities.SSL) == 0) { @@ -180,23 +186,32 @@ Mono handle(AuthenticationFlow flow) { return flow.client .sendCommand( flow.createHandshakeResponse(flow.clientCapabilities), - DecoderState.AUTHENTICATION_SWITCH_RESPONSE) + DecoderState.AUTHENTICATION_SWITCH_RESPONSE, + false) .handle( (message, sink) -> { if (message instanceof ErrorPacket) { - R2dbcException exception = - ExceptionFactory.createException((ErrorPacket) message, null); - sink.error( - new R2dbcNonTransientResourceException(exception.getMessage(), exception)); + sink.error(ExceptionFactory.createException((ErrorPacket) message, null)); } else if (message instanceof OkPacket) { sink.next(COMPLETED); } else if (message instanceof AuthSwitchPacket) { flow.authSwitchPacket = ((AuthSwitchPacket) message); String plugin = flow.authSwitchPacket.getPlugin(); - AuthenticationPlugin authPlugin = AuthenticationFlowPluginLoader.get(plugin); - flow.authMoreDataPacket = null; - flow.pluginHandler = authPlugin; - sink.next(AUTH_SWITCH); + if (flow.configuration.getRestrictedAuth() != null + && !Arrays.stream(flow.configuration.getRestrictedAuth()) + .anyMatch(s -> plugin.equals(s))) { + sink.error( + new R2dbcPermissionDeniedException( + String.format( + "Unsupported authentication plugin %s. Authorized plugin: %s", + plugin, + Arrays.toString(flow.configuration.getRestrictedAuth())))); + } else { + AuthenticationPlugin authPlugin = AuthenticationFlowPluginLoader.get(plugin); + flow.authMoreDataPacket = null; + flow.pluginHandler = authPlugin; + sink.next(AUTH_SWITCH); + } } else { sink.error( new IllegalStateException( @@ -226,7 +241,8 @@ Mono handle(AuthenticationFlow flow) { // this can occur when there is a "finishing" message for authentication plugin // example CachingSha2PasswordFlow that finish with a successful FAST_AUTH flux = - flow.client.sendCommand(clientMessage, DecoderState.AUTHENTICATION_SWITCH_RESPONSE); + flow.client.sendCommand( + clientMessage, DecoderState.AUTHENTICATION_SWITCH_RESPONSE, false); } else { flux = flow.client.receive(DecoderState.AUTHENTICATION_SWITCH_RESPONSE); } @@ -235,17 +251,27 @@ Mono handle(AuthenticationFlow flow) { (message, sink) -> { if (message instanceof ErrorPacket) { sink.error( - new R2dbcNonTransientResourceException( - ((ErrorPacket) message).getMessage())); + new R2dbcNonTransientResourceException(((ErrorPacket) message).message())); } else if (message instanceof OkPacket) { sink.next(COMPLETED); } else if (message instanceof AuthSwitchPacket) { flow.authSwitchPacket = ((AuthSwitchPacket) message); String plugin = flow.authSwitchPacket.getPlugin(); - AuthenticationPlugin authPlugin = AuthenticationFlowPluginLoader.get(plugin); - flow.authMoreDataPacket = null; - flow.pluginHandler = authPlugin; - sink.next(AUTH_SWITCH); + if (flow.configuration.getRestrictedAuth() != null + && !Arrays.stream(flow.configuration.getRestrictedAuth()) + .anyMatch(s -> plugin.equals(s))) { + sink.error( + new R2dbcPermissionDeniedException( + String.format( + "Unsupported authentication plugin %s. Authorized plugin: %s", + plugin, + Arrays.toString(flow.configuration.getRestrictedAuth())))); + } else { + AuthenticationPlugin authPlugin = AuthenticationFlowPluginLoader.get(plugin); + flow.authMoreDataPacket = null; + flow.pluginHandler = authPlugin; + sink.next(AUTH_SWITCH); + } } else if (message instanceof AuthMoreDataPacket) { flow.authMoreDataPacket = (AuthMoreDataPacket) message; sink.next(AUTH_SWITCH); diff --git a/src/main/java/org/mariadb/r2dbc/message/server/AuthMoreDataPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/AuthMoreDataPacket.java index b8e96459..db0b0002 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/AuthMoreDataPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/AuthMoreDataPacket.java @@ -1,35 +1,39 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.AuthMoreData; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.MessageSequence; +import org.mariadb.r2dbc.message.ServerMessage; -public class AuthMoreDataPacket implements ServerMessage { +public class AuthMoreDataPacket implements AuthMoreData, ServerMessage { - private Sequencer sequencer; + private final MessageSequence sequencer; private ByteBuf buf; - private AuthMoreDataPacket(Sequencer sequencer, ByteBuf buf) { + private AuthMoreDataPacket(MessageSequence sequencer, ByteBuf buf) { this.sequencer = sequencer; this.buf = buf; } - public static AuthMoreDataPacket decode(Sequencer sequencer, ByteBuf buf, Context context) { + public static AuthMoreDataPacket decode(MessageSequence sequencer, ByteBuf buf, Context context) { buf.skipBytes(1); + buf.retain(); ByteBuf data = buf.readRetainedSlice(buf.readableBytes()); return new AuthMoreDataPacket(sequencer, data); } - public void deallocate() { + public void release() { if (buf != null) { buf.release(); buf = null; } } - public Sequencer getSequencer() { + public MessageSequence getSequencer() { return sequencer; } diff --git a/src/main/java/org/mariadb/r2dbc/message/server/AuthSwitchPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/AuthSwitchPacket.java index 87f87dc7..efda7453 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/AuthSwitchPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/AuthSwitchPacket.java @@ -1,17 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; import java.nio.charset.StandardCharsets; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.AuthSwitch; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; -public class AuthSwitchPacket implements ServerMessage { +public class AuthSwitchPacket implements AuthSwitch, ServerMessage { - private Sequencer sequencer; - private String plugin; - private byte[] seed; + private final Sequencer sequencer; + private final String plugin; + private final byte[] seed; public AuthSwitchPacket(Sequencer sequencer, String plugin, byte[] seed) { this.sequencer = sequencer; diff --git a/src/main/java/org/mariadb/r2dbc/message/server/ColumnCountPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/ColumnCountPacket.java index 013292fb..df6d0653 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/ColumnCountPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/ColumnCountPacket.java @@ -1,17 +1,18 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.BufferUtils; import org.mariadb.r2dbc.util.constants.Capabilities; public class ColumnCountPacket implements ServerMessage { - private int columnCount; - private boolean metaFollows; + private final int columnCount; + private final boolean metaFollows; public ColumnCountPacket(int columnCount, boolean metaFollows) { this.columnCount = columnCount; @@ -20,11 +21,11 @@ public ColumnCountPacket(int columnCount, boolean metaFollows) { public static ColumnCountPacket decode(Sequencer sequencer, ByteBuf buf, Context context) { long columnCount = BufferUtils.readLengthEncodedInt(buf); + boolean metaFollow = true; if ((context.getServerCapabilities() & Capabilities.MARIADB_CLIENT_CACHE_METADATA) > 0) { - int metaFollow = buf.readByte(); - return new ColumnCountPacket((int) columnCount, metaFollow == 1); + metaFollow = buf.readByte() == 1; } - return new ColumnCountPacket((int) columnCount, true); + return new ColumnCountPacket((int) columnCount, metaFollow); } public int getColumnCount() { diff --git a/src/main/java/org/mariadb/r2dbc/message/server/ColumnDefinitionPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/ColumnDefinitionPacket.java index dfa97e11..3283b63a 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/ColumnDefinitionPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/ColumnDefinitionPacket.java @@ -1,29 +1,24 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; -import io.r2dbc.spi.Blob; +import io.r2dbc.spi.ColumnMetadata; import io.r2dbc.spi.Nullability; -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.ByteBuffer; +import io.r2dbc.spi.OutParameterMetadata; import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.LocalDate; -import java.time.LocalDateTime; -import java.util.BitSet; import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.client.Context; -import org.mariadb.r2dbc.codec.Codec; import org.mariadb.r2dbc.codec.DataType; -import org.mariadb.r2dbc.codec.list.*; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; +import org.mariadb.r2dbc.util.MariadbType; import org.mariadb.r2dbc.util.constants.ColumnFlags; import reactor.util.Logger; import reactor.util.Loggers; -public final class ColumnDefinitionPacket implements ServerMessage { +public final class ColumnDefinitionPacket + implements ServerMessage, ColumnMetadata, OutParameterMetadata { private static final Logger logger = Loggers.getLogger(ColumnDefinitionPacket.class); // This array stored character length for every collation id up to collation id 256 @@ -70,7 +65,8 @@ public final class ColumnDefinitionPacket implements ServerMessage { private final DataType dataType; private final byte decimals; private final int flags; - private boolean ending; + private final boolean ending; + private final MariadbConnectionConfiguration conf; private ColumnDefinitionPacket( byte[] meta, @@ -79,7 +75,8 @@ private ColumnDefinitionPacket( DataType dataType, byte decimals, int flags, - boolean ending) { + boolean ending, + MariadbConnectionConfiguration conf) { this.meta = meta; this.charset = charset; this.length = length; @@ -87,9 +84,10 @@ private ColumnDefinitionPacket( this.decimals = decimals; this.flags = flags; this.ending = ending; + this.conf = conf; } - private ColumnDefinitionPacket(String name) { + private ColumnDefinitionPacket(String name, MariadbConnectionConfiguration conf) { byte[] nameBytes = name.getBytes(StandardCharsets.UTF_8); byte[] arr = new byte[6 + 2 * nameBytes.length]; int pos = 0; @@ -117,10 +115,15 @@ private ColumnDefinitionPacket(String name) { this.decimals = 0; this.flags = ColumnFlags.PRIMARY_KEY; this.ending = false; + this.conf = conf; } public static ColumnDefinitionPacket decode( - Sequencer sequencer, ByteBuf buf, Context context, boolean ending) { + Sequencer sequencer, + ByteBuf buf, + Context context, + boolean ending, + MariadbConnectionConfiguration conf) { byte[] meta = new byte[buf.readableBytes() - 12]; buf.readBytes(meta); int charset = buf.readUnsignedShortLE(); @@ -128,11 +131,13 @@ public static ColumnDefinitionPacket decode( DataType dataType = DataType.fromServer(buf.readUnsignedByte(), charset); int flags = buf.readUnsignedShortLE(); byte decimals = buf.readByte(); - return new ColumnDefinitionPacket(meta, charset, length, dataType, decimals, flags, ending); + return new ColumnDefinitionPacket( + meta, charset, length, dataType, decimals, flags, ending, conf); } - public static ColumnDefinitionPacket fromGeneratedId(String name) { - return new ColumnDefinitionPacket(name); + public static ColumnDefinitionPacket fromGeneratedId( + String name, MariadbConnectionConfiguration conf) { + return new ColumnDefinitionPacket(name, conf); } private String getString(int idx) { @@ -159,7 +164,8 @@ public String getTable() { return this.getString(3); } - public String getColumnAlias() { + @Override + public String getName() { return this.getString(4); } @@ -175,7 +181,7 @@ public long getLength() { return length; } - public DataType getType() { + public DataType getDataType() { return dataType; } @@ -188,7 +194,7 @@ public boolean isSigned() { } public int getDisplaySize() { - if (dataType == DataType.VARCHAR + if (dataType == DataType.TEXT || dataType == DataType.JSON || dataType == DataType.ENUM || dataType == DataType.SET @@ -229,111 +235,109 @@ public boolean isBinary() { return (charset == 63); } - public Class getJavaClass() { + public MariadbType getType() { switch (dataType) { case TINYINT: - return isSigned() ? Byte.class : Short.class; + // TINYINT(1) are considered as boolean + if (length == 1 && conf.tinyInt1isBit()) return MariadbType.BOOLEAN; + return isSigned() ? MariadbType.TINYINT : MariadbType.UNSIGNED_TINYINT; + case YEAR: + return MariadbType.SMALLINT; case SMALLINT: - return isSigned() ? Short.class : Integer.class; + return isSigned() ? MariadbType.SMALLINT : MariadbType.UNSIGNED_SMALLINT; case INTEGER: - return isSigned() ? Integer.class : Long.class; + return isSigned() ? MariadbType.INTEGER : MariadbType.UNSIGNED_INTEGER; case FLOAT: - return Float.class; + return MariadbType.FLOAT; case DOUBLE: - return Double.class; + return MariadbType.DOUBLE; case TIMESTAMP: case DATETIME: - return LocalDateTime.class; + return MariadbType.TIMESTAMP; case BIGINT: - return isSigned() ? Long.class : BigInteger.class; + return isSigned() ? MariadbType.BIGINT : MariadbType.UNSIGNED_BIGINT; case MEDIUMINT: - return Integer.class; + return MariadbType.INTEGER; case DATE: case NEWDATE: - return LocalDate.class; + return MariadbType.DATE; case TIME: - return Duration.class; - case YEAR: - return Short.class; - case VARCHAR: + return MariadbType.TIME; case JSON: case ENUM: case SET: - case VARSTRING: case STRING: - return isBinary() ? ByteBuffer.class : String.class; + case VARSTRING: + case NULL: + return isBinary() ? MariadbType.BYTES : MariadbType.VARCHAR; + case TEXT: + return MariadbType.CLOB; case OLDDECIMAL: case DECIMAL: - return BigDecimal.class; + return MariadbType.DECIMAL; case BIT: - return BitSet.class; + // BIT(1) are considered as boolean + if (length == 1 && conf.tinyInt1isBit()) return MariadbType.BOOLEAN; + return MariadbType.BIT; case TINYBLOB: case MEDIUMBLOB: case LONGBLOB: case BLOB: case GEOMETRY: - return Blob.class; + return MariadbType.BLOB; default: return null; } } - public Codec getDefaultCodec(MariadbConnectionConfiguration conf) { + @Override + public Integer getPrecision() { switch (dataType) { - case VARCHAR: - case JSON: - case ENUM: - case SET: - case VARSTRING: - case STRING: - case NULL: - return isBinary() ? ByteArrayCodec.INSTANCE : StringCodec.INSTANCE; + case OLDDECIMAL: + case DECIMAL: + // DECIMAL and OLDDECIMAL are "exact" fixed-point number. + // so : + // - if can be signed, 1 byte is saved for sign + // - if decimal > 0, one byte more for dot + if (isSigned()) { + return (int) (length - ((getDecimals() > 0) ? 2 : 1)); + } else { + return (int) (length - ((decimals > 0) ? 1 : 0)); + } + default: + return (int) length; + } + } + + @Override + public Integer getScale() { + switch (dataType) { + case OLDDECIMAL: case TINYINT: - // TINYINT(1) are considered as boolean - if (length == 1 && conf.tinyInt1isBit()) return BooleanCodec.INSTANCE; - return isSigned() ? ByteCodec.INSTANCE : ShortCodec.INSTANCE; case SMALLINT: - return isSigned() ? ShortCodec.INSTANCE : IntCodec.INSTANCE; case INTEGER: - return isSigned() ? IntCodec.INSTANCE : LongCodec.INSTANCE; case FLOAT: - return FloatCodec.INSTANCE; case DOUBLE: - return DoubleCodec.INSTANCE; - case TIMESTAMP: - case DATETIME: - return LocalDateTimeCodec.INSTANCE; case BIGINT: - return isSigned() ? LongCodec.INSTANCE : BigIntegerCodec.INSTANCE; case MEDIUMINT: - return IntCodec.INSTANCE; - case DATE: - case NEWDATE: - return LocalDateCodec.INSTANCE; - case TIME: - return DurationCodec.INSTANCE; - case YEAR: - return ShortCodec.INSTANCE; - case OLDDECIMAL: - case DECIMAL: - return BigDecimalCodec.INSTANCE; case BIT: - // BIT(1) are considered as boolean - if (length == 1 && conf.tinyInt1isBit()) return BooleanCodec.INSTANCE; - return BitSetCodec.INSTANCE; - case GEOMETRY: - return ByteArrayCodec.INSTANCE; - case TINYBLOB: - case MEDIUMBLOB: - case LONGBLOB: - case BLOB: - return BlobCodec.INSTANCE; + case DECIMAL: + return (int) decimals; default: - return null; + return 0; } } + @Override + public Class getJavaType() { + return getType().getJavaType(); + } + + public ColumnDefinitionPacket getNativeTypeMetadata() { + return this; + } + @Override public boolean ending() { return this.ending; diff --git a/src/main/java/org/mariadb/r2dbc/message/server/CompletePrepareResult.java b/src/main/java/org/mariadb/r2dbc/message/server/CompletePrepareResult.java index 546c08c1..c6568625 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/CompletePrepareResult.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/CompletePrepareResult.java @@ -1,14 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.ServerPrepareResult; public final class CompletePrepareResult implements ServerMessage { private final ServerPrepareResult prepare; - private boolean continueOnEnd; + private final boolean continueOnEnd; public CompletePrepareResult(final ServerPrepareResult prepare, boolean continueOnEnd) { this.prepare = prepare; @@ -17,7 +18,7 @@ public CompletePrepareResult(final ServerPrepareResult prepare, boolean continue @Override public boolean ending() { - return continueOnEnd; + return !continueOnEnd; } public ServerPrepareResult getPrepare() { diff --git a/src/main/java/org/mariadb/r2dbc/message/server/EofPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/EofPacket.java index debeac3e..477b7a29 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/EofPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/EofPacket.java @@ -1,10 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.constants.ServerStatus; public class EofPacket implements ServerMessage { @@ -42,6 +43,25 @@ public static EofPacket decode( resultSetEnd && (serverStatus & ServerStatus.MORE_RESULTS_EXISTS) == 0); } + /** + * This is for mysql that doesn't send MORE_RESULTS_EXISTS flag, but sending an OK_Packet after, + * breaking protocol. + * + * @param sequencer sequencer + * @param buf current EOF buf + * @param context current context + * @return Eof packet + */ + public static EofPacket decodeOutputParam(Sequencer sequencer, ByteBuf buf, Context context) { + buf.skipBytes(1); + short warningCount = buf.readShortLE(); + short serverStatus = + (short) + (buf.readShortLE() | ServerStatus.PS_OUT_PARAMETERS | ServerStatus.MORE_RESULTS_EXISTS); + context.setServerStatus(serverStatus); + return new EofPacket(sequencer, serverStatus, warningCount, false, false); + } + public short getServerStatus() { return serverStatus; } @@ -59,4 +79,18 @@ public boolean ending() { public boolean resultSetEnd() { return resultSetEnd; } + + @Override + public String toString() { + return "EofPacket{" + + "serverStatus=" + + serverStatus + + ", warningCount=" + + warningCount + + ", ending=" + + ending + + ", resultSetEnd=" + + resultSetEnd + + '}'; + } } diff --git a/src/main/java/org/mariadb/r2dbc/message/server/ErrorPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/ErrorPacket.java index b3320de7..5fa8b736 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/ErrorPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/ErrorPacket.java @@ -1,20 +1,24 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; +import io.r2dbc.spi.R2dbcException; +import io.r2dbc.spi.Result; import java.nio.charset.StandardCharsets; +import org.mariadb.r2dbc.ExceptionFactory; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.Assert; import reactor.util.Logger; import reactor.util.Loggers; -public final class ErrorPacket implements ServerMessage { +public final class ErrorPacket implements ServerMessage, Result.Message { private static final Logger logger = Loggers.getLogger(ErrorPacket.class); private final short errorCode; private final String message; private final String sqlState; - private Sequencer sequencer; + private final Sequencer sequencer; private final boolean ending; private ErrorPacket( @@ -47,18 +51,30 @@ public static ErrorPacket decode(Sequencer sequencer, ByteBuf buf, boolean endin return err; } - public short getErrorCode() { - return errorCode; - } - public String getMessage() { return message; } - public String getSqlState() { + @Override + public R2dbcException exception() { + return ExceptionFactory.createException(this, null); + } + + @Override + public int errorCode() { + return errorCode; + } + + @Override + public String sqlState() { return sqlState; } + @Override + public String message() { + return message; + } + @Override public boolean ending() { return ending; diff --git a/src/main/java/org/mariadb/r2dbc/message/server/InitialHandshakePacket.java b/src/main/java/org/mariadb/r2dbc/message/server/InitialHandshakePacket.java index a4056739..8a15907e 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/InitialHandshakePacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/InitialHandshakePacket.java @@ -1,25 +1,26 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; import java.nio.charset.StandardCharsets; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.constants.Capabilities; public final class InitialHandshakePacket implements ServerMessage { private static final String MARIADB_RPL_HACK_PREFIX = "5.5.5-"; - private Sequencer sequencer; - private String serverVersion; - private long threadId; - private byte[] seed; - private long capabilities; - private short defaultCollation; - private short serverStatus; - private boolean mariaDBServer; - private String authenticationPluginType; + private final Sequencer sequencer; + private final String serverVersion; + private final long threadId; + private final byte[] seed; + private final long capabilities; + private final short defaultCollation; + private final short serverStatus; + private final boolean mariaDBServer; + private final String authenticationPluginType; private int majorVersion; private int minorVersion; private int patchVersion; @@ -172,7 +173,6 @@ public String getAuthenticationPluginType() { return authenticationPluginType; } - @Override public Sequencer getSequencer() { return sequencer; } diff --git a/src/main/java/org/mariadb/r2dbc/message/server/OkPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/OkPacket.java index e572453f..94baf1c8 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/OkPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/OkPacket.java @@ -1,10 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; +import io.r2dbc.spi.IsolationLevel; +import io.r2dbc.spi.Result; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.BufferUtils; import org.mariadb.r2dbc.util.constants.Capabilities; import org.mariadb.r2dbc.util.constants.ServerStatus; @@ -12,7 +15,7 @@ import reactor.util.Logger; import reactor.util.Loggers; -public class OkPacket implements ServerMessage { +public class OkPacket implements ServerMessage, Result.UpdateCount { public static final byte TYPE = (byte) 0x00; private static final Logger logger = Loggers.getLogger(OkPacket.class); private final Sequencer sequencer; @@ -56,13 +59,35 @@ public static OkPacket decode(Sequencer sequencer, ByteBuf buf, Context context) String variable = BufferUtils.readLengthEncodedString(sessionVariableBuf); String value = BufferUtils.readLengthEncodedString(sessionVariableBuf); logger.debug("System variable change : {} = {}", variable, value); + + switch (variable) { + case "transaction_isolation": + case "tx_isolation": + switch (value) { + case "REPEATABLE-READ": + context.setIsolationLevel(IsolationLevel.REPEATABLE_READ); + break; + + case "READ-UNCOMMITTED": + context.setIsolationLevel(IsolationLevel.READ_UNCOMMITTED); + break; + + case "SERIALIZABLE": + context.setIsolationLevel(IsolationLevel.SERIALIZABLE); + break; + + default: + context.setIsolationLevel(IsolationLevel.READ_COMMITTED); + break; + } + } break; case StateChange.SESSION_TRACK_SCHEMA: ByteBuf sessionSchemaBuf = BufferUtils.readLengthEncodedBuffer(stateInfo); - String database = BufferUtils.readLengthEncodedString(sessionSchemaBuf); - // context.setDatabase(database); - logger.debug("Database change : now is '{}'", database); + String schema = BufferUtils.readLengthEncodedString(sessionSchemaBuf); + context.setDatabase(schema); + logger.debug("Schema change : now is '{}'", schema); break; } } @@ -78,10 +103,6 @@ public static OkPacket decode(Sequencer sequencer, ByteBuf buf, Context context) (serverStatus & ServerStatus.MORE_RESULTS_EXISTS) == 0); } - public long getAffectedRows() { - return affectedRows; - } - public long getLastInsertId() { return lastInsertId; } @@ -103,4 +124,25 @@ public boolean ending() { public boolean resultSetEnd() { return true; } + + @Override + public long value() { + return affectedRows; + } + + @Override + public String toString() { + return "OkPacket{" + + "affectedRows=" + + affectedRows + + ", lastInsertId=" + + lastInsertId + + ", serverStatus=" + + serverStatus + + ", warningCount=" + + warningCount + + ", ending=" + + ending + + '}'; + } } diff --git a/src/main/java/org/mariadb/r2dbc/message/server/PrepareResultPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/PrepareResultPacket.java index 041e8ac3..647b017c 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/PrepareResultPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/PrepareResultPacket.java @@ -1,10 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.Context; +import org.mariadb.r2dbc.message.ServerMessage; import org.mariadb.r2dbc.util.constants.Capabilities; public final class PrepareResultPacket implements ServerMessage { @@ -13,8 +14,8 @@ public final class PrepareResultPacket implements ServerMessage { private final int numColumns; private final int numParams; private final boolean eofDeprecated; - private Sequencer sequencer; - private boolean continueOnEnd; + private final Sequencer sequencer; + private final boolean continueOnEnd; private PrepareResultPacket( final Sequencer sequencer, @@ -52,7 +53,7 @@ public static PrepareResultPacket decode( statementId, numColumns, numParams, - ((context.getServerCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0), + ((context.getClientCapabilities() & Capabilities.CLIENT_DEPRECATE_EOF) > 0), continueOnEnd); } diff --git a/src/main/java/org/mariadb/r2dbc/message/server/RowPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/RowPacket.java index 37176875..43712901 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/RowPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/RowPacket.java @@ -1,13 +1,14 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; import io.netty.buffer.ByteBuf; +import org.mariadb.r2dbc.message.ServerMessage; public final class RowPacket implements ServerMessage { - private ByteBuf raw; + private final ByteBuf raw; public RowPacket(ByteBuf raw) { this.raw = raw.retain(); diff --git a/src/main/java/org/mariadb/r2dbc/message/server/Sequencer.java b/src/main/java/org/mariadb/r2dbc/message/server/Sequencer.java index 2f993247..cf8145f6 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/Sequencer.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/Sequencer.java @@ -1,15 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; -public class Sequencer { +import org.mariadb.r2dbc.message.MessageSequence; + +public class Sequencer implements MessageSequence { private byte sequenceId; public Sequencer(byte sequenceId) { this.sequenceId = sequenceId; } + public void reset() { + sequenceId = (byte) 0xff; + } + public byte next() { return ++sequenceId; } diff --git a/src/main/java/org/mariadb/r2dbc/message/server/SkipPacket.java b/src/main/java/org/mariadb/r2dbc/message/server/SkipPacket.java index 69482689..ef9fa71c 100644 --- a/src/main/java/org/mariadb/r2dbc/message/server/SkipPacket.java +++ b/src/main/java/org/mariadb/r2dbc/message/server/SkipPacket.java @@ -1,8 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.message.server; +import org.mariadb.r2dbc.message.ServerMessage; + public class SkipPacket implements ServerMessage { private final boolean ending; diff --git a/src/main/java/org/mariadb/r2dbc/util/Assert.java b/src/main/java/org/mariadb/r2dbc/util/Assert.java index cd28fa33..824a8d41 100644 --- a/src/main/java/org/mariadb/r2dbc/util/Assert.java +++ b/src/main/java/org/mariadb/r2dbc/util/Assert.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; diff --git a/src/main/java/org/mariadb/r2dbc/util/BindEncodedValue.java b/src/main/java/org/mariadb/r2dbc/util/BindEncodedValue.java new file mode 100644 index 00000000..878c9c9e --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/util/BindEncodedValue.java @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.util; + +import io.netty.buffer.ByteBuf; +import org.mariadb.r2dbc.codec.Codec; + +public class BindEncodedValue { + + private final Codec codec; + private final ByteBuf value; + + public BindEncodedValue(Codec codec, ByteBuf value) { + this.codec = codec; + this.value = value; + } + + public Codec getCodec() { + return codec; + } + + public ByteBuf getValue() { + return value; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/util/BindValue.java b/src/main/java/org/mariadb/r2dbc/util/BindValue.java new file mode 100644 index 00000000..ce1738c2 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/util/BindValue.java @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.util; + +import io.netty.buffer.ByteBuf; +import java.util.Objects; +import org.mariadb.r2dbc.codec.Codec; +import reactor.core.publisher.Mono; + +public class BindValue { + + public static final Mono NULL_VALUE = Mono.empty(); + private final Codec codec; + private final Mono value; + + public BindValue(Codec codec, Mono value) { + this.codec = codec; + this.value = Assert.requireNonNull(value, "value must not be null"); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BindValue that = (BindValue) o; + return Objects.equals(this.codec, that.codec) && Objects.equals(this.value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(this.codec, this.value); + } + + @Override + public String toString() { + return "BindValue{codec=" + this.codec.getClass().getSimpleName() + '}'; + } + + public Codec getCodec() { + return this.codec; + } + + public boolean isNull() { + return this.value == NULL_VALUE; + } + + public Mono getValue() { + return this.value; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/util/Binding.java b/src/main/java/org/mariadb/r2dbc/util/Binding.java new file mode 100644 index 00000000..29af6d58 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/util/Binding.java @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.util; + +import io.netty.util.ReferenceCountUtil; +import java.util.*; +import reactor.core.publisher.Flux; +import reactor.util.Logger; +import reactor.util.Loggers; + +public final class Binding { + private static final Logger LOGGER = Loggers.getLogger(Binding.class); + + private final int expectedSize; + private final Map binds; + + public Binding(int expectedSize) { + this.expectedSize = expectedSize; + this.binds = new HashMap<>(); + } + + public Binding add(int index, BindValue parameter) { + Assert.requireNonNull(parameter, "parameter must not be null"); + + if (index >= this.expectedSize) { + throw new IndexOutOfBoundsException( + String.format( + "Binding index %d when only %d parameters are expected", index, this.expectedSize)); + } + + this.binds.put(index, parameter); + + return this; + } + + public void clear() { + this.binds + .entrySet() + .forEach( + entry -> { + Flux.from(entry.getValue().getValue()) + .doOnNext(ReferenceCountUtil::release) + .subscribe( + ignore -> {}, + err -> + LOGGER.warn( + String.format("Cannot release parameter %s", entry.getValue()), err)); + }); + + this.binds.clear(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Binding that = (Binding) o; + return Objects.equals(this.binds, that.binds); + } + + @Override + public int hashCode() { + return Objects.hash(this.binds); + } + + public boolean isEmpty() { + return this.binds.isEmpty(); + } + + public int size() { + return this.binds.size(); + } + + @Override + public String toString() { + return "Binding{binds=" + this.binds + '}'; + } + + public void validate(int expectedSize) { + // valid parameters + for (int i = 0; i < expectedSize; i++) { + if (binds.get(i) == null) { + throw new IllegalStateException(String.format("Parameter at position %d is not set", i)); + } + } + } + + public List getBindResultParameters(int paramNumber) { + if (this.binds.isEmpty() && paramNumber == 0) { + return Collections.emptyList(); + } + List result = new ArrayList<>(paramNumber); + for (int i = 0; i < paramNumber; i++) { + BindValue parameter = this.binds.get(i); + if (parameter == null) { + throw new IllegalStateException(String.format("No parameter specified for index %d", i)); + } + result.add(parameter); + } + return result; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/util/BufferUtils.java b/src/main/java/org/mariadb/r2dbc/util/BufferUtils.java index cc5428bd..f1b46511 100644 --- a/src/main/java/org/mariadb/r2dbc/util/BufferUtils.java +++ b/src/main/java/org/mariadb/r2dbc/util/BufferUtils.java @@ -1,12 +1,16 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ByteProcessor; import java.nio.charset.StandardCharsets; import java.time.format.DateTimeFormatter; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.util.constants.ServerStatus; public final class BufferUtils { @@ -15,6 +19,9 @@ public final class BufferUtils { private static final byte DBL_QUOTE = (byte) '"'; private static final byte ZERO_BYTE = (byte) '\0'; private static final byte BACKSLASH = (byte) '\\'; + public static final byte[] BINARY_PREFIX = {'_', 'b', 'i', 'n', 'a', 'r', 'y', ' ', '\''}; + public static final byte[] STRING_PREFIX = {'\''}; + private static final DateTimeFormatter TIMESTAMP_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"); private static final DateTimeFormatter TIMESTAMP_FORMAT_NO_FRACTIONAL = @@ -67,165 +74,161 @@ public static ByteBuf readLengthEncodedBuffer(ByteBuf buf) { return buf.slice(buf.readerIndex() - length, length); } - public static void writeLengthEncode(int length, ByteBuf buf) { + public static byte[] encodeLength(int length) { if (length < 251) { - buf.writeByte((byte) length); - return; + return new byte[] {(byte) length}; } if (length < 65536) { - buf.writeByte((byte) 0xfc); - buf.writeByte((byte) length); - buf.writeByte((byte) (length >>> 8)); - return; + return new byte[] {(byte) 0xfc, (byte) length, (byte) (length >>> 8)}; } if (length < 16777216) { - buf.writeByte((byte) 0xfd); - buf.writeByte((byte) length); - buf.writeByte((byte) (length >>> 8)); - buf.writeByte((byte) (length >>> 16)); - return; + return new byte[] {(byte) 0xfd, (byte) length, (byte) (length >>> 8), (byte) (length >>> 16)}; } - buf.writeByte((byte) 0xfe); - buf.writeByte((byte) length); - buf.writeByte((byte) (length >>> 8)); - buf.writeByte((byte) (length >>> 16)); - buf.writeByte((byte) (length >>> 24)); - buf.writeByte((byte) (length >>> 32)); - buf.writeByte((byte) (length >>> 40)); - buf.writeByte((byte) (length >>> 48)); - buf.writeByte((byte) (length >>> 54)); + return new byte[] { + (byte) 0xfe, + (byte) length, + (byte) (length >>> 8), + (byte) (length >>> 16), + (byte) (length >>> 24), + (byte) (length >>> 32), + (byte) (length >>> 40), + (byte) (length >>> 48), + (byte) (length >>> 54) + }; } public static void writeLengthEncode(String val, ByteBuf buf) { byte[] bytes = val.getBytes(StandardCharsets.UTF_8); - writeLengthEncode(bytes.length, buf); + buf.writeBytes(encodeLength(bytes.length)); buf.writeBytes(bytes); } - public static void writeAscii(ByteBuf buf, String str) { - buf.writeCharSequence(str, StandardCharsets.US_ASCII); + public static ByteBuf encodeByte(ByteBufAllocator allocator, int value) { + ByteBuf byteBuf = allocator.buffer(); + byteBuf.writeByte(value); + return byteBuf; + } + + public static ByteBuf encodeAscii(ByteBufAllocator allocator, String value) { + ByteBuf byteBuf = allocator.buffer(); + byteBuf.writeCharSequence(value, StandardCharsets.US_ASCII); + return byteBuf; + } + + public static ByteBuf encodeLengthAscii(ByteBufAllocator allocator, String value) { + int len = value.length(); + ByteBuf byteBuf = allocator.buffer(len + 9); + byteBuf.writeBytes(encodeLength(value.length())); + byteBuf.writeCharSequence(value, StandardCharsets.US_ASCII); + return byteBuf; + } + + public static ByteBuf encodeLengthUtf8(ByteBufAllocator allocator, String value) { + byte[] b = value.getBytes(StandardCharsets.UTF_8); + CompositeByteBuf byteBuf = allocator.compositeBuffer(); + byteBuf.addComponent(true, Unpooled.wrappedBuffer(encodeLength(b.length))); + byteBuf.addComponent(true, Unpooled.wrappedBuffer(b)); + return byteBuf; } - public static void writeEscaped(ByteBuf buf, byte[] bytes, int offset, int len, Context context) { - buf.ensureWritable(len * 2); + public static ByteBuf encodeLengthBytes(ByteBufAllocator allocator, byte[] value) { + CompositeByteBuf byteBuf = allocator.compositeBuffer(); + byteBuf.addComponent(true, Unpooled.wrappedBuffer(encodeLength(value.length))); + byteBuf.addComponent(true, Unpooled.wrappedBuffer(value)); + return byteBuf; + } + + public static ByteBuf encodeEscapedBuffer( + ByteBufAllocator allocator, ByteBuf value, Context context) { + ByteBuf buf = allocator.buffer(value.readableBytes() * 2); + buf.writeBytes(BINARY_PREFIX); boolean noBackslashEscapes = (context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0; + int fromIndex = value.readerIndex(); + int toIndex = value.writerIndex(); if (noBackslashEscapes) { - for (int i = offset; i < len + offset; i++) { - if (QUOTE == bytes[i]) { + while (true) { + int nextPos = value.indexOf(fromIndex, toIndex, QUOTE); + if (nextPos >= 0) { + buf.writeBytes(value, fromIndex, nextPos - fromIndex); + buf.writeByte(QUOTE); buf.writeByte(QUOTE); + fromIndex = nextPos + 1; + } else { + buf.writeBytes(value, fromIndex, toIndex); + break; } - buf.writeByte(bytes[i]); } } else { - for (int i = offset; i < len + offset; i++) { - if (bytes[i] == QUOTE - || bytes[i] == BACKSLASH - || bytes[i] == '"' - || bytes[i] == ZERO_BYTE) { - buf.writeByte(BACKSLASH); + ByteProcessor processor = + b -> (b != QUOTE && b != BACKSLASH && b != (byte) '"' && b != ZERO_BYTE); + + while (true) { + int nextPos = value.forEachByte(fromIndex, toIndex - fromIndex, processor); + if (nextPos == -1) { + buf.writeBytes(value, fromIndex, toIndex - fromIndex); + break; } - buf.writeByte(bytes[i]); + buf.writeBytes(value, fromIndex, nextPos - fromIndex); + buf.writeByte(BACKSLASH); + buf.writeByte(value.getByte(nextPos)); + fromIndex = nextPos + 1; } } + buf.writeByte('\''); + return buf; } - public static ByteBuf write(ByteBuf buf, String str, boolean quote, Context context) { - - int charsLength = str.length(); - buf.ensureWritable(charsLength * 3 + 2); + public static ByteBuf encodeEscapedBytes( + ByteBufAllocator allocator, byte[] prefix, byte[] value, Context context) { + ByteBuf stBuf = Unpooled.wrappedBuffer(value); + ByteBuf buf = allocator.buffer(stBuf.readableBytes() + 10); + buf.writeBytes(prefix); + escapedBytes(buf, value, value.length, context); + buf.writeByte('\''); + return buf; + } - // create UTF-8 byte array - // since java char are internally using UTF-16 using surrogate's pattern, 4 bytes unicode - // characters will - // represent 2 characters : example "\uD83C\uDFA4" = 🎤 unicode 8 "no microphones" - // so max size is 3 * charLength - // (escape characters are 1 byte encoded, so length might only be 2 when escape) - // + 2 for the quotes for text protocol - int charsOffset = 0; - char currChar; + public static void escapedBytes(ByteBuf buf, byte[] value, int len, Context context) { + ByteBuf stBuf = Unpooled.wrappedBuffer(value, 0, len); boolean noBackslashEscapes = (context.getServerStatus() & ServerStatus.NO_BACKSLASH_ESCAPES) > 0; - // quick loop if only ASCII chars for faster escape - if (quote) buf.writeByte(QUOTE); + + int fromIndex = stBuf.readerIndex(); + int toIndex = stBuf.writerIndex(); if (noBackslashEscapes) { - for (; - charsOffset < charsLength && (currChar = str.charAt(charsOffset)) < 0x80; - charsOffset++) { - if (currChar == QUOTE) { + while (true) { + int nextPos = stBuf.indexOf(fromIndex, toIndex, QUOTE); + if (nextPos >= 0) { + buf.writeBytes(stBuf, fromIndex, nextPos - fromIndex); + buf.writeByte(QUOTE); buf.writeByte(QUOTE); + fromIndex = nextPos + 1; + } else { + buf.writeBytes(stBuf, fromIndex, toIndex - fromIndex); + break; } - buf.writeByte((byte) currChar); } } else { - for (; - charsOffset < charsLength && (currChar = str.charAt(charsOffset)) < 0x80; - charsOffset++) { - if (currChar == BACKSLASH || currChar == QUOTE || currChar == 0 || currChar == DBL_QUOTE) { - buf.writeByte(BACKSLASH); + ByteProcessor processor = + b -> (b != QUOTE && b != BACKSLASH && b != (byte) '"' && b != ZERO_BYTE); + + while (true) { + int nextPos = stBuf.forEachByte(fromIndex, toIndex - fromIndex, processor); + if (nextPos == -1) { + buf.writeBytes(stBuf, fromIndex, toIndex - fromIndex); + break; } - buf.writeByte((byte) currChar); + buf.writeBytes(stBuf, fromIndex, nextPos - fromIndex); + buf.writeByte(BACKSLASH); + buf.writeByte(stBuf.getByte(nextPos)); + fromIndex = nextPos + 1; } } - - // if quick loop not finished - while (charsOffset < charsLength) { - currChar = str.charAt(charsOffset++); - if (currChar < 0x80) { - if (noBackslashEscapes) { - if (currChar == QUOTE) { - buf.writeByte(QUOTE); - } - } else if (currChar == BACKSLASH - || currChar == QUOTE - || currChar == ZERO_BYTE - || currChar == DBL_QUOTE) { - buf.writeByte(BACKSLASH); - } - buf.writeByte((byte) currChar); - } else if (currChar < 0x800) { - buf.writeByte((byte) (0xc0 | (currChar >> 6))); - buf.writeByte((byte) (0x80 | (currChar & 0x3f))); - } else if (currChar >= 0xD800 && currChar < 0xE000) { - // reserved for surrogate - see https://en.wikipedia.org/wiki/UTF-16 - if (currChar < 0xDC00) { - // is high surrogate - if (charsOffset + 1 > charsLength) { - buf.writeByte((byte) 0x63); - } else { - char nextChar = str.charAt(charsOffset); - if (nextChar >= 0xDC00 && nextChar < 0xE000) { - // is low surrogate - int surrogatePairs = - ((currChar << 10) + nextChar) + (0x010000 - (0xD800 << 10) - 0xDC00); - buf.writeByte((byte) (0xf0 | ((surrogatePairs >> 18)))); - buf.writeByte((byte) (0x80 | ((surrogatePairs >> 12) & 0x3f))); - buf.writeByte((byte) (0x80 | ((surrogatePairs >> 6) & 0x3f))); - buf.writeByte((byte) (0x80 | (surrogatePairs & 0x3f))); - charsOffset++; - } else { - // must have low surrogate - buf.writeByte((byte) 0x3f); - } - } - } else { - // low surrogate without high surrogate before - buf.writeByte((byte) 0x3f); - } - } else { - buf.writeByte((byte) (0xe0 | ((currChar >> 12)))); - buf.writeByte((byte) (0x80 | ((currChar >> 6) & 0x3f))); - buf.writeByte((byte) (0x80 | (currChar & 0x3f))); - } - } - if (quote) { - buf.writeByte(QUOTE); - } - return buf; } public static String toString(ByteBuf packet) { diff --git a/src/main/java/org/mariadb/r2dbc/util/ClientPrepareResult.java b/src/main/java/org/mariadb/r2dbc/util/ClientPrepareResult.java index 564b2d21..d30b35f2 100644 --- a/src/main/java/org/mariadb/r2dbc/util/ClientPrepareResult.java +++ b/src/main/java/org/mariadb/r2dbc/util/ClientPrepareResult.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; @@ -12,9 +12,9 @@ public class ClientPrepareResult implements PrepareResult { private final List queryParts; private final List paramNameList; private final int paramCount; - private boolean isQueryMultipleRewritable; - private boolean isReturning; - private boolean supportAddingReturning; + private final boolean isQueryMultipleRewritable; + private final boolean isReturning; + private final boolean supportAddingReturning; private ClientPrepareResult( List queryParts, diff --git a/src/main/java/org/mariadb/r2dbc/util/DefaultHostnameVerifier.java b/src/main/java/org/mariadb/r2dbc/util/DefaultHostnameVerifier.java index 4c07b204..4c940513 100644 --- a/src/main/java/org/mariadb/r2dbc/util/DefaultHostnameVerifier.java +++ b/src/main/java/org/mariadb/r2dbc/util/DefaultHostnameVerifier.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; diff --git a/src/main/java/org/mariadb/r2dbc/util/HostAddress.java b/src/main/java/org/mariadb/r2dbc/util/HostAddress.java new file mode 100644 index 00000000..db35aa7a --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/util/HostAddress.java @@ -0,0 +1,63 @@ +package org.mariadb.r2dbc.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class HostAddress { + String host; + int port; + + public HostAddress(String host, int port) { + this.host = host; + this.port = port; + } + + public static List parse(String hosts, int defaultPort) { + // parse host for multiple hosts. + if (hosts != null) { + List hostAddresses = new ArrayList<>(); + String[] tmpHosts = hosts.split(","); + for (String tmpHost : tmpHosts) { + if (tmpHost.contains(":")) { + hostAddresses.add( + new HostAddress( + tmpHost.substring(0, tmpHost.indexOf(":")), + Integer.parseInt(tmpHost.substring(tmpHost.indexOf(":") + 1)))); + } else { + hostAddresses.add(new HostAddress(tmpHost, defaultPort)); + } + } + return hostAddresses; + } else { + return Collections.singletonList(new HostAddress("localhost", defaultPort)); + } + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof HostAddress)) return false; + HostAddress that = (HostAddress) o; + return port == that.port && host.equals(that.host); + } + + @Override + public int hashCode() { + return Objects.hash(host, port); + } + + @Override + public String toString() { + return host + ':' + port; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/util/MariadbType.java b/src/main/java/org/mariadb/r2dbc/util/MariadbType.java new file mode 100644 index 00000000..87c09d1a --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/util/MariadbType.java @@ -0,0 +1,62 @@ +package org.mariadb.r2dbc.util; + +import io.r2dbc.spi.Blob; +import io.r2dbc.spi.R2dbcType; +import io.r2dbc.spi.Type; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.BitSet; +import org.mariadb.r2dbc.codec.Codec; +import org.mariadb.r2dbc.codec.list.*; + +public enum MariadbType implements Type { + TINYINT(R2dbcType.TINYINT.name(), Byte.class, ByteCodec.INSTANCE), + UNSIGNED_TINYINT(R2dbcType.TINYINT.name(), Short.class, ShortCodec.INSTANCE), + SMALLINT(R2dbcType.SMALLINT.name(), Short.class, ShortCodec.INSTANCE), + UNSIGNED_SMALLINT(R2dbcType.SMALLINT.name(), Integer.class, IntCodec.INSTANCE), + INTEGER(R2dbcType.INTEGER.name(), Integer.class, IntCodec.INSTANCE), + UNSIGNED_INTEGER(R2dbcType.INTEGER.name(), Long.class, LongCodec.INSTANCE), + FLOAT(R2dbcType.FLOAT.name(), Float.class, FloatCodec.INSTANCE), + DOUBLE(R2dbcType.DOUBLE.name(), Double.class, DoubleCodec.INSTANCE), + BIGINT(R2dbcType.BIGINT.name(), Long.class, LongCodec.INSTANCE), + UNSIGNED_BIGINT(R2dbcType.BIGINT.name(), BigInteger.class, BigIntegerCodec.INSTANCE), + TIME(R2dbcType.TIME.name(), LocalTime.class, LocalTimeCodec.INSTANCE), + TIMESTAMP(R2dbcType.TIMESTAMP.name(), LocalDateTime.class, LocalDateTimeCodec.INSTANCE), + DATE(R2dbcType.DATE.name(), LocalDate.class, LocalDateCodec.INSTANCE), + BIT("BIT", BitSet.class, BitSetCodec.INSTANCE), + BOOLEAN(R2dbcType.BOOLEAN.getName(), Boolean.class, BooleanCodec.INSTANCE), + BYTES("BYTES", byte[].class, ByteArrayCodec.INSTANCE), + BLOB(R2dbcType.BLOB.getName(), ByteBuffer.class, ByteBufferCodec.INSTANCE), + VARCHAR(R2dbcType.VARCHAR.getName(), String.class, StringCodec.INSTANCE), + CLOB(R2dbcType.CLOB.getName(), String.class, StringCodec.INSTANCE), + BINARY(R2dbcType.BINARY.getName(), Blob.class, BlobCodec.INSTANCE), + DECIMAL(R2dbcType.DECIMAL.getName(), BigDecimal.class, BigDecimalCodec.INSTANCE); + + private final String typeName; + private final Class classType; + private final Codec defaultCodec; + + MariadbType(String typeName, Class classType, Codec defaultCodec) { + this.typeName = typeName; + this.classType = classType; + this.defaultCodec = defaultCodec; + } + + @Override + public Class getJavaType() { + return classType; + } + + @Override + public String getName() { + return typeName; + } + + public Codec getDefaultCodec() { + return defaultCodec; + } +} diff --git a/src/main/java/org/mariadb/r2dbc/util/PidFactory.java b/src/main/java/org/mariadb/r2dbc/util/PidFactory.java deleted file mode 100644 index 597d6d1c..00000000 --- a/src/main/java/org/mariadb/r2dbc/util/PidFactory.java +++ /dev/null @@ -1,35 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab - -package org.mariadb.r2dbc.util; - -import java.lang.reflect.Method; -import java.util.function.Supplier; - -public final class PidFactory { - private static Supplier instance; - - static { - try { - // if java 9+ - Class processHandle = Class.forName("java.lang.ProcessHandle"); - instance = - () -> { - try { - Method currentProcessMethod = processHandle.getMethod("current"); - Object currentProcess = currentProcessMethod.invoke(null); - Method pidMethod = processHandle.getMethod("pid"); - return String.valueOf(pidMethod.invoke(currentProcess)); - } catch (Throwable throwable) { - return null; - } - }; - } catch (Throwable cle) { - instance = () -> null; - } - } - - public static Supplier getInstance() { - return instance; - } -} diff --git a/src/main/java/org/mariadb/r2dbc/util/PrepareCache.java b/src/main/java/org/mariadb/r2dbc/util/PrepareCache.java index a72335a4..17eab73f 100644 --- a/src/main/java/org/mariadb/r2dbc/util/PrepareCache.java +++ b/src/main/java/org/mariadb/r2dbc/util/PrepareCache.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; diff --git a/src/main/java/org/mariadb/r2dbc/util/PrepareResult.java b/src/main/java/org/mariadb/r2dbc/util/PrepareResult.java index cb608501..55651f70 100644 --- a/src/main/java/org/mariadb/r2dbc/util/PrepareResult.java +++ b/src/main/java/org/mariadb/r2dbc/util/PrepareResult.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; diff --git a/src/main/java/org/mariadb/r2dbc/util/ServerNamedParamParser.java b/src/main/java/org/mariadb/r2dbc/util/ServerNamedParamParser.java new file mode 100644 index 00000000..b57cee37 --- /dev/null +++ b/src/main/java/org/mariadb/r2dbc/util/ServerNamedParamParser.java @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.util; + +import java.util.ArrayList; +import java.util.List; + +public class ServerNamedParamParser implements PrepareResult { + + private final String realSql; + private final List paramNameList; + private final int paramCount; + + private ServerNamedParamParser(String realSql, List paramNameList) { + this.realSql = realSql; + this.paramNameList = paramNameList; + this.paramCount = paramNameList.size(); + } + + /** + * Separate query in a String list and set flag isQueryMultipleRewritable. The resulting string + * list is separated by ? or :name that are not in comments. + * + * @param queryString query + * @param noBackslashEscapes escape mode + * @return ClientPrepareResult + */ + public static ServerNamedParamParser parameterParts( + String queryString, boolean noBackslashEscapes) { + StringBuilder sb = new StringBuilder(); + List paramNameList = new ArrayList<>(); + + LexState state = LexState.Normal; + char lastChar = '\0'; + boolean endingSemicolon = false; + + boolean singleQuotes = false; + int lastParameterPosition = 0; + + char[] query = queryString.toCharArray(); + int queryLength = query.length; + for (int i = 0; i < queryLength; i++) { + + char car = query[i]; + if (state == LexState.Escape) { + state = LexState.String; + lastChar = car; + continue; + } + + switch (car) { + case '*': + if (state == LexState.Normal && lastChar == '/') { + state = LexState.SlashStarComment; + } + break; + + case '/': + if (state == LexState.SlashStarComment && lastChar == '*') { + state = LexState.Normal; + } else if (state == LexState.Normal && lastChar == '/') { + state = LexState.EOLComment; + } + break; + + case '#': + if (state == LexState.Normal) { + state = LexState.EOLComment; + } + break; + + case '-': + if (state == LexState.Normal && lastChar == '-') { + state = LexState.EOLComment; + } + break; + + case '\n': + if (state == LexState.EOLComment) { + state = LexState.Normal; + } + break; + + case '"': + if (state == LexState.Normal) { + state = LexState.String; + singleQuotes = false; + } else if (state == LexState.String && !singleQuotes) { + state = LexState.Normal; + } + break; + + case '\'': + if (state == LexState.Normal) { + state = LexState.String; + singleQuotes = true; + } else if (state == LexState.String && singleQuotes) { + state = LexState.Normal; + } + break; + + case '\\': + if (!noBackslashEscapes && state == LexState.String) { + state = LexState.Escape; + } + break; + case ';': + if (state == LexState.Normal) { + endingSemicolon = true; + } + break; + + case '?': + if (state == LexState.Normal) { + sb.append(queryString, lastParameterPosition, i).append("?"); + lastParameterPosition = i + 1; + paramNameList.add(null); + } + break; + + case ':': + if (state == LexState.Normal) { + sb.append(queryString, lastParameterPosition, i).append("?"); + String placeholderName = ""; + while (++i < queryLength + && (car = query[i]) != ' ' + && ((car >= '0' && car <= '9') + || (car >= 'A' && car <= 'Z') + || (car >= 'a' && car <= 'z') + || car == '-' + || car == '_')) { + placeholderName += car; + } + lastParameterPosition = i; + paramNameList.add(placeholderName); + } + break; + + case '`': + if (state == LexState.Backtick) { + state = LexState.Normal; + } else if (state == LexState.Normal) { + state = LexState.Backtick; + } + break; + + default: + // multiple queries + if (state == LexState.Normal && endingSemicolon && ((byte) car >= 40)) { + endingSemicolon = false; + } + break; + } + lastChar = car; + } + if (lastParameterPosition == 0) { + sb.append(queryString); + } else { + sb.append(queryString, lastParameterPosition, queryLength); + } + + return new ServerNamedParamParser(sb.toString(), paramNameList); + } + + public String getRealSql() { + return realSql; + } + + public List getParamNameList() { + return paramNameList; + } + + @Override + public int getParamCount() { + return paramCount; + } + + enum LexState { + Normal, /* inside query */ + String, /* inside string */ + SlashStarComment, /* inside slash-star comment */ + Escape, /* found backslash */ + EOLComment, /* # comment, or // comment, or -- comment */ + Backtick /* found backtick */ + } +} diff --git a/src/main/java/org/mariadb/r2dbc/util/ServerPrepareResult.java b/src/main/java/org/mariadb/r2dbc/util/ServerPrepareResult.java index 6d24dfab..c6491058 100644 --- a/src/main/java/org/mariadb/r2dbc/util/ServerPrepareResult.java +++ b/src/main/java/org/mariadb/r2dbc/util/ServerPrepareResult.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; @@ -14,7 +14,7 @@ public class ServerPrepareResult { private final int statementId; private final int numParams; - private final ColumnDefinitionPacket[] columns; + private ColumnDefinitionPacket[] columns; private final AtomicBoolean closing = new AtomicBoolean(); private final AtomicInteger use = new AtomicInteger(1); @@ -26,6 +26,10 @@ public ServerPrepareResult(int statementId, int numParams, ColumnDefinitionPacke this.columns = columns; } + public void setColumns(ColumnDefinitionPacket[] columns) { + this.columns = columns; + } + public int getStatementId() { return statementId; } diff --git a/src/main/java/org/mariadb/r2dbc/util/SslConfig.java b/src/main/java/org/mariadb/r2dbc/util/SslConfig.java index 0bc2bd2d..e6b31f98 100644 --- a/src/main/java/org/mariadb/r2dbc/util/SslConfig.java +++ b/src/main/java/org/mariadb/r2dbc/util/SslConfig.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util; @@ -27,7 +27,7 @@ public class SslConfig { public static final SslConfig DISABLE_INSTANCE = new SslConfig(SslMode.DISABLE); - private SslMode sslMode; + private final SslMode sslMode; private String serverSslCert; private String clientSslCert; private String clientSslKey; @@ -133,6 +133,7 @@ private InputStream loadCert(String path) throws FileNotFoundException { return inStream; } + @SuppressWarnings("static") public GenericFutureListener> getHostNameVerifier( CompletableFuture result, String host, long threadId, SSLEngine engine) { return future -> { @@ -150,7 +151,8 @@ public GenericFutureListener> getHostNa // of error. Certificate[] certs = session.getPeerCertificates(); X509Certificate cert = (X509Certificate) certs[0]; - hostnameVerifier.verify(host, cert, threadId); + + DefaultHostnameVerifier.verify(host, cert, threadId); } } catch (SSLException ex) { result.completeExceptionally( diff --git a/src/main/java/org/mariadb/r2dbc/util/constants/Capabilities.java b/src/main/java/org/mariadb/r2dbc/util/constants/Capabilities.java index 7486b48e..abba3acd 100644 --- a/src/main/java/org/mariadb/r2dbc/util/constants/Capabilities.java +++ b/src/main/java/org/mariadb/r2dbc/util/constants/Capabilities.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util.constants; @@ -40,6 +40,9 @@ public class Capabilities { public static final long MARIADB_CLIENT_COM_MULTI = 1L << 33; /* bundle command during connection */ + /** permit COM_STMT_BULK commands */ + public static final long MARIADB_CLIENT_STMT_BULK_OPERATIONS = 1L << 34; + // permit skipping metadata public static final long MARIADB_CLIENT_CACHE_METADATA = 1L << 36; } diff --git a/src/main/java/org/mariadb/r2dbc/util/constants/ColumnFlags.java b/src/main/java/org/mariadb/r2dbc/util/constants/ColumnFlags.java index 6e67476e..c80c873c 100644 --- a/src/main/java/org/mariadb/r2dbc/util/constants/ColumnFlags.java +++ b/src/main/java/org/mariadb/r2dbc/util/constants/ColumnFlags.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util.constants; diff --git a/src/main/java/org/mariadb/r2dbc/util/constants/ServerStatus.java b/src/main/java/org/mariadb/r2dbc/util/constants/ServerStatus.java index 33b1b8a5..46d86f69 100644 --- a/src/main/java/org/mariadb/r2dbc/util/constants/ServerStatus.java +++ b/src/main/java/org/mariadb/r2dbc/util/constants/ServerStatus.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util.constants; @@ -16,5 +16,6 @@ public class ServerStatus { public static final short METADATA_CHANGED = 1024; public static final short QUERY_WAS_SLOW = 2048; public static final short PS_OUT_PARAMETERS = 4096; - public static final short SERVER_SESSION_STATE_CHANGED = 1 << 14; + public static final short STATUS_IN_TRANS_READONLY = 1 << 13; + public static final short SESSION_STATE_CHANGED = 1 << 14; } diff --git a/src/main/java/org/mariadb/r2dbc/util/constants/StateChange.java b/src/main/java/org/mariadb/r2dbc/util/constants/StateChange.java index 346e0b48..aed75a37 100644 --- a/src/main/java/org/mariadb/r2dbc/util/constants/StateChange.java +++ b/src/main/java/org/mariadb/r2dbc/util/constants/StateChange.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.util.constants; diff --git a/src/main/java9/module-info.java b/src/main/java9/module-info.java new file mode 100644 index 00000000..546e9b0d --- /dev/null +++ b/src/main/java9/module-info.java @@ -0,0 +1,31 @@ +module r2dbc.mariadb { + requires transitive r2dbc.spi; + requires transitive reactor.core; + requires transitive io.netty.buffer; + requires transitive io.netty.handler; + requires transitive io.netty.transport.unix.common; + requires transitive io.netty.common; + requires transitive io.netty.transport; + requires transitive io.netty.codec; + requires transitive org.reactivestreams; + requires transitive reactor.netty.core; + requires transitive java.naming; + + exports org.mariadb.r2dbc; + exports org.mariadb.r2dbc.api; + exports org.mariadb.r2dbc.authentication; + exports org.mariadb.r2dbc.message; + + uses org.mariadb.r2dbc.authentication.AuthenticationPlugin; + uses io.r2dbc.spi.ConnectionFactoryProvider; + + provides io.r2dbc.spi.ConnectionFactoryProvider with + org.mariadb.r2dbc.MariadbConnectionFactoryProvider; + provides org.mariadb.r2dbc.authentication.AuthenticationPlugin with + org.mariadb.r2dbc.authentication.standard.NativePasswordPluginFlow, + org.mariadb.r2dbc.authentication.addon.ClearPasswordPluginFlow, + org.mariadb.r2dbc.authentication.standard.Ed25519PasswordPluginFlow, + org.mariadb.r2dbc.authentication.standard.Sha256PasswordPluginFlow, + org.mariadb.r2dbc.authentication.standard.CachingSha2PasswordFlow, + org.mariadb.r2dbc.authentication.standard.PamPluginFlow; +} diff --git a/src/main/resources/META-INF/services/io.r2dbc.spi.ConnectionFactoryProvider b/src/main/resources/META-INF/services/io.r2dbc.spi.ConnectionFactoryProvider index 380540cc..e6f1d5f3 100644 --- a/src/main/resources/META-INF/services/io.r2dbc.spi.ConnectionFactoryProvider +++ b/src/main/resources/META-INF/services/io.r2dbc.spi.ConnectionFactoryProvider @@ -1,4 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2020-2021 MariaDB Corporation Ab +# Copyright (c) 2020-2022 MariaDB Corporation Ab -org.mariadb.r2dbc.MariadbConnectionFactoryProvider \ No newline at end of file +org.mariadb.r2dbc.MariadbConnectionFactoryProvider diff --git a/src/main/resources/META-INF/services/org.mariadb.r2dbc.authentication.AuthenticationPlugin b/src/main/resources/META-INF/services/org.mariadb.r2dbc.authentication.AuthenticationPlugin index 05df7fb9..efaec9ef 100644 --- a/src/main/resources/META-INF/services/org.mariadb.r2dbc.authentication.AuthenticationPlugin +++ b/src/main/resources/META-INF/services/org.mariadb.r2dbc.authentication.AuthenticationPlugin @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2020-2021 MariaDB Corporation Ab +# Copyright (c) 2020-2022 MariaDB Corporation Ab -org.mariadb.r2dbc.message.flow.NativePasswordPluginFlow -org.mariadb.r2dbc.message.flow.ClearPasswordPluginFlow -org.mariadb.r2dbc.message.flow.Ed25519PasswordPluginFlow -org.mariadb.r2dbc.message.flow.Sha256PasswordPluginFlow -org.mariadb.r2dbc.message.flow.CachingSha2PasswordFlow -org.mariadb.r2dbc.message.flow.PamPluginFlow \ No newline at end of file +org.mariadb.r2dbc.authentication.standard.NativePasswordPluginFlow +org.mariadb.r2dbc.authentication.addon.ClearPasswordPluginFlow +org.mariadb.r2dbc.authentication.standard.Ed25519PasswordPluginFlow +org.mariadb.r2dbc.authentication.standard.Sha256PasswordPluginFlow +org.mariadb.r2dbc.authentication.standard.CachingSha2PasswordFlow +org.mariadb.r2dbc.authentication.standard.PamPluginFlow diff --git a/src/main/resources/project.properties b/src/main/resources/project.properties index cb8a0178..c1d8a4d3 100644 --- a/src/main/resources/project.properties +++ b/src/main/resources/project.properties @@ -1,4 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2020-2021 MariaDB Corporation Ab +# Copyright (c) 2020-2022 MariaDB Corporation Ab -version=${project.version} \ No newline at end of file +version=${project.version} diff --git a/src/test/java/org/mariadb/r2dbc/BaseConnectionTest.java b/src/test/java/org/mariadb/r2dbc/BaseConnectionTest.java index 3c176062..fa796196 100644 --- a/src/test/java/org/mariadb/r2dbc/BaseConnectionTest.java +++ b/src/test/java/org/mariadb/r2dbc/BaseConnectionTest.java @@ -1,19 +1,20 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; import io.r2dbc.spi.ValidationDepth; import java.io.IOException; -import java.sql.SQLException; import java.util.Random; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.extension.*; +import org.junit.jupiter.api.function.Executable; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; import org.mariadb.r2dbc.tools.TcpProxy; +import org.mariadb.r2dbc.util.HostAddress; import reactor.test.StepVerifier; public class BaseConnectionTest extends BaseTest { @@ -21,7 +22,7 @@ public class BaseConnectionTest extends BaseTest { public static MariadbConnection sharedConn; public static MariadbConnection sharedConnPrepare; public static TcpProxy proxy; - private static Random rand = new Random(); + private static final Random rand = new Random(); public static final Boolean backslashEscape = System.getenv("NO_BACKSLASH_ESCAPES") != null ? Boolean.valueOf(System.getenv("NO_BACKSLASH_ESCAPES")) @@ -58,21 +59,17 @@ public static void beforeAll() throws Exception { } public MariadbConnection createProxyCon() throws Exception { + HostAddress hostAddress = TestConfiguration.defaultConf.getHostAddresses().get(0); try { - proxy = - new TcpProxy( - TestConfiguration.defaultConf.getHost(), TestConfiguration.defaultConf.getPort()); + proxy = new TcpProxy(hostAddress.getHost(), hostAddress.getPort()); } catch (IOException i) { - throw new SQLException("proxy error", i); + throw new Exception("proxy error", i); } MariadbConnectionConfiguration confProxy = TestConfiguration.defaultBuilder .clone() .port(proxy.getLocalPort()) - .host( - System.getenv("TRAVIS") != null - ? TestConfiguration.defaultConf.getHost() - : "localhost") + .host(System.getenv("TRAVIS") != null ? hostAddress.getHost() : "localhost") .build(); return new MariadbConnectionFactory(confProxy).create().block(); } @@ -81,7 +78,7 @@ public MariadbConnection createProxyCon() throws Exception { public void afterEach1() { int i = rand.nextInt(); sharedConn - .createStatement("SELECT " + i) + .createStatement("SELECT " + i + ", 'a'") .execute() .flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class))) .as(StepVerifier::create) @@ -90,7 +87,7 @@ public void afterEach1() { int j = rand.nextInt(); sharedConnPrepare - .createStatement("SELECT " + j) + .createStatement("SELECT " + j + ", 'b'") .execute() .flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class))) .as(StepVerifier::create) @@ -104,6 +101,20 @@ public static void afterEAll() { sharedConnPrepare.close().block(); } + public static boolean runLongTest() { + String runLongTest = System.getenv("RUN_LONG_TEST"); + if (runLongTest != null) { + return Boolean.parseBoolean(runLongTest); + } + return false; + } + + public static void assertThrowsContains( + Class expectedType, Executable executable, String expected) { + Exception e = Assertions.assertThrows(expectedType, executable); + Assertions.assertTrue(e.getMessage().contains(expected), "real message:" + e.getMessage()); + } + public static boolean isMariaDBServer() { MariadbConnectionMetadata meta = sharedConn.getMetadata(); return meta.isMariaDBServer(); diff --git a/src/test/java/org/mariadb/r2dbc/BaseTest.java b/src/test/java/org/mariadb/r2dbc/BaseTest.java index d29efb80..e7aeb823 100644 --- a/src/test/java/org/mariadb/r2dbc/BaseTest.java +++ b/src/test/java/org/mariadb/r2dbc/BaseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/test/java/org/mariadb/r2dbc/TestConfiguration.java b/src/test/java/org/mariadb/r2dbc/TestConfiguration.java index 4bb60033..319026b6 100644 --- a/src/test/java/org/mariadb/r2dbc/TestConfiguration.java +++ b/src/test/java/org/mariadb/r2dbc/TestConfiguration.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc; diff --git a/src/test/java/org/mariadb/r2dbc/integration/BatchTest.java b/src/test/java/org/mariadb/r2dbc/integration/BatchTest.java index 096f35d5..1ab3f614 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/BatchTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/BatchTest.java @@ -1,9 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.MariadbConnectionConfiguration; @@ -11,6 +13,9 @@ import org.mariadb.r2dbc.TestConfiguration; import org.mariadb.r2dbc.api.MariadbBatch; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.api.MariadbResult; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; public class BatchTest extends BaseConnectionTest { @@ -50,9 +55,20 @@ void basicBatch() { @Test void multiQueriesBatch() throws Exception { + // error crashing maxscale 6.1.x + Assumptions.assumeTrue( + !sharedConn.getMetadata().getDatabaseVersion().contains("maxScale-6.1.") + && !"skysql-ha".equals(System.getenv("srv"))); MariadbConnectionConfiguration confMulti = TestConfiguration.defaultBuilder.clone().allowMultiQueries(true).build(); - MariadbConnection multiConn = new MariadbConnectionFactory(confMulti).create().block(); + batchTest(confMulti); + MariadbConnectionConfiguration confNoMulti = + TestConfiguration.defaultBuilder.clone().allowMultiQueries(false).build(); + batchTest(confNoMulti); + } + + private void batchTest(MariadbConnectionConfiguration conf) throws Exception { + MariadbConnection multiConn = new MariadbConnectionFactory(conf).create().block(); multiConn .createStatement("CREATE TEMPORARY TABLE multiBatch (id int, test varchar(10))") .execute() @@ -86,40 +102,37 @@ void multiQueriesBatch() throws Exception { } @Test - void noMultiQueriesBatch() throws Exception { - MariadbConnectionConfiguration confMulti = + void cancelBatch() throws Exception { + // error crashing maxscale 6.1.x + Assumptions.assumeTrue( + !sharedConn.getMetadata().getDatabaseVersion().contains("maxScale-6.1.") + && !"skysql-ha".equals(System.getenv("srv"))); + MariadbConnectionConfiguration confNoMulti = TestConfiguration.defaultBuilder.clone().allowMultiQueries(false).build(); - MariadbConnection multiConn = new MariadbConnectionFactory(confMulti).create().block(); + MariadbConnection multiConn = new MariadbConnectionFactory(confNoMulti).create().block(); multiConn .createStatement("CREATE TEMPORARY TABLE multiBatch (id int, test varchar(10))") .execute() .blockLast(); MariadbBatch batch = multiConn.createBatch(); - int[] res = new int[20]; - for (int i = 0; i < 20; i++) { + + int[] res = new int[10_000]; + for (int i = 0; i < res.length; i++) { batch.add("INSERT INTO multiBatch VALUES (" + i + ", 'test" + i + "')"); res[i] = i; } + AtomicInteger resultNb = new AtomicInteger(0); + Flux f = batch.execute(); + Disposable disp = + f.flatMap(it -> it.getRowsUpdated()).subscribe(i -> resultNb.incrementAndGet()); + Thread.sleep(1000); - batch - .execute() - .flatMap(it -> it.getRowsUpdated()) - .as(StepVerifier::create) - .expectNext(1, 1, 1, 1, 1) - .expectNextCount(15) - .then( - () -> { - multiConn - .createStatement("SELECT id FROM multiBatch") - .execute() - .flatMap(r -> r.map((row, metadata) -> row.get(0))) - .as(StepVerifier::create) - .expectNext(0, 1, 2, 3, 4) - .expectNextCount(15) - .verifyComplete(); - multiConn.close().block(); - }) - .verifyComplete(); + Assertions.assertTrue(resultNb.get() > 0); + disp.dispose(); + Thread.sleep(1000); + Assertions.assertTrue( + resultNb.get() > 1 && resultNb.get() < 10_000, + String.format("expected %s to be 0 < x < 10000", resultNb.get())); } @Test @@ -143,4 +156,31 @@ void batchWithParameter() { e.getMessage().contains("Statement with parameters cannot be batched (sql:'")); } } + + @Test + void batchError() { + batchError(sharedConn); + batchError(sharedConnPrepare); + } + + void batchError(MariadbConnection conn) { + conn.createStatement("CREATE TEMPORARY TABLE basicBatch2 (id int, test varchar(10))") + .execute() + .blockLast(); + conn.createStatement("INSERT INTO basicBatch2 VALUES (?, ?)") + .bind(0, 1) + .bind(1, "dd") + .execute() + .blockLast(); + assertThrows( + IllegalStateException.class, + () -> + conn.createStatement("INSERT INTO basicBatch2 VALUES (?, ?)") + .bind(0, 1) + .bind(1, "dd") + .add() + .bind(1, "dd") + .add(), + "Parameter at position 0 is not set"); + } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/BigResultSetTest.java b/src/test/java/org/mariadb/r2dbc/integration/BigResultSetTest.java index 5516e54d..efe18a9a 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/BigResultSetTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/BigResultSetTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -39,9 +39,7 @@ public void beforeEach() { @Test void BigResultSet() { - Assumptions.assumeTrue( - System.getenv("RUN_LONG_TEST") == null - || !Boolean.parseBoolean(System.getenv("RUN_LONG_TEST"))); + Assumptions.assumeTrue(runLongTest()); MariadbConnectionMetadata meta = sharedConn.getMetadata(); // sequence table requirement Assumptions.assumeTrue(meta.isMariaDBServer() && minVersion(10, 1, 0)); @@ -57,7 +55,6 @@ void BigResultSet() { @Test void multipleFluxSubscription() { - Assumptions.assumeTrue(Boolean.parseBoolean(System.getProperty("RUN_LONG_TEST", "true"))); MariadbConnectionMetadata meta = sharedConn.getMetadata(); // sequence table requirement Assumptions.assumeTrue(meta.isMariaDBServer() && minVersion(10, 1, 0)); @@ -69,10 +66,7 @@ void multipleFluxSubscription() { AtomicInteger total = new AtomicInteger(); for (int i = 0; i < 10; i++) { - flux1.subscribe( - s -> { - total.incrementAndGet(); - }); + flux1.subscribe(s -> total.incrementAndGet()); } flux1.blockLast(); @@ -81,17 +75,13 @@ void multipleFluxSubscription() { @Test void multiPacketRow() { - Assumptions.assumeTrue( - checkMaxAllowedPacketMore20m(sharedConn) - && Boolean.parseBoolean(System.getProperty("RUN_LONG_TEST", "true"))); + Assumptions.assumeTrue(checkMaxAllowedPacketMore20m(sharedConn) && runLongTest()); multiPacketRow(sharedConn); } @Test void multiPacketRowPrepare() { - Assumptions.assumeTrue( - checkMaxAllowedPacketMore20m(sharedConn) - && Boolean.parseBoolean(System.getProperty("RUN_LONG_TEST", "true"))); + Assumptions.assumeTrue(checkMaxAllowedPacketMore20m(sharedConn) && runLongTest()); multiPacketRow(sharedConnPrepare); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/ConfigurationTest.java b/src/test/java/org/mariadb/r2dbc/integration/ConfigurationTest.java index 24de73d9..c46cddc1 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/ConfigurationTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/ConfigurationTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -8,7 +8,6 @@ import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; -import java.time.Duration; import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.Assertions; @@ -61,6 +60,31 @@ void ensureUserInfoUrlEncoding() { "r2dbc:mariadb://root%40%C3%A5:p%40ssword@localhost:3305/%D1" + "%88db"); Assertions.assertTrue(factory.toString().contains("username='root@å'")); Assertions.assertTrue(factory.toString().contains("database='шdb'")); + Assertions.assertTrue(factory.toString().contains("isolationLevel=null")); + } + + @Test + void isolationLevel() { + MariadbConnectionFactory factory = + (MariadbConnectionFactory) + ConnectionFactories.get( + "r2dbc:mariadb://root:password@localhost:3305/db?isolationLevel=REPEATABLE-READ"); + Assertions.assertTrue(factory.toString().contains("username='root'")); + Assertions.assertTrue(factory.toString().contains("database='db'")); + Assertions.assertTrue( + factory.toString().contains("isolationLevel=IsolationLevel{sql='REPEATABLE READ'}")); + } + + @Test + void haMode() { + MariadbConnectionFactory factory = + (MariadbConnectionFactory) + ConnectionFactories.get( + "r2dbc:mariadb:failover://root:password@localhost:3305/db?isolationLevel=REPEATABLE-READ"); + Assertions.assertTrue(factory.toString().contains("username='root'")); + Assertions.assertTrue(factory.toString().contains("database='db'")); + Assertions.assertTrue( + factory.toString().contains("isolationLevel=IsolationLevel{sql='REPEATABLE READ'}")); } @Test @@ -93,7 +117,7 @@ void checkOptions() throws Exception { + "&clientSslKey=" + clientSslKey + "&allowPipelining=true&useServerPrepStmts" - + "=true&prepareCacheSize=2560&connectTimeout=PT10S&socketTimeout=PT1H&tcpKeepAlive=true" + + "=true&prepareCacheSize=2560&connectTimeout=PT10S&tcpKeepAlive=true" + "&tcpAbortiveClose=true&sslMode=TRUST" + "&connectionAttributes" + "=test=2," @@ -104,13 +128,12 @@ void checkOptions() throws Exception { Assertions.assertTrue(factory.toString().contains("serverSslCert=" + serverSslCert)); Assertions.assertTrue(factory.toString().contains("clientSslCert=" + clientSslCert)); Assertions.assertTrue(factory.toString().contains("allowPipelining=true")); - Assertions.assertTrue(factory.toString().contains("useServerPrepStmts=true")); + Assertions.assertTrue(factory.toString().contains("useServerPrepStmts=false")); Assertions.assertTrue(factory.toString().contains("prepareCacheSize=2560")); Assertions.assertTrue(factory.toString().contains("sslMode=TRUST")); Assertions.assertTrue(factory.toString().contains("connectionAttributes={test=2, h=4}")); Assertions.assertTrue(factory.toString().contains("pamOtherPwd=*,*")); Assertions.assertTrue(factory.toString().contains("connectTimeout=PT10S")); - Assertions.assertTrue(factory.toString().contains("socketTimeout=PT1H")); Assertions.assertTrue(factory.toString().contains("tcpKeepAlive=true")); Assertions.assertTrue(factory.toString().contains("tcpAbortiveClose=true")); } @@ -125,8 +148,26 @@ void checkNotConcerned() { } } + @Test + void checkDecoded() { + ConnectionFactoryOptions options = + ConnectionFactoryOptions.parse("r2dbc:mariadb://ro%3Aot:pw%3Ad@localhost:3306/db"); + MariadbConnectionConfiguration conf = + MariadbConnectionConfiguration.fromOptions(options).build(); + Assertions.assertEquals("ro:ot", conf.getUsername()); + Assertions.assertEquals("pw:d", conf.getPassword().toString()); + } + @Test void factory() { + + final ConnectionFactoryOptions option1s = ConnectionFactoryOptions.builder().build(); + + assertThrows( + NoSuchOptionException.class, + () -> MariadbConnectionConfiguration.fromOptions(option1s).build(), + ""); + ConnectionFactoryOptions options = ConnectionFactoryOptions.builder() .option(ConnectionFactoryOptions.DRIVER, "mariadb") @@ -135,7 +176,6 @@ void factory() { .option(ConnectionFactoryOptions.USER, "myUser") .option(ConnectionFactoryOptions.DATABASE, "myDb") .option(MariadbConnectionFactoryProvider.ALLOW_MULTI_QUERIES, true) - .option(MariadbConnectionFactoryProvider.SOCKET_TIMEOUT, Duration.ofSeconds(3600)) .option(MariadbConnectionFactoryProvider.TCP_KEEP_ALIVE, true) .option(MariadbConnectionFactoryProvider.TCP_ABORTIVE_CLOSE, true) .option(Option.valueOf("locale"), "en_US") @@ -145,13 +185,12 @@ void factory() { MariadbConnectionFactory factory = MariadbConnectionFactory.from(conf); Assertions.assertTrue(factory.toString().contains("database='myDb'")); - Assertions.assertTrue(factory.toString().contains("host='someHost'")); + Assertions.assertTrue( + factory.toString().contains("hosts={[someHost:43306]}"), factory.toString()); Assertions.assertTrue(factory.toString().contains("allowMultiQueries=true")); Assertions.assertTrue(factory.toString().contains("allowPipelining=true")); Assertions.assertTrue(factory.toString().contains("username='myUser'")); - Assertions.assertTrue(factory.toString().contains("port=43306")); Assertions.assertTrue(factory.toString().contains("connectTimeout=PT10S")); - Assertions.assertTrue(factory.toString().contains("socketTimeout=PT1H")); Assertions.assertTrue(factory.toString().contains("tcpKeepAlive=true")); Assertions.assertTrue(factory.toString().contains("tcpAbortiveClose=true")); } @@ -191,7 +230,7 @@ void checkOptionsPerOption() { .build(); MariadbConnectionConfiguration conf = MariadbConnectionConfiguration.fromOptions(options).build(); - Assertions.assertEquals("someHost", conf.getHost()); + Assertions.assertEquals("someHost", conf.getHostAddresses().get(0).getHost()); Assertions.assertEquals(43306, conf.getPort()); Assertions.assertEquals(true, conf.allowMultiQueries()); @@ -297,11 +336,11 @@ void confStringValue() { builder.pamOtherPwd(new String[] {"fff", "ddd"}); builder.tlsProtocol((String[]) null); Assertions.assertEquals( - "Builder{rsaPublicKey=null, cachingRsaPublicKey=null, allowPublicKeyRetrieval=false, username=admin, connectTimeout=null, socketTimeout=null, tcpKeepAlive=null, tcpAbortiveClose=null, database=dbname, host=localhost, sessionVariables=null, connectionAttributes=null, password=*, port=3306, socket=null, allowMultiQueries=false, allowPipelining=true, useServerPrepStmts=false, prepareCacheSize=null, tlsProtocol=null, serverSslCert=null, clientSslCert=null, clientSslKey=null, clientSslPassword=null, sslMode=TRUST, pamOtherPwd=*,*, tinyInt1isBit=false, autoCommit=true}", + "Builder{rsaPublicKey=null, cachingRsaPublicKey=null, allowPublicKeyRetrieval=false, username=admin, connectTimeout=null, tcpKeepAlive=null, tcpAbortiveClose=null, transactionReplay=null, database=dbname, host=localhost, sessionVariables=null, connectionAttributes=null, password=*, restrictedAuth=null, port=3306, hosts={}, socket=null, allowMultiQueries=false, allowPipelining=true, useServerPrepStmts=false, prepareCacheSize=null, isolationLevel=null, tlsProtocol=null, serverSslCert=null, clientSslCert=null, clientSslKey=null, clientSslPassword=null, sslMode=TRUST, pamOtherPwd=*,*, tinyInt1isBit=false, autoCommit=true}", builder.toString()); builder.tlsProtocol((String) null); Assertions.assertEquals( - "Builder{rsaPublicKey=null, cachingRsaPublicKey=null, allowPublicKeyRetrieval=false, username=admin, connectTimeout=null, socketTimeout=null, tcpKeepAlive=null, tcpAbortiveClose=null, database=dbname, host=localhost, sessionVariables=null, connectionAttributes=null, password=*, port=3306, socket=null, allowMultiQueries=false, allowPipelining=true, useServerPrepStmts=false, prepareCacheSize=null, tlsProtocol=null, serverSslCert=null, clientSslCert=null, clientSslKey=null, clientSslPassword=null, sslMode=TRUST, pamOtherPwd=*,*, tinyInt1isBit=false, autoCommit=true}", + "Builder{rsaPublicKey=null, cachingRsaPublicKey=null, allowPublicKeyRetrieval=false, username=admin, connectTimeout=null, tcpKeepAlive=null, tcpAbortiveClose=null, transactionReplay=null, database=dbname, host=localhost, sessionVariables=null, connectionAttributes=null, password=*, restrictedAuth=null, port=3306, hosts={}, socket=null, allowMultiQueries=false, allowPipelining=true, useServerPrepStmts=false, prepareCacheSize=null, isolationLevel=null, tlsProtocol=null, serverSslCert=null, clientSslCert=null, clientSslKey=null, clientSslPassword=null, sslMode=TRUST, pamOtherPwd=*,*, tinyInt1isBit=false, autoCommit=true}", builder.toString()); MariadbConnectionConfiguration conf = builder.build(); Assertions.assertEquals( diff --git a/src/test/java/org/mariadb/r2dbc/integration/ConnectionMetadataTest.java b/src/test/java/org/mariadb/r2dbc/integration/ConnectionMetadataTest.java index 75522cad..1ec31c38 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/ConnectionMetadataTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/ConnectionMetadataTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -27,8 +27,13 @@ void connectionMeta() { String version = System.getenv("v"); if (type != null && System.getenv("TRAVIS") != null) { if ("mariadb".equals(type) || "mysql".equals(type)) { - assertTrue(meta.getDatabaseVersion().contains(version)); - assertEquals(type.toLowerCase(), meta.getDatabaseProductName().toLowerCase()); + assertTrue( + meta.getDatabaseVersion().contains(version), + "Error " + meta.getDatabaseVersion() + " doesn't contains " + version); + assertEquals( + type.toLowerCase(), + meta.getDatabaseProductName().toLowerCase(), + "Error comparing " + type + " with " + meta.getDatabaseProductName()); } } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/ConnectionTest.java b/src/test/java/org/mariadb/r2dbc/integration/ConnectionTest.java index d41bdc58..12acfae1 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/ConnectionTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/ConnectionTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -14,21 +14,31 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Assumptions; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.*; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbResult; import org.mariadb.r2dbc.api.MariadbStatement; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.netty.resources.LoopResources; import reactor.test.StepVerifier; public class ConnectionTest extends BaseConnectionTest { private Level initialReactorLvl; private Level initialLvl; + @BeforeAll + public static void before2() { + dropAll(); + sharedConn.createStatement("CREATE DATABASE test_r2dbc").execute().blockLast(); + } + + @AfterAll + public static void dropAll() { + sharedConn.createStatement("DROP DATABASE test_r2dbc").execute().blockLast(); + } + @Test void localValidation() { sharedConn @@ -139,11 +149,11 @@ void connectionWithoutErrorOnClose() throws Exception { !"maxscale".equals(System.getenv("srv")) && !"skysql".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); - // disableLog(); + // disableLog(); MariadbConnection connection = createProxyCon(); proxy.stop(); connection.close().block(); - // reInitLog(); + // reInitLog(); } @Test @@ -152,7 +162,7 @@ void connectionDuringError() throws Exception { !"maxscale".equals(System.getenv("srv")) && !"skysql".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); - // disableLog(); + // disableLog(); MariadbConnection connection = createProxyCon(); new Timer() .schedule( @@ -165,7 +175,7 @@ public void run() { 200); assertTimeout( - Duration.ofSeconds(2), + Duration.ofSeconds(5), () -> { try { connection @@ -188,7 +198,7 @@ public void run() { .as(StepVerifier::create) .expectNext(Boolean.FALSE) .verifyComplete(); - // reInitLog(); + // reInitLog(); } }); } @@ -236,14 +246,64 @@ void connectTimeout() throws Exception { } @Test - void socketTimeoutTimeout() throws Exception { + void timeoutMultiHost() throws Exception { MariadbConnectionConfiguration conf = - TestConfiguration.defaultBuilder.clone().socketTimeout(Duration.ofSeconds(1)).build(); + TestConfiguration.defaultBuilder + .clone() + .host( + "128.2.2.2," + + TestConfiguration.defaultBuilder + .clone() + .build() + .getHostAddresses() + .get(0) + .getHost()) + .connectTimeout(Duration.ofMillis(500)) + .build(); MariadbConnection connection = new MariadbConnectionFactory(conf).create().block(); consume(connection); connection.close().block(); } + @Test + public void localSocket() throws Exception { + Assumptions.assumeTrue( + System.getenv("local") != null + && "1".equals(System.getenv("local")) + && !System.getProperty("os.name").toLowerCase(Locale.ROOT).contains("win")); + String socket = + sharedConn + .createStatement("select @@socket") + .execute() + .flatMap(r -> r.map((row, metadata) -> row.get(0, String.class))) + .blockLast(); + sharedConn.createStatement("DROP USER IF EXISTS testSocket@'localhost'").execute().blockLast(); + sharedConn + .createStatement("CREATE USER testSocket@'localhost' IDENTIFIED BY 'MySup5%rPassw@ord'") + .execute() + .blockLast(); + sharedConn + .createStatement( + "GRANT SELECT on *.* to testSocket@'localhost' IDENTIFIED BY 'MySup5%rPassw@ord'") + .execute() + .blockLast(); + sharedConn.createStatement("FLUSH PRIVILEGES").execute().blockLast(); + + MariadbConnectionConfiguration conf = + MariadbConnectionConfiguration.builder() + .username("testSocket") + .password("MySup5%rPassw@ord") + .database(TestConfiguration.database) + .socket(socket) + .build(); + + MariadbConnection connection = new MariadbConnectionFactory(conf).create().block(); + consume(connection); + connection.close().block(); + + sharedConn.createStatement("DROP USER testSocket@'localhost'").execute().blockLast(); + } + @Test void socketTcpKeepAlive() throws Exception { MariadbConnectionConfiguration conf = @@ -372,6 +432,67 @@ void multipleBegin(MariadbConnection con) throws Exception { con.beginTransaction().subscribe(); con.beginTransaction().block(); con.beginTransaction().block(); + con.rollbackTransaction().block(); + } + + @Test + void multipleBeginWithIsolation() throws Exception { + MariadbTransactionDefinition[] transactionDefinitions = { + MariadbTransactionDefinition.READ_ONLY, + MariadbTransactionDefinition.READ_WRITE, + MariadbTransactionDefinition.EMPTY, + MariadbTransactionDefinition.WITH_CONSISTENT_SNAPSHOT_READ_ONLY, + MariadbTransactionDefinition.WITH_CONSISTENT_SNAPSHOT_READ_WRITE + }; + + for (MariadbTransactionDefinition transactionDefinition : transactionDefinitions) { + MariadbConnection connection = factory.create().block(); + multipleBeginWithIsolation(connection, transactionDefinition); + connection.close().block(); + + connection = + new MariadbConnectionFactory( + TestConfiguration.defaultBuilder.clone().allowPipelining(false).build()) + .create() + .block(); + multipleBeginWithIsolation(connection, transactionDefinition); + connection.close().block(); + } + } + + void multipleBeginWithIsolation( + MariadbConnection con, MariadbTransactionDefinition transactionDefinition) throws Exception { + con.beginTransaction(transactionDefinition).subscribe(); + con.beginTransaction(transactionDefinition).block(); + con.beginTransaction(transactionDefinition).block(); + con.rollbackTransaction().block(); + } + + @Test + void beginTransactionWithIsolation() throws Exception { + TransactionDefinition transactionDefinition = + MariadbTransactionDefinition.READ_ONLY.isolationLevel(IsolationLevel.READ_COMMITTED); + TransactionDefinition transactionDefinition2 = + MariadbTransactionDefinition.READ_ONLY.isolationLevel(IsolationLevel.REPEATABLE_READ); + assertFalse(sharedConn.isInTransaction()); + + sharedConn.beginTransaction(transactionDefinition).block(); + assertEquals(IsolationLevel.READ_COMMITTED, sharedConn.getTransactionIsolationLevel()); + assertTrue(sharedConn.isInTransaction()); + assertTrue(sharedConn.isInReadOnlyTransaction()); + sharedConn.beginTransaction(transactionDefinition).block(); + sharedConn + .beginTransaction(transactionDefinition2) + .as(StepVerifier::create) + .expectErrorMatches( + throwable -> + throwable instanceof R2dbcPermissionDeniedException + && throwable + .getMessage() + .equals( + "Transaction characteristics can't be changed while a transaction is in progress")) + .verify(); + sharedConn.rollbackTransaction().block(); } @Test @@ -393,6 +514,7 @@ void multipleAutocommit(MariadbConnection con) throws Exception { con.setAutoCommit(true).subscribe(); con.setAutoCommit(true).block(); con.setAutoCommit(false).block(); + con.setAutoCommit(true).block(); } @Test @@ -429,6 +551,9 @@ private void consume(Connection connection) { @Test void multiThreading() throws Throwable { + Assumptions.assumeTrue( + !"skysql".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); + AtomicInteger completed = new AtomicInteger(0); ThreadPoolExecutor scheduler = new ThreadPoolExecutor(10, 20, 50, TimeUnit.SECONDS, new LinkedBlockingQueue<>()); @@ -520,7 +645,7 @@ void sessionVariables() throws Exception { } protected class ExecuteQueries implements Runnable { - private AtomicInteger i; + private final AtomicInteger i; public ExecuteQueries(AtomicInteger i) { this.i = i; @@ -548,8 +673,8 @@ public void run() { } protected class ExecuteQueriesOnSameConnection implements Runnable { - private AtomicInteger i; - private MariadbConnection connection; + private final AtomicInteger i; + private final MariadbConnection connection; public ExecuteQueriesOnSameConnection(AtomicInteger i, MariadbConnection connection) { this.i = i; @@ -579,13 +704,9 @@ void getTransactionIsolationLevel() { new MariadbConnectionFactory(TestConfiguration.defaultBuilder.build()).create().block(); try { IsolationLevel defaultValue = IsolationLevel.REPEATABLE_READ; - - if ("skysql".equals(System.getenv("srv")) || "skysql-ha".equals(System.getenv("srv"))) { - defaultValue = IsolationLevel.READ_COMMITTED; - } - Assertions.assertEquals(defaultValue, connection.getTransactionIsolationLevel()); connection.setTransactionIsolationLevel(IsolationLevel.READ_UNCOMMITTED).block(); + connection.createStatement("BEGIN").execute().blockLast(); Assertions.assertEquals( IsolationLevel.READ_UNCOMMITTED, connection.getTransactionIsolationLevel()); connection.setTransactionIsolationLevel(defaultValue).block(); @@ -594,6 +715,23 @@ void getTransactionIsolationLevel() { } } + @Test + void getDatabase() { + MariadbConnection connection = + new MariadbConnectionFactory(TestConfiguration.defaultBuilder.build()).create().block(); + assertEquals(TestConfiguration.database, connection.getDatabase()); + connection.setDatabase("test_r2dbc").block(); + assertEquals("test_r2dbc", connection.getDatabase()); + String db = + connection + .createStatement("SELECT DATABASE()") + .execute() + .flatMap(it -> it.map((row, rowMetadata) -> row.get(0, String.class))) + .blockFirst(); + assertEquals("test_r2dbc", db); + connection.close().block(); + } + @Test void rollbackTransaction() { sharedConn.createStatement("DROP TABLE IF EXISTS rollbackTable").execute().blockLast(); @@ -612,6 +750,7 @@ void rollbackTransaction() { .as(StepVerifier::create) .verifyComplete(); sharedConn.createStatement("DROP TABLE IF EXISTS rollbackTable").execute().blockLast(); + sharedConn.setAutoCommit(true).block(); } @Test @@ -745,11 +884,11 @@ void toStringTest() { "MariadbConnection{client=Client{isClosed=false, " + "context=ConnectionContext{")); if (!"skysql".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))) { - Assertions.assertTrue( connection .toString() - .contains(", isolationLevel=IsolationLevel{sql='REPEATABLE READ'}}")); + .contains(", isolationLevel=IsolationLevel{sql='REPEATABLE READ'}}"), + connection.toString()); } } finally { connection.close().block(); @@ -802,7 +941,7 @@ public void noDb() throws Throwable { } @Test - public void initialIsolationLevel() { + public void initialIsolationLevel() throws CloneNotSupportedException { Assumptions.assumeTrue( !"skysql".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); for (IsolationLevel level : levels) { @@ -812,7 +951,31 @@ public void initialIsolationLevel() { .blockLast(); MariadbConnection connection = new MariadbConnectionFactory(TestConfiguration.defaultBuilder.build()).create().block(); + assertEquals(IsolationLevel.REPEATABLE_READ, connection.getTransactionIsolationLevel()); + connection.close().block(); + + connection = + new MariadbConnectionFactory( + TestConfiguration.defaultBuilder.clone().isolationLevel(level).build()) + .create() + .block(); assertEquals(level, connection.getTransactionIsolationLevel()); + String sql = "SELECT @@tx_isolation"; + + if (!isMariaDBServer()) { + if ((minVersion(8, 0, 3)) + || (sharedConn.getMetadata().getMajorVersion() < 8 && minVersion(5, 7, 20))) { + sql = "SELECT @@transaction_isolation"; + } + } + + String iso = + connection + .createStatement(sql) + .execute() + .flatMap(it -> it.map((row, rowMetadata) -> row.get(0, String.class))) + .blockFirst(); + assertEquals(level, IsolationLevel.valueOf(iso.replace("-", " "))); connection.close().block(); } @@ -825,6 +988,11 @@ public void initialIsolationLevel() { @Test public void errorOnConnection() throws Throwable { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv"))); + BigInteger maxConn = sharedConn .createStatement("select @@max_connections") @@ -833,7 +1001,7 @@ public void errorOnConnection() throws Throwable { .blockLast(); Assumptions.assumeTrue(maxConn.intValue() < 600); - R2dbcTransientResourceException expected = null; + Throwable expected = null; Mono[] cons = new Mono[maxConn.intValue()]; for (int i = 0; i < maxConn.intValue(); i++) { cons[i] = new MariadbConnectionFactory(TestConfiguration.defaultBuilder.build()).create(); @@ -842,7 +1010,7 @@ public void errorOnConnection() throws Throwable { for (int i = 0; i < maxConn.intValue(); i++) { try { connections[i] = (MariadbConnection) cons[i].block(); - } catch (R2dbcTransientResourceException e) { + } catch (Throwable e) { expected = e; } } @@ -853,7 +1021,8 @@ public void errorOnConnection() throws Throwable { } } Assertions.assertNotNull(expected); - Assertions.assertTrue(expected.getMessage().contains("Too many connections")); + Assertions.assertTrue(expected.getMessage().contains("Fail to establish connection to")); + Assertions.assertTrue(expected.getCause().getMessage().contains("Too many connections")); Thread.sleep(1000); } @@ -865,7 +1034,9 @@ void killedConnection() { && !"skysql-ha".equals(System.getenv("srv"))); MariadbConnection connection = factory.create().block(); long threadId = connection.getThreadId(); - + assertNotNull(connection.getHost()); + assertEquals(TestConfiguration.defaultBuilder.build().getPort(), connection.getPort()); + connection.getPort(); Runnable runnable = () -> { try { @@ -910,4 +1081,71 @@ void killedConnection() { .expectNext(Boolean.FALSE) .verifyComplete(); } + + @Test + public void queryTimeout() throws Throwable { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + MariadbConnection connection = + new MariadbConnectionFactory(TestConfiguration.defaultBuilder.clone().build()) + .create() + .block(); + connection.setStatementTimeout(Duration.ofMillis(0500)).block(); + + try { + connection + .createStatement( + "select * from information_schema.columns as c1, " + + "information_schema.tables, information_schema.tables as t2") + .execute() + .flatMap(r -> r.map((rows, meta) -> "")) + .blockLast(); + Assertions.fail(); + } catch (R2dbcTimeoutException e) { + assertTrue( + e.getMessage().contains("Query execution was interrupted (max_statement_time exceeded)") + || e.getMessage() + .contains( + "Query execution was interrupted, maximum statement execution time exceeded")); + } finally { + connection.close().block(); + } + } + + @Test + public void setLockWaitTimeout() { + sharedConn.setLockWaitTimeout(Duration.ofMillis(1)).block(); + } + + @Test + public void testPools() throws Throwable { + boolean hasReactorTcp = false; + boolean hasMariaDbThreads = false; + Set threadSet = Thread.getAllStackTraces().keySet(); + for (Thread thread : threadSet) { + if (thread.getName().contains("reactor-tcp")) hasReactorTcp = true; + if (thread.getName().contains("mariadb")) hasMariaDbThreads = true; + } + assertTrue(hasReactorTcp); + assertFalse(hasMariaDbThreads); + + MariadbConnection connection = + new MariadbConnectionFactory( + TestConfiguration.defaultBuilder + .clone() + .loopResources(LoopResources.create("mariadb")) + .build()) + .create() + .block(); + + threadSet = Thread.getAllStackTraces().keySet(); + for (Thread thread : threadSet) { + if (thread.getName().contains("mariadb")) hasMariaDbThreads = true; + } + assertTrue(hasMariaDbThreads); + + connection.close().block(); + } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/ErrorTest.java b/src/test/java/org/mariadb/r2dbc/integration/ErrorTest.java index bdc4cb98..35ee379f 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/ErrorTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/ErrorTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -13,7 +13,7 @@ import org.mariadb.r2dbc.TestConfiguration; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; -import reactor.core.publisher.Mono; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; public class ErrorTest extends BaseConnectionTest { @@ -74,9 +74,11 @@ void permissionDenied() throws Exception { .expectErrorMatches( throwable -> throwable instanceof R2dbcNonTransientResourceException - && (throwable + && throwable.getMessage().contains("Fail to establish connection to") + && throwable + .getCause() .getMessage() - .contains("Access denied for user 'userWithoutRight'"))) + .contains("Access denied for user 'userWithoutRight'")) .verify(); } @@ -122,7 +124,7 @@ void rollbackException() { .createStatement("SET SESSION innodb_lock_wait_timeout=1") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); connection.beginTransaction().block(); connection diff --git a/src/test/java/org/mariadb/r2dbc/integration/FailoverConnectionTest.java b/src/test/java/org/mariadb/r2dbc/integration/FailoverConnectionTest.java new file mode 100644 index 00000000..c4ea91c6 --- /dev/null +++ b/src/test/java/org/mariadb/r2dbc/integration/FailoverConnectionTest.java @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.integration; + +import static org.junit.jupiter.api.Assertions.*; + +import ch.qos.logback.classic.Level; +import io.r2dbc.spi.*; +import java.io.IOException; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.*; +import org.mariadb.r2dbc.*; +import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.api.MariadbStatement; +import org.mariadb.r2dbc.tools.TcpProxy; +import org.mariadb.r2dbc.util.HostAddress; +import reactor.test.StepVerifier; + +public class FailoverConnectionTest extends BaseConnectionTest { + + + @BeforeAll + public static void before2() { + + sharedConn + .createStatement("CREATE TABLE IF NOT EXISTS sequence_1_to_10000 (t1 int)") + .execute() + .blockLast(); + if (sharedConn + .createStatement("SELECT COUNT(*) FROM sequence_1_to_10000") + .execute() + .flatMap(r -> r.map((row, metadata) -> row.get(0, Integer.class))) + .blockLast() + != 10000) { + sharedConn.createStatement("TRUNCATE TABLE sequence_1_to_10000").execute().blockLast(); + if (isMariaDBServer()) { + sharedConn + .createStatement("INSERT INTO sequence_1_to_10000 SELECT * from seq_1_to_10000") + .execute() + .blockLast(); + } else { + MariadbStatement stmt = + sharedConn.createStatement("INSERT INTO sequence_1_to_10000 VALUES (?)"); + stmt.bind(0, 1); + for (int i = 2; i <= 10_000; i++) { + stmt.add(); + stmt.bind(0, i); + } + stmt.execute().blockLast(); + } + } + } + + @Test + void multipleCommandStack() throws Exception { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + + MariadbConnection connection = createFailoverProxyConnection(HaMode.SEQUENTIAL, false, false); + try { + connection.createStatement("SET @con=1").execute().blockLast(); + assertTrue(connection.validate(ValidationDepth.REMOTE).block()); + proxy.restart(11000); + Thread.sleep(200); + assertFalse(connection.validate(ValidationDepth.REMOTE).block()); + Thread.sleep(200); + + assertTrue(connection.validate(ValidationDepth.REMOTE).block()); + connection.close().block(); + } finally { + proxy.forceClose(); + Thread.sleep(50); + } + } + + @Test + void transactionReplayFalse() throws Exception { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + + MariadbConnection connection = createFailoverProxyConnection(HaMode.SEQUENTIAL, false, false); + try { + connection.setAutoCommit(false).block(); + connection.beginTransaction().block(); + connection.createStatement("SET @con=1").execute().blockLast(); + + proxy.restartForce(); + try { + connection.createStatement("SELECT @con").execute().blockLast(); + fail(); + } catch (R2dbcException e) { + assertTrue(e.getMessage().contains("In progress transaction was lost")); + } + } finally { + proxy.forceClose(); + Thread.sleep(50); + } + } + + @Test + void transactionReplayFailingBetweenCmds() throws Exception { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + try { + transactionReplayFailingBetweenCmds( + createFailoverProxyConnection(HaMode.SEQUENTIAL, true, false)); + transactionReplayFailingBetweenCmds( + createFailoverProxyConnection(HaMode.SEQUENTIAL, true, true)); + } finally { + Thread.sleep(50); + } + } + + private void transactionReplayFailingBetweenCmds(MariadbConnection connection) throws Exception { + try { + connection.setAutoCommit(false).block(); + connection.beginTransaction().block(); + connection.createStatement("SET @con=1").execute().blockLast(); + + // proxy.restartForce(); + + Optional res = + connection + .createStatement("SELECT @con") + .execute() + .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) + .blockLast(); + assertTrue(res.isPresent()); + assertEquals(1L, res.get()); + + connection.createStatement("DROP TABLE IF EXISTS testReplay").execute().blockLast(); + connection.createStatement("CREATE TABLE testReplay(id INT)").execute().blockLast(); + connection.createStatement("INSERT INTO testReplay VALUE (1)").execute().blockLast(); + connection.setAutoCommit(false).block(); + connection.beginTransaction().block(); + + connection.createStatement("INSERT INTO testReplay VALUE (2)").execute().blockLast(); + connection + .createStatement("INSERT INTO testReplay VALUE (?)") + .bind(0, 3) + .execute() + .blockLast(); + + connection + .createStatement("INSERT INTO testReplay VALUE (?)") + .bind(0, 4) + .execute() + .blockLast(); + + proxy.restartForce(); + + connection + .createStatement("INSERT INTO testReplay VALUE (?)") + .bind(0, 5) + .execute() + .blockLast(); + + connection + .createStatement("SELECT id from testReplay") + .execute() + .flatMap(r -> r.map((row, metadata) -> row.get(0, Integer.class))) + .as(StepVerifier::create) + .expectNext(1, 2, 3, 4, 5) + .verifyComplete(); + connection.createStatement("DROP TABLE IF EXISTS testReplay").execute().blockLast(); + + } finally { + connection.close().block(); + proxy.forceClose(); + } + } + + @Test + void transactionReplayFailingDuringCmd() throws Exception { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + + transactionReplayFailingDuringCmd( + createFailoverProxyConnection(HaMode.SEQUENTIAL, true, false)); + transactionReplayFailingDuringCmd(createFailoverProxyConnection(HaMode.SEQUENTIAL, true, true)); + Thread.sleep(50); + } + + private void transactionReplayFailingDuringCmd(MariadbConnection connection) throws Exception { + connection.setAutoCommit(false).block(); + connection.beginTransaction().block(); + + AtomicInteger expectedResult = new AtomicInteger(1); + AtomicBoolean endedByError = new AtomicBoolean(false); + AtomicReference resultingError = new AtomicReference<>(); + connection + .createStatement("SELECT * from sequence_1_to_10000") + .execute() + .flatMap( + r -> + r.map( + (row, metadata) -> { + int i = row.get(0, Integer.class); + assertEquals(expectedResult.getAndIncrement(), i); + return i; + })) + .doOnError( + t -> { + endedByError.set(true); + resultingError.set(t); + }) + .subscribe(); + + Thread.sleep(10); + proxy.restartForce(); + + ScheduledThreadPoolExecutor waitingExecutor = new ScheduledThreadPoolExecutor(1); + Runnable runnable = + () -> { + while (true) { + if (expectedResult.get() >= 10000 || endedByError.get()) { + + assertNotNull(resultingError.get()); + assertTrue( + resultingError + .get() + .getMessage() + .contains( + "Driver has reconnect connection after a communications link failure with") + && resultingError.get().getMessage().contains("during command.")); + + return; + } + try { + Thread.sleep(250); + } catch (Throwable e) { + } + } + }; + + waitingExecutor.execute(runnable); + waitingExecutor.shutdown(); + waitingExecutor.awaitTermination(2, TimeUnit.MINUTES); + Thread.sleep(100); + connection.close().block(); + proxy.forceClose(); + Thread.sleep(50); + } + + private MariadbConnection createFailoverProxyConnection( + HaMode haMode, boolean transactionReplay, boolean usePrepare) throws Exception { + + HostAddress hostAddress = TestConfiguration.defaultConf.getHostAddresses().get(0); + try { + proxy = new TcpProxy(hostAddress.getHost(), hostAddress.getPort()); + } catch (IOException i) { + throw new Exception("proxy error", i); + } + + List hosts = new ArrayList<>(); + hosts.add(new HostAddress("localhost", 9999)); + hosts.add(new HostAddress("localhost", proxy.getLocalPort())); + MariadbConnectionConfiguration.Builder builder = + TestConfiguration.defaultBuilder + .clone() + .haMode(haMode.name()) + .transactionReplay(transactionReplay) + .connectTimeout(Duration.ofSeconds(5)) + .useServerPrepStmts(usePrepare) + .hostAddresses(hosts) + .host(System.getenv("TRAVIS") != null ? hostAddress.getHost() : "localhost"); + + if (TestConfiguration.defaultConf + .getSslConfig() + .getSslMode() + .equals(org.mariadb.jdbc.export.SslMode.VERIFY_FULL)) { + builder.sslMode(SslMode.VERIFY_CA); + } + + MariadbConnectionConfiguration confProxy = builder.build(); + + return new MariadbConnectionFactory(confProxy).create().block(); + } + +} diff --git a/src/test/java/org/mariadb/r2dbc/integration/LoggingTest.java b/src/test/java/org/mariadb/r2dbc/integration/LoggingTest.java index 6715c2fa..c8a13278 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/LoggingTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/LoggingTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -34,7 +34,7 @@ void basicLogging() throws IOException { Level initialLevel = logger.getLevel(); logger.setLevel(Level.TRACE); logger.setAdditive(false); - logger.detachAndStopAllAppenders(); + // logger.detachAndStopAllAppenders(); LoggerContext context = new LoggerContext(); FileAppender fa = new FileAppender(); @@ -61,66 +61,71 @@ void basicLogging() throws IOException { .flatMap(r -> r.map((row, metadata) -> row.get(0, Integer.class))) .as(StepVerifier::create) .expectNext(1) - .then( - () -> { - MariadbConnectionMetadata meta = connection.getMetadata(); - connection.close().block(); - try { - String contents = new String(Files.readAllBytes(Paths.get(tempFile.getPath()))); - String selectIsolation = - " +-------------------------------------------------+\r\n" - + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" - + "+--------+-------------------------------------------------+----------------+\r\n" - + "|00000000| 16 00 00 00 03 53 45 4c 45 43 54 20 40 40 74 78 |.....SELECT @@tx|\r\n" - + "|00000010| 5f 69 73 6f 6c 61 74 69 6f 6e |_isolation |\r\n" - + "+--------+-------------------------------------------------+----------------+"; - String mysqlIsolation = - " +-------------------------------------------------+\r\n" - + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" - + "+--------+-------------------------------------------------+----------------+\r\n" - + "|00000000| 1f 00 00 00 03 53 45 4c 45 43 54 20 40 40 74 72 |.....SELECT @@tr|\r\n" - + "|00000010| 61 6e 73 61 63 74 69 6f 6e 5f 69 73 6f 6c 61 74 |ansaction_isolat|\r\n" - + "|00000020| 69 6f 6e |ion |\r\n" - + "+--------+-------------------------------------------------+----------------+"; + .verifyComplete(); + MariadbConnectionMetadata meta = connection.getMetadata(); + connection.close().block(); - if (meta.isMariaDBServer() - || (meta.getMajorVersion() < 8 && !meta.minVersion(5, 7, 20)) - || (meta.getMajorVersion() >= 8 && !meta.minVersion(8, 0, 3))) { - Assertions.assertTrue( - contents.contains(selectIsolation) - || contents.contains(selectIsolation.replace("\r\n", "\n"))); - } else { - Assertions.assertTrue( - contents.contains(mysqlIsolation) - || contents.contains(mysqlIsolation.replace("\r\n", "\n"))); - } + try { + String contents = new String(Files.readAllBytes(Paths.get(tempFile.getPath()))); + String selectIsolation = + " +-------------------------------------------------+\r\n" + + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" + + "+--------+-------------------------------------------------+----------------+\r\n" + + "|00000000| 80 00 00 00 03 53 45 54 20 61 75 74 6f 63 6f 6d |.....SET autocom|\r\n" + + "|00000010| 6d 69 74 3d 31 2c 74 78 5f 69 73 6f 6c 61 74 69 |mit=1,tx_isolati|\r\n" + + "|00000020| 6f 6e 3d 27 52 45 50 45 41 54 41 42 4c 45 2d 52 |on='REPEATABLE-R|\r\n" + + "|00000030| 45 41 44 27 2c 73 65 73 73 69 6f 6e 5f 74 72 61 |EAD',session_tra|\r\n" + + "|00000040| 63 6b 5f 73 63 68 65 6d 61 3d 31 2c 73 65 73 73 |ck_schema=1,sess|\r\n" + + "|00000050| 69 6f 6e 5f 74 72 61 63 6b 5f 73 79 73 74 65 6d |ion_track_system|\r\n" + + "|00000060| 5f 76 61 72 69 61 62 6c 65 73 3d 27 61 75 74 6f |_variables='auto|\r\n" + + "|00000070| 63 6f 6d 6d 69 74 2c 74 78 5f 69 73 6f 6c 61 74 |commit,tx_isolat|\r\n" + + "|00000080| 69 6f 6e 27 |ion' |\r\n" + + "+--------+-------------------------------------------------+----------------+"; + String mysqlIsolation = + " +-------------------------------------------------+\r\n" + + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" + + "+--------+-------------------------------------------------+----------------+\r\n" + + "|00000000| 92 00 00 00 03 53 45 54 20 61 75 74 6f 63 6f 6d |.....SET autocom|\r\n" + + "|00000010| 6d 69 74 3d 31 2c 74 72 61 6e 73 61 63 74 69 6f |mit=1,transactio|\r\n" + + "|00000020| 6e 5f 69 73 6f 6c 61 74 69 6f 6e 3d 27 52 45 50 |n_isolation='REP|\r\n" + + "|00000030| 45 41 54 41 42 4c 45 2d 52 45 41 44 27 2c 73 65 |EATABLE-READ',se|\r\n" + + "|00000040| 73 73 69 6f 6e 5f 74 72 61 63 6b 5f 73 63 68 65 |ssion_track_sche|\r\n" + + "|00000050| 6d 61 3d 31 2c 73 65 73 73 69 6f 6e 5f 74 72 61 |ma=1,session_tra|\r\n" + + "|00000060| 63 6b 5f 73 79 73 74 65 6d 5f 76 61 72 69 61 62 |ck_system_variab|\r\n" + + "|00000070| 6c 65 73 3d 27 61 75 74 6f 63 6f 6d 6d 69 74 2c |les='autocommit,|\r\n" + + "|00000080| 74 72 61 6e 73 61 63 74 69 6f 6e 5f 69 73 6f 6c |transaction_isol|\r\n" + + "|00000090| 61 74 69 6f 6e 27 |ation' |\r\n" + + "+--------+-------------------------------------------------+----------------+"; - String selectOne = - " +-------------------------------------------------+\r\n" - + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" - + "+--------+-------------------------------------------------+----------------+\r\n" - + "|00000000| 09 00 00 00 03 53 45 4c 45 43 54 20 31 |.....SELECT 1 |\r\n" - + "+--------+-------------------------------------------------+----------------+"; - Assertions.assertTrue( - contents.contains(selectOne) - || contents.contains(selectOne.replace("\r\n", "\n"))); - String rowResult = - " +-------------------------------------------------+\r\n" - + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" - + "+--------+-------------------------------------------------+----------------+\r\n" - + "|00000000| 01 00 00 00 01 |..... |\r\n" - + "+--------+-------------------------------------------------+----------------+"; - Assertions.assertTrue( - contents.contains(rowResult) - || contents.contains(rowResult.replace("\r\n", "\n"))); - logger.setLevel(initialLevel); - logger.detachAppender(fa); - } catch (IOException e) { - e.printStackTrace(); - Assertions.fail(); - } - }) - .verifyComplete(); + if (meta.isMariaDBServer() + || (meta.getMajorVersion() < 8 && !meta.minVersion(5, 7, 20)) + || (meta.getMajorVersion() >= 8 && !meta.minVersion(8, 0, 3))) { + Assertions.assertTrue( + contents.contains(selectIsolation) + || contents.contains(selectIsolation.replace("\r\n", "\n")), + contents); + } else { + Assertions.assertTrue( + contents.contains(mysqlIsolation) + || contents.contains(mysqlIsolation.replace("\r\n", "\n")), + contents); + } + + String selectOne = + " +-------------------------------------------------+\r\n" + + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |\r\n" + + "+--------+-------------------------------------------------+----------------+\r\n" + + "|00000000| 09 00 00 00 03 53 45 4c 45 43 54 20 31 |.....SELECT 1 |\r\n" + + "+--------+-------------------------------------------------+----------------+"; + Assertions.assertTrue( + contents.contains(selectOne) || contents.contains(selectOne.replace("\r\n", "\n"))); + logger.setLevel(initialLevel); + logger.detachAppender(fa); + logger.setAdditive(true); + } catch (Throwable e) { + e.printStackTrace(); + Assertions.fail(); + } } public String encodeHexString(byte[] byteArray) { diff --git a/src/test/java/org/mariadb/r2dbc/integration/MariadbBinaryTestKit.java b/src/test/java/org/mariadb/r2dbc/integration/MariadbBinaryTestKit.java new file mode 100644 index 00000000..2d852ec6 --- /dev/null +++ b/src/test/java/org/mariadb/r2dbc/integration/MariadbBinaryTestKit.java @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.integration; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.test.TestKit; +import java.sql.SQLException; +import javax.sql.DataSource; +import org.junit.jupiter.api.Assumptions; +import org.mariadb.jdbc.MariaDbDataSource; +import org.mariadb.r2dbc.MariadbConnectionConfiguration; +import org.mariadb.r2dbc.MariadbConnectionFactory; +import org.mariadb.r2dbc.TestConfiguration; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; + +public class MariadbBinaryTestKit implements TestKit { + private static final DataSource jdbcDatasource; + + static { + String connString = + String.format( + "jdbc:mariadb://%s:%s/%s?user=%s&password=%s", + TestConfiguration.host, + TestConfiguration.port, + TestConfiguration.database, + TestConfiguration.username, + TestConfiguration.password); + try { + jdbcDatasource = new MariaDbDataSource(connString); + } catch (SQLException e) { + throw new IllegalArgumentException( + String.format("wrong initialization with %s", connString), e); + } + } + + @Override + public ConnectionFactory getConnectionFactory() { + // error crashing maxscale 6.1.x + try (java.sql.Connection con = jdbcDatasource.getConnection()) { + Assumptions.assumeTrue( + !con.getMetaData().getDatabaseProductVersion().contains("maxScale-6.1.") + && !"skysql-ha".equals(System.getenv("srv"))); + } catch (SQLException e) { + // eat + } + try { + MariadbConnectionConfiguration confMulti = + TestConfiguration.defaultBuilder + .clone() + .useServerPrepStmts(true) + .allowMultiQueries(true) + .build(); + return new MariadbConnectionFactory(confMulti); + } catch (CloneNotSupportedException e) { + throw new IllegalStateException("Unexpected error"); + } + } + + @Override + public String getPlaceholder(int i) { + return ":v" + i; + } + + @Override + public String getIdentifier(int i) { + return "v" + i; + } + + @Override + public JdbcOperations getJdbcOperations() { + return new JdbcTemplate(MariadbBinaryTestKit.jdbcDatasource); + } + + @Override + public String doGetSql(TestStatement statement) { + switch (statement) { + case CREATE_TABLE_AUTOGENERATED_KEY: + return TestStatement.CREATE_TABLE_AUTOGENERATED_KEY + .getSql() + .replaceAll("IDENTITY", "PRIMARY KEY AUTO_INCREMENT"); + case INSERT_VALUE_AUTOGENERATED_KEY: + case INSERT_VALUE100: + return "INSERT INTO test(test_value) VALUES (100)"; + case INSERT_VALUE200: + return "INSERT INTO test(test_value) VALUES (200)"; + default: + return statement.getSql(); + } + } + + @Override + public String clobType() { + return "TEXT"; + } +} diff --git a/src/test/java/org/mariadb/r2dbc/integration/MariadbTextTestKit.java b/src/test/java/org/mariadb/r2dbc/integration/MariadbTextTestKit.java new file mode 100644 index 00000000..131b67d5 --- /dev/null +++ b/src/test/java/org/mariadb/r2dbc/integration/MariadbTextTestKit.java @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.integration; + +import io.r2dbc.spi.ConnectionFactory; +import io.r2dbc.spi.test.TestKit; +import java.sql.SQLException; +import javax.sql.DataSource; +import org.junit.jupiter.api.Assumptions; +import org.mariadb.jdbc.MariaDbDataSource; +import org.mariadb.r2dbc.MariadbConnectionConfiguration; +import org.mariadb.r2dbc.MariadbConnectionFactory; +import org.mariadb.r2dbc.TestConfiguration; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; + +public class MariadbTextTestKit implements TestKit { + private static final DataSource jdbcDatasource; + + static { + String connString = + String.format( + "jdbc:mariadb://%s:%s/%s?user=%s&password=%s", + TestConfiguration.host, + TestConfiguration.port, + TestConfiguration.database, + TestConfiguration.username, + TestConfiguration.password); + try { + jdbcDatasource = new MariaDbDataSource(connString); + } catch (SQLException e) { + throw new IllegalArgumentException( + String.format("wrong initialization with %s", connString), e); + } + } + + @Override + public ConnectionFactory getConnectionFactory() { + // error crashing maxscale 6.1.x + try (java.sql.Connection con = jdbcDatasource.getConnection()) { + Assumptions.assumeTrue( + !con.getMetaData().getDatabaseProductVersion().contains("maxScale-6.1.") + && !"skysql-ha".equals(System.getenv("srv"))); + } catch (SQLException e) { + // eat + } + + try { + MariadbConnectionConfiguration confMulti = + TestConfiguration.defaultBuilder.clone().allowMultiQueries(true).build(); + return new MariadbConnectionFactory(confMulti); + } catch (CloneNotSupportedException e) { + throw new IllegalStateException("Unexpected error"); + } + } + + @Override + public String getPlaceholder(int i) { + return ":v" + i; + } + + @Override + public String getIdentifier(int i) { + return "v" + i; + } + + @Override + public JdbcOperations getJdbcOperations() { + return new JdbcTemplate(MariadbTextTestKit.jdbcDatasource); + } + + @Override + public String doGetSql(TestStatement statement) { + switch (statement) { + case CREATE_TABLE_AUTOGENERATED_KEY: + return TestStatement.CREATE_TABLE_AUTOGENERATED_KEY + .getSql() + .replaceAll("IDENTITY", "PRIMARY KEY AUTO_INCREMENT"); + case INSERT_VALUE_AUTOGENERATED_KEY: + case INSERT_VALUE100: + return "INSERT INTO test(test_value) VALUES (100)"; + case INSERT_VALUE200: + return "INSERT INTO test(test_value) VALUES (200)"; + default: + return statement.getSql(); + } + } + + @Override + public String clobType() { + return "TEXT"; + } +} diff --git a/src/test/java/org/mariadb/r2dbc/integration/MultiQueriesTest.java b/src/test/java/org/mariadb/r2dbc/integration/MultiQueriesTest.java index 8f1f83e1..2a75bb2d 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/MultiQueriesTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/MultiQueriesTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; diff --git a/src/test/java/org/mariadb/r2dbc/integration/NoPipelineTest.java b/src/test/java/org/mariadb/r2dbc/integration/NoPipelineTest.java index 503700ee..3aaa9006 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/NoPipelineTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/NoPipelineTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; diff --git a/src/test/java/org/mariadb/r2dbc/integration/PrepareResultSetTest.java b/src/test/java/org/mariadb/r2dbc/integration/PrepareResultSetTest.java index c3fc9b91..cf265060 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/PrepareResultSetTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/PrepareResultSetTest.java @@ -1,14 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; import io.r2dbc.spi.R2dbcTransientResourceException; import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.MariadbConnectionConfiguration; @@ -21,7 +19,7 @@ import reactor.test.StepVerifier; public class PrepareResultSetTest extends BaseConnectionTest { - private static List stringList = + private static final List stringList = Arrays.asList( "456", "789000002", @@ -43,7 +41,7 @@ public class PrepareResultSetTest extends BaseConnectionTest { @BeforeAll public static void before2() { - sharedConn.createStatement("DROP TABLE IF EXISTS PrepareResultSetTest").execute().blockLast(); + after2(); sharedConn .createStatement( "CREATE TABLE PrepareResultSetTest(" @@ -71,37 +69,52 @@ public static void before2() { "INSERT INTO PrepareResultSetTest VALUES (456,789000002,25,30, 456.45,127,2020,45,'ዩኒኮድ ወረጘ የጝ',65445681355454,987456,45000, 45.9, -2, 2045, 12, 'ዩኒኮድ What does this means ?')") .execute() .blockLast(); + sharedConn.createStatement("CREATE TABLE myTable(a varchar(10))").execute().blockLast(); + sharedConn + .createStatement( + "CREATE TABLE parameterLengthEncoded(t0 VARCHAR(1024),t1 MEDIUMTEXT) DEFAULT CHARSET=utf8mb4") + .execute() + .blockLast(); + sharedConn + .createStatement( + "CREATE TABLE parameterLengthEncodedLong (t0 LONGTEXT) DEFAULT CHARSET=utf8mb4") + .execute() + .blockLast(); + sharedConn.createStatement("CREATE TABLE validateParam(t0 VARCHAR(10))").execute().blockLast(); + sharedConnPrepare + .createStatement( + "CREATE TABLE missingParameter(t1 VARCHAR(256),t2 VARCHAR(256)) DEFAULT CHARSET=utf8mb4") + .execute() + .blockLast(); } @AfterAll public static void after2() { - sharedConn.createStatement("DROP TABLE PrepareResultSetTest").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS PrepareResultSetTest").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS myTable").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS parameterLengthEncoded").execute().blockLast(); + sharedConn + .createStatement("DROP TABLE IF EXISTS parameterLengthEncodedLong") + .execute() + .blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS validateParam").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS missingParameter").execute().blockLast(); } @Test void bindWithName() { - assertThrows( - Exception.class, - () -> - sharedConnPrepare - .createStatement("INSERT INTO myTable (a) VALUES (:var1)") - .bind("var1", "test"), - "Cannot use getColumn(name) with prepared statement"); - assertThrows( - Exception.class, - () -> - sharedConnPrepare - .createStatement("INSERT INTO myTable (a) VALUES (:var1)") - .bindNull("var1", String.class), - "Cannot use getColumn(name) with prepared statement"); + sharedConnPrepare + .createStatement("INSERT INTO myTable (a) VALUES (:var1)") + .bind("var1", "test") + .execute(); } @Test void validateParam() { - sharedConnPrepare - .createStatement("CREATE TEMPORARY TABLE validateParam(t0 VARCHAR(10))") - .execute() - .blockLast(); + // disabling with maxscale due to MXS-3956 + // to be re-enable when > 6.1.1 + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); Assertions.assertThrows( Exception.class, () -> @@ -118,35 +131,39 @@ void validateParam() { .execute() .flatMap(r -> r.getRowsUpdated()) .blockLast(), - "Parameter at position 0 is not set"); + "No parameters have been set"); } @Test void parameterLengthEncoded() { - Assumptions.assumeTrue(maxAllowedPacket() >= 16 * 1024 * 1024); - + Assumptions.assumeTrue(maxAllowedPacket() >= 17 * 1024 * 1024); + Assumptions.assumeTrue(runLongTest()); char[] arr1024 = new char[1024]; for (int i = 0; i < arr1024.length; i++) { arr1024[i] = (char) ('a' + (i % 10)); } - char[] arr = new char[16_000_000]; for (int i = 0; i < arr.length; i++) { arr[i] = (char) ('a' + (i % 10)); } + char[] arr2 = new char[17_000_000]; + for (int i = 0; i < arr2.length; i++) { + arr2[i] = (char) ('a' + (i % 10)); + } + String arr1024St = String.valueOf(arr1024); + String arrSt = String.valueOf(arr); + String arrSt2 = String.valueOf(arr2); - sharedConnPrepare - .createStatement( - "CREATE TEMPORARY TABLE parameterLengthEncoded" - + "(t0 VARCHAR(1024),t1 MEDIUMTEXT) DEFAULT CHARSET=utf8mb4") - .execute() - .blockLast(); sharedConnPrepare .createStatement("INSERT INTO parameterLengthEncoded VALUES (?, ?)") - .bind(0, String.valueOf(arr1024)) - .bind(1, String.valueOf(arr)) + .bind(0, arr1024St) + .bind(1, arrSt) + .add() + .bind(0, arr1024St) + .bind(1, arrSt2) .execute() .blockLast(); + AtomicBoolean first = new AtomicBoolean(true); sharedConnPrepare .createStatement("SELECT * FROM parameterLengthEncoded WHERE 1 = ?") .bind(0, 1) @@ -156,25 +173,38 @@ void parameterLengthEncoded() { r.map( (row, metadata) -> { String t0 = row.get(0, String.class); - String t1 = row.get(1, String.class); - Assertions.assertEquals(String.valueOf(arr1024), t0); - Assertions.assertEquals(String.valueOf(arr), t1); + if (first.get()) { + String t1 = row.get(1, String.class); + Assertions.assertEquals(arrSt, t1); + first.set(false); + } else { + String t1 = row.get(1, String.class); + Assertions.assertEquals(arrSt2, t1); + } + Assertions.assertEquals(arr1024St, t0); return t0; })) .as(StepVerifier::create) .expectNext(String.valueOf(arr1024)) .verifyComplete(); + first.set(true); sharedConnPrepare - .createStatement("SELECT * FROM parameterLengthEncoded /* ? */") + .createStatement("SELECT * FROM parameterLengthEncoded /* ? */ ") .execute() .flatMap( r -> r.map( (row, metadata) -> { String t0 = row.get(0, String.class); - String t1 = row.get(1, String.class); - Assertions.assertEquals(String.valueOf(arr1024), t0); - Assertions.assertEquals(String.valueOf(arr), t1); + if (first.get()) { + String t1 = row.get(1, String.class); + Assertions.assertEquals(arrSt, t1); + first.set(false); + } else { + String t1 = row.get(1, String.class); + Assertions.assertEquals(arrSt2, t1); + } + Assertions.assertEquals(arr1024St, t0); return t0; })) .as(StepVerifier::create) @@ -194,12 +224,7 @@ void parameterLengthEncodedLong() { arr[i] = (char) ('a' + (i % 10)); } String val = String.valueOf(arr); - sharedConnPrepare - .createStatement( - "CREATE TEMPORARY TABLE parameterLengthEncodedLong" - + "(t0 LONGTEXT) DEFAULT CHARSET=utf8mb4") - .execute() - .blockLast(); + sharedConnPrepare.beginTransaction().block(); sharedConnPrepare .createStatement("INSERT INTO parameterLengthEncodedLong VALUES (?)") @@ -218,27 +243,25 @@ void parameterLengthEncodedLong() { @Test void missingParameter() { - sharedConnPrepare - .createStatement( - "CREATE TEMPORARY TABLE missingParameter" - + "(t1 VARCHAR(256),t2 VARCHAR(256)) DEFAULT CHARSET=utf8mb4") - .execute() - .blockLast(); - // missing first parameter MariadbStatement stmt = sharedConnPrepare.createStatement("INSERT INTO missingParameter(t1, t2) VALUES (?, ?)"); - stmt.bind(1, "test").execute().blockLast(); - + assertThrows( + IllegalStateException.class, + () -> stmt.bind(1, "test").execute().blockLast(), + "Parameter at position 0 is not set"); assertThrows( IllegalArgumentException.class, () -> stmt.bind(null, null), "identifier cannot be null"); assertThrows( - IllegalArgumentException.class, + NoSuchElementException.class, () -> stmt.bind("test", null), - "Cannot use getColumn(name) with prepared statement"); + "No parameter with name 'test' found"); + stmt.bindNull(0, null).bind(1, "test").execute().blockLast(); + + stmt.bind(1, "test"); assertThrows( - IllegalArgumentException.class, () -> stmt.add(), "Parameter at position 0 is not set"); + IllegalStateException.class, () -> stmt.add(), "Parameter at position 0 is not set"); sharedConnPrepare .createStatement("SELECT * FROM missingParameter") .execute() @@ -262,12 +285,7 @@ void resultSetSkippingRes() { .createStatement("SELECT * FROM PrepareResultSetTest WHERE 1 = ?") .bind(0, 1) .execute() - .flatMap( - r -> - r.map( - (row, metadata) -> { - return row.get(finalI, String.class); - })) + .flatMap(r -> r.map((row, metadata) -> row.get(finalI, String.class))) .as(StepVerifier::create) .expectNext(stringList.get(i)) .verifyComplete(); @@ -398,7 +416,7 @@ void parameterVerification() { assertThrows( IndexOutOfBoundsException.class, () -> stmt.bind(-1, 1), - "index must be in 0-0 range but value is -1"); + "wrong index value -1, index must be positive"); stmt.bind(0, 1).execute().subscribe().dispose(); stmt.bind(0, 1) .execute() @@ -410,20 +428,30 @@ void parameterVerification() { assertThrows( IndexOutOfBoundsException.class, () -> stmt.bind(2, 1), - "index must be in 0-0 range but value is 2"); + "Binding index 2 when only 1 parameters are expected"); assertThrows( IllegalArgumentException.class, - () -> stmt.bind(0, this), + () -> stmt.bind(0, this).execute().blockLast(), "No encoder for class org.mariadb.r2dbc.integration.PrepareResultSetTest (parameter at index 0) "); assertThrows( IndexOutOfBoundsException.class, () -> stmt.bindNull(-1, Integer.class), - "index must be in 0-0 range but value is -1"); + "wrong index value -1, index must be positive"); assertThrows( IndexOutOfBoundsException.class, () -> stmt.bindNull(2, Integer.class), - "index must be in 0-0 range but value is 2"); - stmt.bindNull(0, this.getClass()); + "Cannot bind parameter 2, statement has 1 parameters"); + assertThrows( + IllegalArgumentException.class, + () -> stmt.bindNull(0, this.getClass()), + "No encoder for class org.mariadb.r2dbc.integration.PrepareResultSetTest"); + + // error crashing maxscale 6.1.x + Assumptions.assumeTrue( + !sharedConn.getMetadata().getDatabaseVersion().contains("maxScale-6.1.") + && !"skysql-ha".equals(System.getenv("srv"))); + + stmt.bindNull(0, String.class); stmt.execute() .flatMap(r -> r.map((row, metadata) -> row.get(0, Long.class))) .as(StepVerifier::create) @@ -431,11 +459,10 @@ void parameterVerification() { .verifyComplete(); // no parameter assertThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> stmt.execute().blockLast(), - "Parameter at position 0 is not " + "set"); - assertThrows( - IllegalArgumentException.class, () -> stmt.add(), "Parameter at position 0 is not set"); + "No parameters have been set"); + Assertions.assertThrows(IllegalArgumentException.class, () -> stmt.add()); } @Test @@ -486,8 +513,11 @@ void parameterNull() { .as(StepVerifier::create) .expectNext(Optional.of("2"), Optional.empty()) .verifyComplete(); - - stmt.bindNull(0, this.getClass()) + assertThrows( + IllegalArgumentException.class, + () -> stmt.bindNull(0, this.getClass()), + "No encoder for class org.mariadb.r2dbc.integration.PrepareResultSetTest"); + stmt.bindNull(0, String.class) .execute() .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) .as(StepVerifier::create) @@ -498,9 +528,9 @@ void parameterNull() { () -> stmt.bindNull(null, String.class), "identifier cannot be null"); assertThrows( - IllegalArgumentException.class, + NoSuchElementException.class, () -> stmt.bindNull("fff", String.class), - "Cannot use getColumn(name) with prepared statement"); + "No parameter with name 'fff' found (possible values [null])"); } @Test @@ -514,10 +544,10 @@ void prepareReuse() { assertThrows( IndexOutOfBoundsException.class, () -> stmt.bind(2, 1), - "index must be in 0-0 range but value is 2"); + "Binding index 2 when only 1 parameters are expected"); assertThrows( IllegalArgumentException.class, - () -> stmt.bind(0, this), + () -> stmt.bind(0, this).execute().blockLast(), "No encoder for class org.mariadb.r2dbc.integration.PrepareResultSetTest (parameter at index 0) "); assertThrows( IndexOutOfBoundsException.class, @@ -526,14 +556,19 @@ void prepareReuse() { assertThrows( IndexOutOfBoundsException.class, () -> stmt.bindNull(2, Integer.class), - "index must be in 0-0 range but value is 2"); - stmt.bindNull(0, this.getClass()); + "Cannot bind parameter 2, statement has 1 parameters"); + assertThrows( + IllegalArgumentException.class, + () -> stmt.bindNull(0, this.getClass()), + "No encoder for class org.mariadb.r2dbc.integration.PrepareResultSetTest (parameter at index 0)"); + stmt.bindNull(0, String.class); + stmt.bind(0, 1); stmt.execute().blockLast(); // no parameter assertThrows( - IllegalArgumentException.class, + IllegalStateException.class, () -> stmt.execute().blockLast(), - "Parameter at position 0 is not set"); + "No parameters have been set"); } private List prepareInfo(MariadbConnection connection) { diff --git a/src/test/java/org/mariadb/r2dbc/integration/ProcedureResultsetTest.java b/src/test/java/org/mariadb/r2dbc/integration/ProcedureResultsetTest.java new file mode 100644 index 00000000..efc49c91 --- /dev/null +++ b/src/test/java/org/mariadb/r2dbc/integration/ProcedureResultsetTest.java @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.integration; + +import io.r2dbc.spi.OutParametersMetadata; +import io.r2dbc.spi.Parameters; +import io.r2dbc.spi.R2dbcType; +import io.r2dbc.spi.Result; +import java.time.LocalDateTime; +import java.util.List; +import org.junit.jupiter.api.*; +import org.mariadb.r2dbc.BaseConnectionTest; +import org.mariadb.r2dbc.api.MariadbResult; +import reactor.core.publisher.Flux; + +public class ProcedureResultsetTest extends BaseConnectionTest { + + @BeforeAll + public static void before2() { + dropAll(); + sharedConn + .createStatement( + "CREATE PROCEDURE basic_proc (IN t1 INT, INOUT t2 INT unsigned, OUT t3 INT, IN t4 INT, OUT t5 VARCHAR(20) CHARSET utf8mb4, OUT t6 TIMESTAMP, OUT t7 VARCHAR(20) CHARSET utf8mb4) BEGIN \n" + + "SELECT 1;\n" + + "set t3 = t1 * t4;\n" + + "set t2 = t2 * t1;\n" + + "set t5 = 'http://test';\n" + + "set t6 = TIMESTAMP('2003-12-31 12:00:00');\n" + + "set t7 = 'test';\n" + + "END") + .execute() + .blockLast(); + } + + @AfterAll + public static void dropAll() { + sharedConn.createStatement("DROP PROCEDURE IF EXISTS basic_proc").execute().blockLast(); + } + + @Test + void outputParameter() { + List> l = + sharedConn + .createStatement("call basic_proc(?,?,?,?,?,?,?)") + .bind(0, 2) + .bind(1, Parameters.inOut(2)) + .bind(2, Parameters.out(R2dbcType.INTEGER)) + .bind(3, 10) + .bind(4, Parameters.out(R2dbcType.VARCHAR)) + .bind(5, Parameters.out(R2dbcType.TIMESTAMP)) + .bind(6, Parameters.out(R2dbcType.VARCHAR)) + .execute() + .flatMap( + r -> + ((MariadbResult) r.filter(Result.OutSegment.class::isInstance)) + .flatMap( + seg -> { + OutParametersMetadata metas = + ((Result.OutSegment) seg).outParameters().getMetadata(); + + assertThrows( + IllegalArgumentException.class, + () -> metas.getParameterMetadata(-1), + "Column index -1 is not in permit range[0,4]"); + assertThrows( + IllegalArgumentException.class, + () -> metas.getParameterMetadata(10), + "Column index 10 is not in permit range[0,4]"); + + Assertions.assertEquals( + metas.getParameterMetadata(0), metas.getParameterMetadata("t2")); + Assertions.assertEquals(5, metas.getParameterMetadatas().size()); + assertThrows( + IllegalArgumentException.class, + () -> metas.getParameterMetadata("wrong"), + "Column name 'wrong' does not exist in column names [t2, t3, t5, t6, t7]"); + return Flux.just( + ((Result.OutSegment) seg).outParameters().get(0), + ((Result.OutSegment) seg).outParameters().get(1), + ((Result.OutSegment) seg).outParameters().get(2), + ((Result.OutSegment) seg).outParameters().get(3), + ((Result.OutSegment) seg).outParameters().get(4)); + }) + .collectList()) + .collectList() + .block(); + + Assertions.assertEquals(2, l.size()); + Assertions.assertEquals(0, l.get(0).size()); + Assertions.assertEquals(5, l.get(1).size()); + Assertions.assertEquals(4L, l.get(1).get(0)); + Assertions.assertEquals(20, l.get(1).get(1)); + if (isMariaDBServer() && minVersion(10, 3, 0)) { + Assertions.assertEquals("http://test", l.get(1).get(2)); + Assertions.assertEquals("test", l.get(1).get(4)); + } + Assertions.assertEquals(LocalDateTime.parse("2003-12-31T12:00:00"), l.get(1).get(3)); + } +} diff --git a/src/test/java/org/mariadb/r2dbc/integration/ResultsetTest.java b/src/test/java/org/mariadb/r2dbc/integration/ResultsetTest.java index aa579b05..677d5f57 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/ResultsetTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/ResultsetTest.java @@ -1,13 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; import static org.junit.jupiter.api.Assertions.assertEquals; -import io.r2dbc.spi.R2dbcTransientResourceException; import java.math.BigInteger; -import java.sql.SQLException; +import java.util.NoSuchElementException; import java.util.Random; import java.util.concurrent.atomic.AtomicBoolean; import org.junit.jupiter.api.*; @@ -17,7 +16,7 @@ import reactor.test.StepVerifier; public class ResultsetTest extends BaseConnectionTest { - private static String vals = "azertyuiopqsdfghjklmwxcvbn"; + private static final String vals = "azertyuiopqsdfghjklmwxcvbn"; @BeforeAll public static void before2() { @@ -31,48 +30,44 @@ public static void before2() { @AfterAll public static void dropAll() { - sharedConn.createStatement("DROP TABLE prepare3").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS prepare3").execute().blockLast(); + sharedConn.createStatement("DROP PROCEDURE IF EXISTS multiResultSets").execute().blockLast(); } @Test void multipleResultSet() { sharedConn .createStatement( - "create procedure multiResultSets() BEGIN SELECT 'a', 'b'; SELECT 'c', 'd', 'e'; END") + "create procedure multiResultSets() BEGIN SELECT 'a', 'b'; SELECT 'c', 'd', 'e'; END") .execute() - .subscribe(); + .blockLast(); final AtomicBoolean first = new AtomicBoolean(true); sharedConn .createStatement("call multiResultSets()") .execute() - .subscribe( - res -> { - if (first.get()) { - first.set(false); - res.map( - (row, metadata) -> { - Assertions.assertEquals(row.get(0), "a"); - Assertions.assertEquals(row.get(1), "b"); - Assertions.assertEquals(row.get("a"), "a"); - Assertions.assertEquals(row.get("b"), "b"); - assertThrows( - IllegalArgumentException.class, - () -> row.get("unknown"), - "Column name 'unknown' does not exist in column names [a, b]"); - return "true"; - }) - .subscribe(); - } else { - res.map( - (row, metadata) -> { - Assertions.assertEquals(row.get(0), "c"); - Assertions.assertEquals(row.get(1), "d"); - Assertions.assertEquals(row.get(2), "e"); - return "true"; - }) - .subscribe(); - } - }); + .flatMap( + r -> + r.map( + (row, metadata) -> { + if (first.get()) { + first.set(false); + Assertions.assertEquals(row.get(0), "a"); + Assertions.assertEquals(row.get(1), "b"); + Assertions.assertEquals(row.get("a"), "a"); + Assertions.assertEquals(row.get("b"), "b"); + assertThrows( + NoSuchElementException.class, + () -> row.get("unknown"), + "Column name 'unknown' does not exist in column names [a, b]"); + return "true"; + } else { + Assertions.assertEquals(row.get(0), "c"); + Assertions.assertEquals(row.get(1), "d"); + Assertions.assertEquals(row.get(2), "e"); + return "true"; + } + })) + .blockLast(); } private String stLen(int len) { @@ -222,7 +217,7 @@ void getIndexToBig(MariadbConnection connection) { .as(StepVerifier::create) .expectErrorMatches( throwable -> - throwable instanceof R2dbcTransientResourceException + throwable instanceof IndexOutOfBoundsException && throwable.getMessage().equals("Column index 5 not in range [0-2]")) .verify(); } @@ -247,7 +242,7 @@ void getIndexToLow(MariadbConnection connection) { .as(StepVerifier::create) .expectErrorMatches( throwable -> - throwable instanceof R2dbcTransientResourceException + throwable instanceof IndexOutOfBoundsException && throwable.getMessage().equals("Column index -5 must be positive")) .verify(); } @@ -265,7 +260,7 @@ private String generateLongText(int len) { } @Test - public void skippingRes() throws SQLException { + public void skippingRes() throws Exception { BigInteger maxAllowedPacket = sharedConn .createStatement("select @@max_allowed_packet") diff --git a/src/test/java/org/mariadb/r2dbc/integration/RowMetadataTest.java b/src/test/java/org/mariadb/r2dbc/integration/RowMetadataTest.java index 09495d9c..850e8892 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/RowMetadataTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/RowMetadataTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -8,10 +8,7 @@ import io.r2dbc.spi.ColumnMetadata; import io.r2dbc.spi.Nullability; import java.math.BigDecimal; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; +import java.util.*; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -52,6 +49,7 @@ public static void afterAll2() { } @Test + @SuppressWarnings("deprecation") void rowMeta() { sharedConn .createStatement( @@ -65,13 +63,18 @@ void rowMeta() { List expected = Arrays.asList("t1Alias", "t2", "t3", "t4", "t5", "t6"); assertEquals(expected.size(), metadata.getColumnNames().size()); + assertTrue(metadata.contains("t1Alias")); + assertTrue(metadata.contains("T1ALIAS")); + assertTrue(metadata.contains("t1alias")); + assertFalse(metadata.contains("t1Aliass")); + assertArrayEquals(expected.toArray(), metadata.getColumnNames().toArray()); this.assertThrows( - IllegalArgumentException.class, + IndexOutOfBoundsException.class, () -> metadata.getColumnMetadata(-1), "Column index -1 is not in permit range[0,5]"); this.assertThrows( - IllegalArgumentException.class, + IndexOutOfBoundsException.class, () -> metadata.getColumnMetadata(6), "Column index 6 is not in permit range[0,5]"); ColumnMetadata colMeta = metadata.getColumnMetadata(0); @@ -87,7 +90,7 @@ void rowMeta() { assertEquals( System.getProperty("TEST_DATABASE", TestConfiguration.database), t1Meta.getSchema()); - assertEquals("t1Alias", t1Meta.getColumnAlias()); + assertEquals("t1Alias", t1Meta.getName()); assertEquals("t1", t1Meta.getColumn()); assertEquals("rowmeta", t1Meta.getTable()); assertEquals("rowMetaAlias", t1Meta.getTableAlias()); @@ -106,7 +109,7 @@ void rowMeta() { assertEquals("t2", colMeta.getName()); this.assertThrows( - IllegalArgumentException.class, + NoSuchElementException.class, () -> metadata.getColumnMetadata("wrongName"), "Column name 'wrongName' does not exist in column names [t1Alias, t2, t3, t4, t5, t6]"); @@ -123,7 +126,7 @@ void rowMeta() { assertEquals( System.getProperty("TEST_DATABASE", TestConfiguration.database), t2Meta.getSchema()); - assertEquals("t2", t2Meta.getColumnAlias()); + assertEquals("t2", t2Meta.getName()); assertEquals("t2", t2Meta.getColumn()); assertEquals("rowmeta", t2Meta.getTable()); assertEquals("rowMetaAlias", t2Meta.getTableAlias()); diff --git a/src/test/java/org/mariadb/r2dbc/integration/StatementBatchingTest.java b/src/test/java/org/mariadb/r2dbc/integration/StatementBatchingTest.java index 1276181d..97619258 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/StatementBatchingTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/StatementBatchingTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -37,25 +37,12 @@ void batchStatement(MariadbConnection connection) { .execute() .blockLast(); - // this is normally an error in specs (see https://github.com/r2dbc/r2dbc-spi/issues/229) - // but permitting this allowed for old behavior to be ok and following spec - connection - .createStatement("INSERT INTO batchStatement values (?, ?)") - .bind(0, 3) - .bind(1, "test") - .add() - .bind(1, "test2") - .bind(0, 4) - .add() - .execute() - .blockLast(); - connection .createStatement("SELECT * FROM batchStatement") .execute() .flatMap(r -> r.map((row, metadata) -> row.get(0, String.class) + row.get(1, String.class))) .as(StepVerifier::create) - .expectNext("1test", "2test2", "3test", "4test2") + .expectNext("1test", "2test2") .verifyComplete(); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/StatementTest.java b/src/test/java/org/mariadb/r2dbc/integration/StatementTest.java index d568fa7d..75d38347 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/StatementTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/StatementTest.java @@ -1,18 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; import io.r2dbc.spi.R2dbcDataIntegrityViolationException; import io.r2dbc.spi.R2dbcTransientResourceException; import io.r2dbc.spi.Statement; +import java.util.NoSuchElementException; import java.util.Optional; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; +import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; import org.mariadb.r2dbc.api.MariadbStatement; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; public class StatementTest extends BaseConnectionTest { @@ -20,44 +23,23 @@ public class StatementTest extends BaseConnectionTest { @Test void bindOnStatementWithoutParameter() { Statement stmt = sharedConn.createStatement("INSERT INTO someTable values (1,2)"); - try { - stmt = stmt.add(); // mean nothing there - stmt.bind(0, 1); - Assertions.fail("must have thrown exception"); - } catch (UnsupportedOperationException e) { - Assertions.assertTrue( - e.getMessage() - .contains( - "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'")); - } - - try { - stmt.bind("name", 1); - Assertions.fail("must have thrown exception"); - } catch (UnsupportedOperationException e) { - Assertions.assertTrue( - e.getMessage() - .contains( - "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'")); - } - try { - stmt.bindNull(0, String.class); - Assertions.fail("must have thrown exception"); - } catch (UnsupportedOperationException e) { - Assertions.assertTrue( - e.getMessage() - .contains( - "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'")); - } - try { - stmt.bindNull("name", String.class); - Assertions.fail("must have thrown exception"); - } catch (UnsupportedOperationException e) { - Assertions.assertTrue( - e.getMessage() - .contains( - "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'")); - } + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bind(1, 1), + "Binding index 1 when only 0 parameters are expected"); + + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bind("name", 1), + "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'"); + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bindNull(0, String.class), + "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'"); + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bindNull("name", String.class), + "Binding parameters is not supported for the statement 'INSERT INTO someTable values (1,2)'"); } @Test @@ -65,11 +47,11 @@ void bindOnNamedParameterStatement() { Statement stmt = sharedConn.createStatement("INSERT INTO someTable values (:1,:2)"); assertThrows( - IllegalArgumentException.class, + NoSuchElementException.class, () -> stmt.bind("bla", "nok"), "No parameter with name 'bla' found (possible values [1, 2])"); assertThrows( - IllegalArgumentException.class, + NoSuchElementException.class, () -> stmt.bindNull("bla", String.class), "No parameter with name 'bla' found (possible values [1, 2])"); stmt.bind("1", "ok").bindNull("2", String.class); @@ -78,31 +60,18 @@ void bindOnNamedParameterStatement() { @Test void bindOnPreparedStatementWrongParameter() { Statement stmt = sharedConn.createStatement("INSERT INTO someTable values (?, ?)"); - try { - stmt.bind(-1, 1); - Assertions.fail("must have thrown exception"); - } catch (IndexOutOfBoundsException e) { - Assertions.assertTrue(e.getMessage().contains("index must be in 0-1 range but value is -1")); - } - try { - stmt.bind(2, 1); - Assertions.fail("must have thrown exception"); - } catch (IndexOutOfBoundsException e) { - Assertions.assertTrue(e.getMessage().contains("index must be in 0-1 range but value is 2")); - } - - try { - stmt.bindNull(-1, String.class); - Assertions.fail("must have thrown exception"); - } catch (IndexOutOfBoundsException e) { - Assertions.assertTrue(e.getMessage().contains("index must be in 0-1 range but value is -1")); - } - try { - stmt.bindNull(2, String.class); - Assertions.fail("must have thrown exception"); - } catch (IndexOutOfBoundsException e) { - Assertions.assertTrue(e.getMessage().contains("index must be in 0-1 range but value is 2")); - } + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bind(-1, 1), + "wrong index value -1, index must be positive"); + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bindNull(-1, String.class), + "wrong index value -1, index must be positive"); + assertThrowsContains( + IndexOutOfBoundsException.class, + () -> stmt.bindNull(2, String.class), + "Cannot bind parameter 2, statement has 2 parameters"); } @Test @@ -117,7 +86,7 @@ void bindWrongName() { try { stmt.bind("other", 1); Assertions.fail("must have thrown exception"); - } catch (IllegalArgumentException e) { + } catch (NoSuchElementException e) { Assertions.assertTrue( e.getMessage() .contains("No parameter with name 'other' found (possible values [name1, name2])")); @@ -125,7 +94,7 @@ void bindWrongName() { try { stmt.bindNull("other", String.class); Assertions.fail("must have thrown exception"); - } catch (IllegalArgumentException e) { + } catch (NoSuchElementException e) { Assertions.assertTrue( e.getMessage() .contains("No parameter with name 'other' found (possible values [name1, name2])")); @@ -134,9 +103,9 @@ void bindWrongName() { @Test void bindUnknownClass() { - Statement stmt = sharedConn.createStatement("INSERT INTO someTable values (?)"); + MariadbStatement stmt = sharedConn.createStatement("INSERT INTO someTable values (?)"); try { - stmt.bind(0, sharedConn); + stmt.bind(0, sharedConn).execute().subscribe(); Assertions.fail("must have thrown exception"); } catch (IllegalArgumentException e) { Assertions.assertTrue( @@ -165,7 +134,7 @@ void bindOnPreparedStatementWithoutAllParameter() { try { stmt.execute(); - } catch (IllegalArgumentException e) { + } catch (IllegalStateException e) { Assertions.assertTrue(e.getMessage().contains("Parameter at position 0 is not set")); } } @@ -174,11 +143,12 @@ void bindOnPreparedStatementWithoutAllParameter() { void statementToString() { String st = sharedConn.createStatement("SELECT 1").toString(); Assertions.assertTrue( - st.contains("MariadbSimpleQueryStatement{") && st.contains("sql='SELECT 1'")); + st.contains("MariadbClientParameterizedQueryStatement{") && st.contains("sql='SELECT 1'"), + st); String st2 = sharedConn.createStatement("SELECT ?").toString(); Assertions.assertTrue( - st2.contains("MariadbClientParameterizedQueryStatement{") - && st2.contains("sql='SELECT ?'")); + st2.contains("MariadbClientParameterizedQueryStatement{") && st2.contains("sql='SELECT ?'"), + st2); } @Test @@ -206,6 +176,38 @@ void fetchSize() { .verifyComplete(); } + @Test + void metadataNotSkipped() { + String sql; + + StringBuilder sb = new StringBuilder("select ?"); + for (int i = 1; i < 1000; i++) { + sb.append(",?"); + } + sql = sb.toString(); + int[] rnds = randParams(); + io.r2dbc.spi.Statement statement = sharedConnPrepare.createStatement(sql); + for (int j = 0; j < 2; j++) { + for (int i = 0; i < 1000; i++) { + statement.bind(i, rnds[i]); + } + + Integer val = + Flux.from(statement.execute()) + .flatMap(it -> it.map((row, rowMetadata) -> row.get(0, Integer.class))) + .blockLast(); + if (rnds[0] != val) throw new IllegalStateException("ERROR"); + } + } + + private static int[] randParams() { + int[] rnds = new int[1000]; + for (int i = 0; i < 1000; i++) { + rnds[i] = (int) (Math.random() * 1000); + } + return rnds; + } + @Test public void dupplicate() { sharedConn @@ -238,6 +240,7 @@ public void dupplicate() { .expectNext(1) .verifyComplete(); } + sharedConn .createStatement("INSERT INTO dupplicate(id, test) VALUES (1, 'dupplicate')") .execute() @@ -246,7 +249,8 @@ public void dupplicate() { .expectErrorMatches( throwable -> throwable instanceof R2dbcDataIntegrityViolationException - && throwable.getMessage().contains("Duplicate entry '1' for key")) + && (throwable.getMessage().contains("Duplicate entry '1' for key") + || throwable.getMessage().contains("Duplicate key in container"))) .verify(); } @@ -286,6 +290,7 @@ public void getPosition() { @Test public void returning() { Assumptions.assumeTrue(isMariaDBServer()); + if (!minVersion(10, 5, 1)) { Assertions.assertThrows( IllegalArgumentException.class, @@ -480,7 +485,6 @@ public void returningBefore105WithParameter() { .bind(0, "b") .add() .bind(0, "c") - .add() .execute() .flatMap(r -> r.map((row, metadata) -> row.get("id", String.class))) .as(StepVerifier::create) @@ -537,26 +541,41 @@ public void prepareReturning() { .bind(0, "c") .add() .bind(0, "d") - .add() .execute() .flatMap(r -> r.map((row, metadata) -> row.get("id", String.class))) .as(StepVerifier::create) .expectNext("7", "8") .verifyComplete(); + + assertThrows( + IllegalStateException.class, + () -> + sharedConn + .createStatement("INSERT INTO prepareReturning(test) VALUES (?)") + .returnGeneratedValues() + .bind(0, "c") + .add() + .bind(0, "d") + .add() + .execute(), + "Parameter at position 0 is not set"); } @Test void parameterNull() { - sharedConn - .createStatement("CREATE TEMPORARY TABLE parameterNull(t varchar(10), t2 varchar(10))") + parameterNull(sharedConn); + parameterNull(sharedConnPrepare); + } + + void parameterNull(MariadbConnection conn) { + conn.createStatement("CREATE TEMPORARY TABLE parameterNull(t varchar(10), t2 varchar(10))") .execute() .blockLast(); - sharedConn - .createStatement("INSERT INTO parameterNull VALUES ('1', '1'), (null, '2'), (null, null)") + conn.createStatement("INSERT INTO parameterNull VALUES ('1', '1'), (null, '2'), (null, null)") .execute() .blockLast(); MariadbStatement stmt = - sharedConn.createStatement("SELECT t2 FROM parameterNull WHERE COALESCE(t,?) is null"); + conn.createStatement("SELECT t2 FROM parameterNull WHERE COALESCE(t,?) is null"); stmt.bindNull(0, Integer.class) .execute() .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) @@ -570,12 +589,10 @@ void parameterNull() { .expectNext(Optional.of("2"), Optional.empty()) .verifyComplete(); - stmt.bindNull(0, this.getClass()) - .execute() - .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) - .as(StepVerifier::create) - .expectNext(Optional.of("2"), Optional.empty()) - .verifyComplete(); + assertThrows( + IllegalArgumentException.class, + () -> stmt.bindNull(0, this.getClass()), + "No encoder for class org.mariadb.r2dbc.integration.StatementTest"); } @Test diff --git a/src/test/java/org/mariadb/r2dbc/integration/TlsTest.java b/src/test/java/org/mariadb/r2dbc/integration/TlsTest.java index bce6c0cd..31a57426 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/TlsTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/TlsTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; @@ -12,13 +12,14 @@ import java.nio.file.Paths; import java.util.Arrays; import java.util.stream.Stream; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.*; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; -import reactor.core.publisher.Mono; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; public class TlsTest extends BaseConnectionTest { @@ -40,12 +41,14 @@ public static void before2() { : Integer.valueOf(System.getenv("TEST_MAXSCALE_TLS_PORT")); // try default if not present if (serverSslCert == null) { - File sslDir = new File(System.getProperty("user.dir") + "/../../ssl"); + File sslDir = new File(System.getProperty("user.dir") + "/../ssl"); + if (!sslDir.exists() || !sslDir.isDirectory()) { + sslDir = new File(System.getProperty("user.dir") + "/../../ssl"); + } if (sslDir.exists() && sslDir.isDirectory()) { - - serverSslCert = System.getProperty("user.dir") + "/../../ssl/server.crt"; - clientSslCert = System.getProperty("user.dir") + "/../../ssl/client.crt"; - clientSslKey = System.getProperty("user.dir") + "/../../ssl/client.key"; + serverSslCert = sslDir.getPath() + "/server.crt"; + clientSslCert = sslDir.getPath() + "/client.crt"; + clientSslKey = sslDir.getPath() + "/client.key"; } } @@ -60,7 +63,7 @@ public static void before2() { .createStatement("DROP USER 'MUTUAL_AUTH'") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .subscribe(); String create_sql; String grant_sql; @@ -77,6 +80,39 @@ public static void before2() { sharedConn.createStatement("FLUSH PRIVILEGES").execute().blockLast(); } + @Test + public void testWithoutPassword() throws Throwable { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"mariadb-es".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + Assumptions.assumeTrue(haveSsl(sharedConn)); + sharedConn.createStatement("CREATE USER userWithoutPassword").execute().blockLast(); + sharedConn + .createStatement( + String.format( + "GRANT SELECT on `%s`.* to userWithoutPassword", TestConfiguration.database)) + .execute() + .blockLast(); + MariadbConnectionConfiguration conf = + TestConfiguration.defaultBuilder + .clone() + .username("userWithoutPassword") + .password("") + .port(sslPort) + .sslMode(SslMode.TRUST) + .build(); + MariadbConnection connection = new MariadbConnectionFactory(conf).create().block(); + connection.close(); + sharedConn + .createStatement("DROP USER IF EXISTS userWithoutPassword") + .execute() + .map(res -> res.getRowsUpdated()) + .onErrorReturn(Flux.empty()) + .blockLast(); + } + @Test void defaultHasNoSSL() throws Exception { Assumptions.assumeTrue( @@ -171,7 +207,10 @@ void trustForceProtocol() throws Exception { !"maxscale".equals(System.getenv("srv")) && !"skysql".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); - String trustProtocol = minVersion(8, 0, 0) ? "TLSv1.2" : "TLSv1.1"; + String trustProtocol = + (isMariaDBServer() && minVersion(10, 3, 0)) || (!isMariaDBServer() && minVersion(8, 0, 0)) + ? "TLSv1.2" + : "TLSv1.1"; Assumptions.assumeTrue(haveSsl(sharedConn)); MariadbConnectionConfiguration conf = TestConfiguration.defaultBuilder @@ -352,14 +391,15 @@ void fullMutualWithoutClientCerts() throws Exception { .serverSslCert(serverSslCert) .clientSslKey(clientSslKey) .build(); - new MariadbConnectionFactory(conf) - .create() - .as(StepVerifier::create) - .expectErrorMatches( - throwable -> - throwable instanceof R2dbcNonTransientException - && throwable.getMessage().contains("Access denied")) - .verify(); + try { + new MariadbConnectionFactory(conf).create().block(); + Assertions.fail(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + Assertions.assertTrue( + throwable instanceof R2dbcNonTransientException + && throwable.getMessage().contains("Access denied")); + } } @Test diff --git a/src/test/java/org/mariadb/r2dbc/integration/TransactionTest.java b/src/test/java/org/mariadb/r2dbc/integration/TransactionTest.java index 9ffbd413..0dab29ad 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/TransactionTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/TransactionTest.java @@ -1,11 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration; -import io.r2dbc.spi.*; import java.net.URL; -import java.util.*; import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; @@ -14,7 +12,7 @@ import reactor.test.StepVerifier; public class TransactionTest extends BaseConnectionTest { - private static String insertCmd = + private static final String insertCmd = "INSERT INTO `users` (`first_name`, `last_name`, `email`) VALUES ('MariaDB', 'Row', 'mariadb@test.com')"; @BeforeAll @@ -24,9 +22,9 @@ public static void before2() { .createStatement( "CREATE TABLE `users` (\n" + " `id` int(11) NOT NULL AUTO_INCREMENT,\n" - + " `first_name` varchar(255) COLLATE utf16_slovak_ci NOT NULL,\n" - + " `last_name` varchar(255) COLLATE utf16_slovak_ci NOT NULL,\n" - + " `email` varchar(255) COLLATE utf16_slovak_ci NOT NULL,\n" + + " `first_name` varchar(255) NOT NULL,\n" + + " `last_name` varchar(255) NOT NULL,\n" + + " `email` varchar(255) NOT NULL,\n" + " PRIMARY KEY (`id`)\n" + ")") .execute() @@ -175,6 +173,7 @@ void releaseSavepoint() throws Exception { .blockLast(); checkInserted(conn, 2); conn.rollbackTransaction().block(); + conn.setAutoCommit(true).block(); conn.close(); } @@ -191,6 +190,7 @@ void rollbackSavepoint() { conn.rollbackTransactionToSavepoint("mySavePoint").block(); checkInserted(conn, 1); conn.rollbackTransaction().block(); + conn.setAutoCommit(true).block(); conn.close(); } @@ -208,6 +208,7 @@ void rollbackSavepointPipelining() { .blockLast(); checkInserted(conn, 1); conn.rollbackTransaction().block(); + conn.setAutoCommit(true).block(); conn.close(); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/authentication/Ed25519PluginTest.java b/src/test/java/org/mariadb/r2dbc/integration/authentication/Ed25519PluginTest.java index 5dc4a44c..62abcaee 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/authentication/Ed25519PluginTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/authentication/Ed25519PluginTest.java @@ -1,45 +1,53 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.authentication; +import io.r2dbc.spi.R2dbcNonTransientResourceException; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; -import org.mariadb.r2dbc.BaseConnectionTest; -import org.mariadb.r2dbc.MariadbConnectionConfiguration; -import org.mariadb.r2dbc.MariadbConnectionFactory; -import org.mariadb.r2dbc.TestConfiguration; +import org.mariadb.r2dbc.*; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; -import reactor.core.publisher.Mono; +import reactor.core.publisher.Flux; public class Ed25519PluginTest extends BaseConnectionTest { + static AtomicBoolean ed25519PluginEnabled = new AtomicBoolean(true); @BeforeAll public static void before2() { MariadbConnectionMetadata meta = sharedConn.getMetadata(); if (meta.isMariaDBServer() && meta.minVersion(10, 2, 0)) { - sharedConn - .createStatement("INSTALL SONAME 'auth_ed25519'") - .execute() - .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) - .blockLast(); + sharedConn.createStatement("INSTALL SONAME 'auth_ed25519'").execute().blockLast(); if (meta.minVersion(10, 4, 0)) { sharedConn .createStatement( "CREATE USER verificationEd25519AuthPlugin IDENTIFIED " + "VIA ed25519 USING PASSWORD('MySup8%rPassw@ord')") .execute() + .flatMap(it -> it.getRowsUpdated()) + .onErrorResume( + e -> { + ed25519PluginEnabled.set(false); + return Flux.just(1); + }) .blockLast(); + } else { sharedConn .createStatement( "CREATE USER verificationEd25519AuthPlugin IDENTIFIED " + "VIA ed25519 USING '6aW9C7ENlasUfymtfMvMZZtnkCVlcb1ssxOLJ0kj/AA'") .execute() + .flatMap(it -> it.getRowsUpdated()) + .onErrorResume( + e -> { + ed25519PluginEnabled.set(false); + return Flux.just(1); + }) .blockLast(); } sharedConn @@ -48,6 +56,12 @@ public static void before2() { "GRANT SELECT on `%s`.* to verificationEd25519AuthPlugin", TestConfiguration.database)) .execute() + .flatMap(it -> it.getRowsUpdated()) + .onErrorResume( + e -> { + ed25519PluginEnabled.set(false); + return Flux.just(1); + }) .blockLast(); sharedConn.createStatement("FLUSH PRIVILEGES").execute().blockLast(); } @@ -59,14 +73,16 @@ public static void after2() { .createStatement("DROP USER IF EXISTS verificationEd25519AuthPlugin") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); } @Test public void verificationEd25519AuthPlugin() throws Throwable { Assumptions.assumeTrue( - !"maxscale".equals(System.getenv("srv")) && !"skysql-ha".equals(System.getenv("srv"))); + ed25519PluginEnabled.get() + && !"maxscale".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); MariadbConnectionMetadata meta = sharedConn.getMetadata(); Assumptions.assumeTrue(meta.isMariaDBServer() && meta.minVersion(10, 2, 0)); @@ -80,19 +96,45 @@ public void verificationEd25519AuthPlugin() throws Throwable { connection.close(); } + @Test + public void verificationEd25519AuthPluginRestricted() throws Throwable { + Assumptions.assumeTrue( + ed25519PluginEnabled.get() + && !"maxscale".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv"))); + MariadbConnectionMetadata meta = sharedConn.getMetadata(); + Assumptions.assumeTrue(meta.isMariaDBServer() && meta.minVersion(10, 2, 0)); + + MariadbConnectionConfiguration conf = + TestConfiguration.defaultBuilder + .clone() + .username("verificationEd25519AuthPlugin") + .password("MySup8%rPassw@ord") + .restrictedAuth("mysql_native_password") + .sslMode( + SslMode.from("1".equals(System.getenv("TEST_REQUIRE_TLS")) ? "trust" : "disabled")) + .build(); + assertThrows( + R2dbcNonTransientResourceException.class, + () -> new MariadbConnectionFactory(conf).create().block(), + "Unsupported authentication plugin client_ed25519. Authorized plugin: [mysql_native_password]"); + } + @Test public void multiAuthPlugin() throws Throwable { Assumptions.assumeTrue( !"maxscale".equals(System.getenv("srv")) && !"skysql".equals(System.getenv("srv")) - && !"skysql-ha".equals(System.getenv("srv"))); + && !"skysql-ha".equals(System.getenv("srv")) + && System.getenv("TEST_PAM_USER") != null); Assumptions.assumeTrue(isMariaDBServer() && minVersion(10, 4, 2)); - + sharedConn.createStatement("INSTALL PLUGIN pam SONAME 'auth_pam'").execute().blockLast(); sharedConn.createStatement("drop user IF EXISTS mysqltest1").execute().blockLast(); sharedConn .createStatement( "CREATE USER mysqltest1 IDENTIFIED " - + "VIA ed25519 as password('!Passw0rd3') " + + "VIA pam " + + " OR ed25519 as password('!Passw0rd3')" + " OR mysql_native_password as password('!Passw0rd3Works')") .execute() .blockLast(); @@ -117,4 +159,48 @@ public void multiAuthPlugin() throws Throwable { connection.close().block(); sharedConn.createStatement("drop user mysqltest1@'%'").execute().blockLast(); } + + @Test + public void multiAuthPluginRestricted() throws Throwable { + Assumptions.assumeTrue( + !"maxscale".equals(System.getenv("srv")) + && !"skysql".equals(System.getenv("srv")) + && !"skysql-ha".equals(System.getenv("srv")) + && System.getenv("TEST_PAM_USER") != null); + Assumptions.assumeTrue(isMariaDBServer() && minVersion(10, 4, 2)); + sharedConn.createStatement("INSTALL PLUGIN pam SONAME 'auth_pam'").execute().blockLast(); + sharedConn.createStatement("drop user IF EXISTS mysqltest1").execute().blockLast(); + sharedConn + .createStatement( + "CREATE USER mysqltest1 IDENTIFIED " + + "VIA pam " + + " OR ed25519 as password('!Passw0rd3')" + + " OR mysql_native_password as password('!Passw0rd3Works')") + .execute() + .blockLast(); + + sharedConn.createStatement("GRANT SELECT on *.* to mysqltest1").execute().blockLast(); + MariadbConnectionConfiguration conf = + TestConfiguration.defaultBuilder + .clone() + .username("mysqltest1") + .password("!Passw0rd3") + .restrictedAuth("mysql_native_password,dialog,mysql_clear_password") + .build(); + assertThrows( + R2dbcNonTransientResourceException.class, + () -> new MariadbConnectionFactory(conf).create().block(), + "Unsupported authentication plugin client_ed25519. Authorized plugin: [mysql_native_password, dialog, mysql_clear_password]"); + + MariadbConnectionConfiguration conf2 = + TestConfiguration.defaultBuilder + .clone() + .username("mysqltest1") + .restrictedAuth("mysql_native_password,ed25519") + .build(); + assertThrows( + R2dbcNonTransientResourceException.class, + () -> new MariadbConnectionFactory(conf2).create().block(), + "Unsupported authentication plugin"); + } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/authentication/PamPluginTest.java b/src/test/java/org/mariadb/r2dbc/integration/authentication/PamPluginTest.java index 2bef5ba5..fe16829e 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/authentication/PamPluginTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/authentication/PamPluginTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.authentication; @@ -39,11 +39,17 @@ public void pamAuthPlugin() throws Throwable { .blockLast(); sharedConn.createStatement("FLUSH PRIVILEGES").execute().blockLast(); + int testPort = TestConfiguration.port; + if (System.getenv("TEST_PAM_PORT") != null) { + testPort = Integer.parseInt(System.getenv("TEST_PAM_PORT")); + } + MariadbConnectionConfiguration conf = TestConfiguration.defaultBuilder .clone() .username(System.getenv("TEST_PAM_USER")) .password(System.getenv("TEST_PAM_PWD")) + .port(testPort) .build(); MariadbConnection connection = new MariadbConnectionFactory(conf).create().block(); connection.close().block(); diff --git a/src/test/java/org/mariadb/r2dbc/integration/authentication/Sha256PluginTest.java b/src/test/java/org/mariadb/r2dbc/integration/authentication/Sha256PluginTest.java index 0f4de851..3562f8e9 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/authentication/Sha256PluginTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/authentication/Sha256PluginTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.authentication; @@ -12,14 +12,15 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.*; import org.mariadb.r2dbc.api.MariadbConnection; -import reactor.core.publisher.Mono; +import reactor.core.publisher.Flux; import reactor.test.StepVerifier; public class Sha256PluginTest extends BaseConnectionTest { private static String rsaPublicKey; private static String cachingRsaPublicKey; - private static boolean isWindows = System.getProperty("os.name").toLowerCase().contains("win"); + private static final boolean isWindows = + System.getProperty("os.name").toLowerCase().contains("win"); private static boolean validPath(String path) { if (path == null) return false; @@ -164,43 +165,43 @@ public static void dropAll() { .createStatement("DROP USER sha256User") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); sharedConn .createStatement("DROP USER sha256User2") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); sharedConn .createStatement("DROP USER sha256User3") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); sharedConn .createStatement("DROP USER cachingSha256User") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); sharedConn .createStatement("DROP USER cachingSha256User2") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); sharedConn .createStatement("DROP USER cachingSha256User3") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); sharedConn .createStatement("DROP USER cachingSha256User4") .execute() .map(res -> res.getRowsUpdated()) - .onErrorReturn(Mono.empty()) + .onErrorReturn(Flux.empty()) .blockLast(); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/BigIntegerParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/BigIntegerParseTest.java index 8e7e4ed9..a5ff124b 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/BigIntegerParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/BigIntegerParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -13,6 +13,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class BigIntegerParseTest extends BaseConnectionTest { @@ -188,7 +189,10 @@ private void ByteValue(MariadbConnection connection) { .expectErrorMatches( throwable -> throwable instanceof R2dbcNonTransientResourceException - && throwable.getMessage().equals("byte overflow")) + && throwable.getMessage().equals("byte overflow") + && ((R2dbcNonTransientResourceException) throwable) + .getSql() + .equals("SELECT t1 FROM BigIntTable WHERE 1 = ? LIMIT 3")) .verify(); connection .createStatement("SELECT t1 FROM BigIntUnsignedTable WHERE 1 = ? LIMIT 3") @@ -610,5 +614,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(BigInteger.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM BigIntTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.BIGINT)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM BigIntUnsignedTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.UNSIGNED_BIGINT)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/BitParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/BitParseTest.java index 7a93d8a9..1bafd5f4 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/BitParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/BitParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -16,6 +16,7 @@ import org.mariadb.r2dbc.MariadbConnectionFactory; import org.mariadb.r2dbc.TestConfiguration; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class BitParseTest extends BaseConnectionTest { @@ -493,5 +494,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(BitSet.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM BitTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.BIT)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM BitTable2 WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.BOOLEAN)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/BlobParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/BlobParseTest.java index 9d1d63a1..1494fcd9 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/BlobParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/BlobParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class BlobParseTest extends BaseConnectionTest { @@ -81,8 +82,9 @@ private void defaultValue(MariadbConnection connection) { row.get(1); return row.get(0); })) - .cast(Blob.class) - .flatMap(Blob::stream) + // .cast(Blob.class) + // .flatMap(Blob::stream) + .cast(ByteBuffer.class) .as(StepVerifier::create) .consumeNextWith(consumer) .consumeNextWith(consumer) @@ -521,7 +523,15 @@ private void meta(MariadbConnection connection) { .execute() .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getJavaType())) .as(StepVerifier::create) - .expectNextMatches(c -> c.equals(Blob.class)) + .expectNextMatches(c -> c.equals(ByteBuffer.class)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM BlobTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.BLOB)) .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/DateParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/DateParseTest.java index 6f8f6724..2c3458dc 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/DateParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/DateParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -16,6 +16,7 @@ import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.BaseTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class DateParseTest extends BaseConnectionTest { @@ -494,5 +495,13 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(LocalDate.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM DateTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.DATE)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/DateTimeParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/DateTimeParseTest.java index 877b8825..2fa9ecd2 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/DateTimeParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/DateTimeParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -16,6 +16,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class DateTimeParseTest extends BaseConnectionTest { @@ -501,5 +502,14 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(LocalDateTime.class)) .verifyComplete(); + + connection + .createStatement("SELECT t1 FROM DateTimeTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.TIMESTAMP)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/DecimalParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/DecimalParseTest.java index 9fd71923..9d03ae9f 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/DecimalParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/DecimalParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -14,6 +14,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class DecimalParseTest extends BaseConnectionTest { @@ -441,5 +442,13 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(BigDecimal.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM DecimalTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.DECIMAL)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/DoubleParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/DoubleParseTest.java index 4ae9e1e4..15bd0cdd 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/DoubleParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/DoubleParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -14,6 +14,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class DoubleParseTest extends BaseConnectionTest { @@ -461,5 +462,13 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Double.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM DoubleTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.DOUBLE)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/FloatParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/FloatParseTest.java index d14f4252..e0757203 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/FloatParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/FloatParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -14,11 +14,13 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class FloatParseTest extends BaseConnectionTest { @BeforeAll public static void before2() { + afterAll2(); sharedConn.createStatement("CREATE TABLE FloatTable (t1 FLOAT)").execute().blockLast(); sharedConn .createStatement("INSERT INTO FloatTable VALUES (0.1),(1),(922.92233), (null)") @@ -466,5 +468,13 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Float.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM FloatTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.FLOAT)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/IntParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/IntParseTest.java index 601a71ea..e233335e 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/IntParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/IntParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -14,6 +14,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class IntParseTest extends BaseConnectionTest { @@ -632,5 +633,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Long.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM IntTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.INTEGER)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM IntUnsignedTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.UNSIGNED_INTEGER)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/MediumIntParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/MediumIntParseTest.java index 2d84a548..2e1023a2 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/MediumIntParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/MediumIntParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -13,11 +13,13 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class MediumIntParseTest extends BaseConnectionTest { @BeforeAll public static void before2() { + afterAll2(); sharedConn .createStatement("CREATE TABLE MediumIntTable (t1 MEDIUMINT, t2 MEDIUMINT ZEROFILL)") .execute() @@ -39,8 +41,8 @@ public static void before2() { @AfterAll public static void afterAll2() { - sharedConn.createStatement("DROP TABLE MediumIntTable").execute().blockLast(); - sharedConn.createStatement("DROP TABLE MediumIntUnsignedTable").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS MediumIntTable").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS MediumIntUnsignedTable").execute().blockLast(); } @Test @@ -514,5 +516,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Integer.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM MediumIntTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.INTEGER)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM MediumIntUnsignedTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.INTEGER)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/ShortParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/ShortParseTest.java index 42b6b033..2f337be5 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/ShortParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/ShortParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -13,6 +13,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class ShortParseTest extends BaseConnectionTest { @@ -558,5 +559,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Integer.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM ShortTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.SMALLINT)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM ShortUnsignedTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.UNSIGNED_SMALLINT)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/StringParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/StringParseTest.java index 20cb7f37..d9c2a9e3 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/StringParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/StringParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -17,10 +17,12 @@ import java.util.Arrays; import java.util.Optional; import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -30,11 +32,22 @@ public static void before2() { sharedConn.createStatement("DROP TABLE IF EXISTS StringTable").execute().blockLast(); sharedConn .createStatement( - "CREATE TABLE StringTable (t1 varchar(256)) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") + "CREATE TABLE StringTable (t1 varchar(256), t2 TEXT) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") .execute() .blockLast(); sharedConn - .createStatement("INSERT INTO StringTable VALUES ('some🌟'),('1'),('0'), (null)") + .createStatement( + "INSERT INTO StringTable VALUES ('some🌟', 'some🌟'),('1', '1'),('0', '0'), (null, null)") + .execute() + .blockLast(); + sharedConn + .createStatement( + "CREATE TABLE StringBinary (t1 varbinary(256), t2 varbinary(1024)) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") + .execute() + .blockLast(); + sharedConn + .createStatement( + "INSERT INTO StringBinary VALUES ('some🌟', 'some🌟'),('1', '1'),('0', '0'), (null, null)") .execute() .blockLast(); sharedConn.createStatement("FLUSH TABLES").execute().blockLast(); @@ -43,6 +56,7 @@ public static void before2() { @AfterAll public static void afterAll2() { sharedConn.createStatement("DROP TABLE IF EXISTS StringTable").execute().blockLast(); + sharedConn.createStatement("DROP TABLE IF EXISTS StringBinary").execute().blockLast(); sharedConn.createStatement("DROP TABLE IF EXISTS durationValue").execute().blockLast(); sharedConn.createStatement("DROP TABLE IF EXISTS localTimeValue").execute().blockLast(); sharedConn.createStatement("DROP TABLE IF EXISTS localDateValue").execute().blockLast(); @@ -88,6 +102,79 @@ private void defaultValue(MariadbConnection connection) { .as(StepVerifier::create) .expectNext(Optional.of("some🌟"), Optional.of("1"), Optional.of("0"), Optional.empty()) .verifyComplete(); + connection + .createStatement("SELECT t2 FROM StringTable WHERE 1 = ?") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) + .as(StepVerifier::create) + .expectNext(Optional.of("some🌟"), Optional.of("1"), Optional.of("0"), Optional.empty()) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM StringTable WHERE 1 = ?") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0, Object.class)))) + .as(StepVerifier::create) + .expectNext(Optional.of("some🌟"), Optional.of("1"), Optional.of("0"), Optional.empty()) + .verifyComplete(); + } + + @Test + void defaultValueBinary() { + defaultValueBinary(sharedConn); + } + + @Test + void defaultValuePrepareBinary() { + defaultValueBinary(sharedConnPrepare); + } + + private void defaultValueBinary(MariadbConnection connection) { + connection + .createStatement("SELECT t1 FROM StringBinary WHERE 1 = ?") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) + .as(StepVerifier::create) + .consumeNextWith( + c -> { + Assertions.assertTrue(c.get() instanceof byte[]); + Assertions.assertArrayEquals( + "some🌟".getBytes(StandardCharsets.UTF_8), (byte[]) c.get()); + }) + .consumeNextWith( + c -> + Assertions.assertArrayEquals( + "1".getBytes(StandardCharsets.UTF_8), (byte[]) c.get())) + .consumeNextWith( + c -> + Assertions.assertArrayEquals( + "0".getBytes(StandardCharsets.UTF_8), (byte[]) c.get())) + .expectNext(Optional.empty()) + .verifyComplete(); + connection + .createStatement("SELECT t2 FROM StringBinary WHERE 1 = ?") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) + .as(StepVerifier::create) + .consumeNextWith( + c -> { + Assertions.assertTrue(c.get() instanceof byte[]); + Assertions.assertArrayEquals( + "some🌟".getBytes(StandardCharsets.UTF_8), (byte[]) c.get()); + }) + .consumeNextWith( + c -> { + Assertions.assertArrayEquals("1".getBytes(StandardCharsets.UTF_8), (byte[]) c.get()); + }) + .consumeNextWith( + c -> { + Assertions.assertArrayEquals("0".getBytes(StandardCharsets.UTF_8), (byte[]) c.get()); + }) + .expectNext(Optional.empty()) + .verifyComplete(); connection .createStatement("SELECT t1 FROM StringTable WHERE 1 = ?") .bind(0, 1) @@ -741,5 +828,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(String.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM StringTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.VARCHAR)) + .verifyComplete(); + connection + .createStatement("SELECT t2 FROM StringTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.CLOB)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/TimeParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/TimeParseTest.java index ad85eb07..8be986fa 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/TimeParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/TimeParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -16,6 +16,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class TimeParseTest extends BaseConnectionTest { @@ -70,10 +71,10 @@ private void defaultValue(MariadbConnection connection) { .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) .as(StepVerifier::create) .expectNext( - Optional.of(Duration.parse("P3DT18H0.012340S")), - Optional.of(Duration.parse("P33DT8H0.123S")), - Optional.of(Duration.parse("PT8M")), - Optional.of(Duration.parse("PT22S")), + Optional.of(LocalTime.parse("18:00:00.012340")), + Optional.of(LocalTime.parse("08:00:00.123")), + Optional.of(LocalTime.parse("00:08:00")), + Optional.of(LocalTime.parse("00:00:22")), Optional.empty()) .verifyComplete(); connection @@ -83,10 +84,10 @@ private void defaultValue(MariadbConnection connection) { .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0)))) .as(StepVerifier::create) .expectNext( - Optional.of(Duration.parse("PT-10H-1M-2.01234S")), - Optional.of(Duration.parse("PT-10.123S")), - Optional.of(Duration.parse("PT0M")), - Optional.of(Duration.parse("PT-22S")), + Optional.of(LocalTime.parse("13:58:57.987660")), + Optional.of(LocalTime.parse("23:59:49.877")), + Optional.of(LocalTime.parse("00:00")), + Optional.of(LocalTime.parse("23:59:38")), Optional.empty()) .verifyComplete(); } @@ -109,12 +110,25 @@ private void durationValue(MariadbConnection connection) { .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0, Duration.class)))) .as(StepVerifier::create) .expectNext( - Optional.of(Duration.parse("PT90H0.012340S")), - Optional.of(Duration.parse("PT800H0.123S")), + Optional.of(Duration.parse("P3DT18H0.012340S")), + Optional.of(Duration.parse("P33DT8H0.123S")), Optional.of(Duration.parse("PT8M")), Optional.of(Duration.parse("PT22S")), Optional.empty()) .verifyComplete(); + connection + .createStatement("SELECT t2 FROM TimeParseTest WHERE 1 = ?") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> Optional.ofNullable(row.get(0, Duration.class)))) + .as(StepVerifier::create) + .expectNext( + Optional.of(Duration.parse("PT-10H-1M-2.01234S")), + Optional.of(Duration.parse("PT-10.123S")), + Optional.of(Duration.parse("PT0M")), + Optional.of(Duration.parse("PT-22S")), + Optional.empty()) + .verifyComplete(); } @Test @@ -541,7 +555,15 @@ private void meta(MariadbConnection connection) { .execute() .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getJavaType())) .as(StepVerifier::create) - .expectNextMatches(c -> c.equals(Duration.class)) + .expectNextMatches(c -> c.equals(LocalTime.class)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM TimeParseTest WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.TIME)) .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/TimestampParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/TimestampParseTest.java index 3467393b..63f12b66 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/TimestampParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/TimestampParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -16,6 +16,7 @@ import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class TimestampParseTest extends BaseConnectionTest { @@ -508,5 +509,13 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(LocalDateTime.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM TimestampTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.TIMESTAMP)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/TinyIntParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/TinyIntParseTest.java index 736763ee..0d034197 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/TinyIntParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/TinyIntParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -16,6 +16,7 @@ import org.mariadb.r2dbc.MariadbConnectionFactory; import org.mariadb.r2dbc.TestConfiguration; import org.mariadb.r2dbc.api.MariadbConnection; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class TinyIntParseTest extends BaseConnectionTest { @@ -559,5 +560,21 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Short.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM tinyIntTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.TINYINT)) + .verifyComplete(); + connection + .createStatement("SELECT t1 FROM tinyIntUnsignedTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.UNSIGNED_TINYINT)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/codec/YearParseTest.java b/src/test/java/org/mariadb/r2dbc/integration/codec/YearParseTest.java index a174f10d..07c76a6c 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/codec/YearParseTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/codec/YearParseTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.codec; @@ -15,10 +15,11 @@ import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; +import org.mariadb.r2dbc.util.MariadbType; import reactor.test.StepVerifier; public class YearParseTest extends BaseConnectionTest { - private static MariadbConnectionMetadata meta = sharedConn.getMetadata(); + private static final MariadbConnectionMetadata meta = sharedConn.getMetadata(); @BeforeAll public static void before2() { @@ -522,5 +523,13 @@ private void meta(MariadbConnection connection) { .as(StepVerifier::create) .expectNextMatches(c -> c.equals(Short.class)) .verifyComplete(); + connection + .createStatement("SELECT t1 FROM YearTable WHERE 1 = ? LIMIT 1") + .bind(0, 1) + .execute() + .flatMap(r -> r.map((row, metadata) -> metadata.getColumnMetadata(0).getType())) + .as(StepVerifier::create) + .expectNextMatches(c -> c.equals(MariadbType.SMALLINT)) + .verifyComplete(); } } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/BigIntegerParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/BigIntegerParameterTest.java index 13d156bc..bdb13569 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/BigIntegerParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/BigIntegerParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -10,11 +10,7 @@ import java.time.LocalDateTime; import java.time.LocalTime; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbStatement; @@ -80,11 +76,9 @@ private void bigIntValue(MariadbConnection connection) { .bind(2, new BigInteger("-9")); Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=BigIntegerCodec, value=1}, Parameter{codec=BigIntegerCodec, value=9223372036854775807}, Parameter{codec=BigIntegerCodec, value=-9}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=BigIntegerCodec, value=1}, 1=Parameter{codec=BigIntegerCodec, value=9223372036854775807}, 2=Parameter{codec=BigIntegerCodec, value=-9}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=BigIntegerCodec}, 1=BindValue{codec=BigIntegerCodec}, 2=BindValue{codec=BigIntegerCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("1"), Optional.of("9223372036854775807"), Optional.of("-9")); @@ -109,11 +103,9 @@ private void stringValue(MariadbConnection connection) { .bind(2, "-9"); Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=StringCodec, value=1}, Parameter{codec=StringCodec, value=9223372036854775807}, Parameter{codec=StringCodec, value=-9}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=StringCodec, value=1}, 1=Parameter{codec=StringCodec, value=9223372036854775807}, 2=Parameter{codec=StringCodec, value=-9}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=StringCodec}, 1=BindValue{codec=StringCodec}, 2=BindValue{codec=StringCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); @@ -140,11 +132,9 @@ private void decimalValue(MariadbConnection connection) { Assertions.assertTrue( stmt.toString() - .contains( - "Parameter{codec=BigDecimalCodec, value=9223372036854775807}, Parameter{codec=BigDecimalCodec, value=-9}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=BigDecimalCodec, value=1}, 1=Parameter{codec=BigDecimalCodec, value=9223372036854775807}, 2=Parameter{codec=BigDecimalCodec, value=-9}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=BigDecimalCodec}, 1=BindValue{codec=BigDecimalCodec}, 2=BindValue{codec=BigDecimalCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("1"), Optional.of("9223372036854775807"), Optional.of("-9")); @@ -169,11 +159,9 @@ private void intValue(MariadbConnection connection) { .bind(2, 0); Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=IntCodec, value=1}, Parameter{codec=IntCodec, value=-1}, Parameter{codec=IntCodec, value=0}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=IntCodec, value=1}, 1=Parameter{codec=IntCodec, value=-1}, 2=Parameter{codec=IntCodec, value=0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=IntCodec}, 1=BindValue{codec=IntCodec}, 2=BindValue{codec=IntCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("1"), Optional.of("-1"), Optional.of("0")); @@ -186,6 +174,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -196,13 +185,12 @@ private void byteValue(MariadbConnection connection) { .bind(0, (byte) 127) .bind(1, (byte) -128) .bind(2, (byte) 0); + Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=ByteCodec, value=127}, Parameter{codec=ByteCodec, value=-128}, Parameter{codec=ByteCodec, value=0}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=ByteCodec, value=127}, 1=Parameter{codec=ByteCodec, value=-128}, 2=Parameter{codec=ByteCodec, value=0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=ByteCodec}, 1=BindValue{codec=ByteCodec}, 2=BindValue{codec=ByteCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("127"), Optional.of("-128"), Optional.of("0")); @@ -227,11 +215,9 @@ private void floatValue(MariadbConnection connection) { .bind(2, 0f); Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=FloatCodec, value=127.0}, Parameter{codec=FloatCodec, value=-128.0}, Parameter{codec=FloatCodec, value=0.0}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=FloatCodec, value=127.0}, 1=Parameter{codec=FloatCodec, value=-128.0}, 2=Parameter{codec=FloatCodec, value=0.0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=FloatCodec}, 1=BindValue{codec=FloatCodec}, 2=BindValue{codec=FloatCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("127"), Optional.of("-128"), Optional.of("0")); @@ -244,6 +230,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); } @@ -257,11 +244,9 @@ private void doubleValue(MariadbConnection connection) { Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=DoubleCodec, value=127.0}, Parameter{codec=DoubleCodec, value=-128.0}, Parameter{codec=DoubleCodec, value=0.0}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=DoubleCodec, value=127.0}, 1=Parameter{codec=DoubleCodec, value=-128.0}, 2=Parameter{codec=DoubleCodec, value=0.0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=DoubleCodec}, 1=BindValue{codec=DoubleCodec}, 2=BindValue{codec=DoubleCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("127"), Optional.of("-128"), Optional.of("0")); } @@ -285,11 +270,9 @@ private void shortValue(MariadbConnection connection) { .bind(2, Short.valueOf("0")); Assertions.assertTrue( stmt.toString() - .contains( - "Parameter{codec=ShortCodec, value=-1}, Parameter{codec=ShortCodec, value=0}]") - || stmt.toString() - .contains( - "1=Parameter{codec=ShortCodec, value=-1}, 2=Parameter{codec=ShortCodec, value=0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=ShortCodec}, 1=BindValue{codec=ShortCodec}, 2=BindValue{codec=ShortCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("1"), Optional.of("-1"), Optional.of("0")); } @@ -313,11 +296,9 @@ private void longValue(MariadbConnection connection) { .bind(2, 0L); Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=LongCodec, value=1}, Parameter{codec=LongCodec, value=-1}, Parameter{codec=LongCodec, value=0}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=LongCodec, value=1}, 1=Parameter{codec=LongCodec, value=-1}, 2=Parameter{codec=LongCodec, value=0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=LongCodec}, 1=BindValue{codec=LongCodec}, 2=BindValue{codec=LongCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("1"), Optional.of("-1"), Optional.of("0")); } @@ -341,11 +322,9 @@ private void LongValue(MariadbConnection connection) { .bind(2, Long.valueOf("0")); Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=LongCodec, value=1}, Parameter{codec=LongCodec, value=-1}, Parameter{codec=LongCodec, value=0}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=LongCodec, value=1}, 1=Parameter{codec=LongCodec, value=-1}, 2=Parameter{codec=LongCodec, value=0}}")); + .contains( + "bindings=[Binding{binds={0=BindValue{codec=LongCodec}, 1=BindValue{codec=LongCodec}, 2=BindValue{codec=LongCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); validate(Optional.of("1"), Optional.of("-1"), Optional.of("0")); @@ -364,7 +343,10 @@ private void localDateTimeValue(MariadbConnection connection) { .bind(1, LocalDateTime.now()) .bind(2, LocalDateTime.now()); Assertions.assertTrue( - stmt.toString().contains("parameters=[Parameter{codec=LocalDateTimeCodec, value=")); + stmt.toString() + .contains( + "bindings=[Binding{binds={0=BindValue{codec=LocalDateTimeCodec}, 1=BindValue{codec=LocalDateTimeCodec}, 2=BindValue{codec=LocalDateTimeCodec}}}]"), + stmt.toString()); stmt.execute() .flatMap(r -> r.getRowsUpdated()) .as(StepVerifier::create) @@ -409,8 +391,12 @@ private void localDateValue(MariadbConnection connection) { .bind(0, LocalDate.now()) .bind(1, LocalDate.now()) .bind(2, LocalDate.now()); + Assertions.assertTrue( - stmt.toString().contains("parameters=[Parameter{codec=LocalDateCodec, value=")); + stmt.toString() + .contains( + "bindings=[Binding{binds={0=BindValue{codec=LocalDateCodec}, 1=BindValue{codec=LocalDateCodec}, 2=BindValue{codec=LocalDateCodec}}}]"), + stmt.toString()); stmt.execute() .flatMap(r -> r.getRowsUpdated()) .as(StepVerifier::create) @@ -435,8 +421,10 @@ void localTimeValuePrepare() { .bind(1, LocalTime.now()) .bind(2, LocalTime.now()); Assertions.assertTrue( - stmt.toString().contains("parameters=[Parameter{codec=LocalTimeCodec, value=") - || stmt.toString().contains("parameters={0=Parameter{codec=LocalTimeCodec, value=")); + stmt.toString() + .contains( + "bindings=[Binding{binds={0=BindValue{codec=LocalTimeCodec}, 1=BindValue{codec=LocalTimeCodec}, 2=BindValue{codec=LocalTimeCodec}}}]"), + stmt.toString()); stmt.execute().blockLast(); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/BitParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/BitParameterTest.java index f3377a63..fcc26925 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/BitParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/BitParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -13,11 +13,7 @@ import java.time.LocalTime; import java.util.BitSet; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbResult; @@ -73,6 +69,7 @@ void booleanValue() { @Test void booleanValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); booleanValue(sharedConnPrepare); } @@ -86,12 +83,8 @@ private void booleanValue(MariadbConnection connection) { Assertions.assertTrue( stmt.toString() - .contains( - "parameters=[Parameter{codec=BooleanCodec, value=true}, Parameter{codec=BooleanCodec, value=true}, Parameter{codec=BooleanCodec, value=false}]") - || stmt.toString() - .contains( - "parameters={0=Parameter{codec=BooleanCodec, value=true}, 1=Parameter{codec=BooleanCodec, " - + "value=true}, 2=Parameter{codec=BooleanCodec, value=false}}"), + .contains( + "bindings=[Binding{binds={0=BindValue{codec=BooleanCodec}, 1=BindValue{codec=BooleanCodec}, 2=BindValue{codec=BooleanCodec}}}]"), stmt.toString()); stmt.execute().blockLast(); validate( @@ -107,6 +100,7 @@ void bigIntValue() { @Test void bigIntValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); bigIntValue(sharedConnPrepare); } @@ -155,6 +149,7 @@ void decimalValue() { @Test void decimalValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); decimalValue(sharedConnPrepare); } @@ -179,6 +174,7 @@ void intValue() { @Test void intValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); intValue(sharedConnPrepare); } @@ -203,6 +199,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -234,13 +231,29 @@ private void blobValue(MariadbConnection connection) { connection .createStatement("INSERT INTO ByteParam VALUES (?,?,?)") .bind(0, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {(byte) 15})))) - .bind(1, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {(byte) 1, 0, (byte) 127})))) + .bind(1, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {(byte) 1, 2})))) .bind(2, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {0})))) .execute() .blockLast(); + + validate( + Optional.of(BitSet.valueOf(new byte[] {(byte) 15})), + Optional.of(BitSet.valueOf(new byte[] {(byte) 2, (byte) 1})), + Optional.of(BitSet.valueOf(new byte[] {(byte) 0}))); + + sharedConn.createStatement("TRUNCATE TABLE ByteParam").execute().blockLast(); + + connection + .createStatement("INSERT INTO ByteParam VALUES (?,?,?)") + .bind(0, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {(byte) 15})))) + .bind(1, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {(byte) 1, 2})))) + .bind(2, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {0})))) + .execute() + .blockLast(); + validate( Optional.of(BitSet.valueOf(new byte[] {(byte) 15})), - Optional.of(BitSet.valueOf(new byte[] {(byte) 127, 0, (byte) 1})), + Optional.of(BitSet.valueOf(new byte[] {(byte) 2, (byte) 1})), Optional.of(BitSet.valueOf(new byte[] {(byte) 0}))); } @@ -275,6 +288,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); } @@ -323,6 +337,7 @@ void longValue() { @Test void longValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); longValue(sharedConnPrepare); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/BlobParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/BlobParameterTest.java index 8488bbef..6118f537 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/BlobParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/BlobParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -15,11 +15,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.MariadbConnectionFactory; @@ -32,7 +28,7 @@ import reactor.test.StepVerifier; public class BlobParameterTest extends BaseConnectionTest { - private static MariadbConnectionMetadata meta = sharedConn.getMetadata(); + private static final MariadbConnectionMetadata meta = sharedConn.getMetadata(); @BeforeAll public static void before2() { @@ -81,6 +77,7 @@ void bigIntValue() { @Test void bigIntValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); bigIntValue(sharedConnPrepare); } @@ -129,6 +126,7 @@ void decimalValue() { @Test void decimalValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); decimalValue(sharedConnPrepare); } @@ -153,6 +151,7 @@ void intValue() { @Test void intValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); intValue(sharedConnPrepare); } @@ -177,6 +176,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -214,7 +214,6 @@ private void blobValue(MariadbConnection connection) { Blob.from( Mono.just(ByteBuffer.wrap(new byte[] {(byte) 1, 0, (byte) 127, (byte) 92})))) .bind(2, Blob.from(Mono.just(ByteBuffer.wrap(new byte[] {0})))); - Assertions.assertTrue(stmt.toString().contains("Parameter{codec=BlobCodec,")); stmt.execute().blockLast(); validateNotNull( ByteBuffer.wrap(new byte[] {(byte) 15}), @@ -239,7 +238,6 @@ private void streamValue(MariadbConnection connection) { .bind(0, new ByteArrayInputStream(new byte[] {(byte) 15})) .bind(1, new ByteArrayInputStream(new byte[] {(byte) 1, 0, (byte) 127})) .bind(2, new ByteArrayInputStream(new byte[] {0})); - Assertions.assertTrue(stmt.toString().contains("Parameter{codec=StreamCodec,")); stmt.execute().blockLast(); validateNotNull( ByteBuffer.wrap(new byte[] {(byte) 15}), @@ -290,7 +288,6 @@ private void inputStreamValue(MariadbConnection connection) { .bind(0, new ByteArrayInputStream(new byte[] {(byte) 15})) .bind(1, new ByteArrayInputStream((new byte[] {(byte) 1, 39, (byte) 127}))) .bind(2, new ByteArrayInputStream((new byte[] {0}))); - Assertions.assertTrue(stmt.toString().contains("Parameter{codec=StreamCodec,")); stmt.execute().blockLast(); validateNotNull( ByteBuffer.wrap(new byte[] {(byte) 15}), @@ -337,6 +334,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); validateNotNull( ByteBuffer.wrap("11".getBytes()), diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/DateParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/DateParameterTest.java index 534d10cf..a8459be3 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/DateParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/DateParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -10,10 +10,7 @@ import java.time.LocalDateTime; import java.time.LocalTime; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbResult; @@ -87,7 +84,10 @@ private void bigIntValue(MariadbConnection connection) { .expectErrorMatches( throwable -> throwable instanceof R2dbcBadGrammarException - && ((R2dbcBadGrammarException) throwable).getSqlState().equals("22007")) + && ((R2dbcBadGrammarException) throwable).getSqlState().equals("22007") + && ((R2dbcBadGrammarException) throwable) + .getSql() + .equals("INSERT INTO DateParam VALUES (?,?,?)")) .verify(); } } @@ -163,6 +163,7 @@ void intValue() { @Test void intValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); intValue(sharedConnPrepare); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/DateTimeParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/DateTimeParameterTest.java index 7ec2aed7..3066fa82 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/DateTimeParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/DateTimeParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -10,10 +10,7 @@ import java.time.LocalDateTime; import java.time.LocalTime; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbConnectionMetadata; @@ -168,6 +165,7 @@ void intValue() { @Test void intValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); intValue(sharedConnPrepare); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/DecimalParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/DecimalParameterTest.java index 14730359..3633d8bf 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/DecimalParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/DecimalParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -13,10 +13,7 @@ import java.time.LocalTime; import java.util.Arrays; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import reactor.core.publisher.Flux; @@ -166,6 +163,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -214,6 +212,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/FloatParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/FloatParameterTest.java index 186d7482..156f9bba 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/FloatParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/FloatParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/IntParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/IntParameterTest.java index 8c1ec477..bd61deae 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/IntParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/IntParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/MediumIntParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/MediumIntParameterTest.java index 2205e1be..906f5b00 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/MediumIntParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/MediumIntParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/ShortParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/ShortParameterTest.java index 900ec483..b4437d86 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/ShortParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/ShortParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/StringParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/StringParameterTest.java index b62fd884..ab27747a 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/StringParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/StringParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -15,11 +15,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.MariadbConnectionConfiguration; import org.mariadb.r2dbc.MariadbConnectionFactory; @@ -32,7 +28,7 @@ import reactor.test.StepVerifier; public class StringParameterTest extends BaseConnectionTest { - private static MariadbConnectionMetadata meta = sharedConn.getMetadata(); + private static final MariadbConnectionMetadata meta = sharedConn.getMetadata(); @BeforeAll public static void before2() { @@ -68,7 +64,7 @@ private void nullValue(MariadbConnection connection) { .createStatement("INSERT INTO StringParam VALUES (?,?,?)") .bindNull(0, BigInteger.class) .bindNull(1, BigInteger.class) - .bindNull(2, null) + .bindNull(2, Short.class) .execute() .blockLast(); validate(Optional.empty(), Optional.empty(), Optional.empty()); @@ -105,9 +101,6 @@ private void bitSetValue(MariadbConnection connection) { .bind(0, BitSet.valueOf(revertOrder("çà¤\\".getBytes(StandardCharsets.UTF_8)))) .bind(1, BitSet.valueOf(revertOrder("你好".getBytes(StandardCharsets.UTF_8)))) .bind(2, BitSet.valueOf(revertOrder("🌟hello\\".getBytes(StandardCharsets.UTF_8)))); - Assertions.assertTrue( - stmt.toString().contains("parameters=[Parameter{codec=BitSetCodec, value={") - || stmt.toString().contains("parameters={0=Parameter{codec=BitSetCodec, value={")); stmt.execute().blockLast(); validate(Optional.of("çà¤\\"), Optional.of("你好"), Optional.of("🌟hello\\")); } @@ -130,9 +123,6 @@ private void byteArrayValue(MariadbConnection connection) { .bind(0, "çà¤\\".getBytes(StandardCharsets.UTF_8)) .bind(1, "你好".getBytes(StandardCharsets.UTF_8)) .bind(2, "🌟hello\\".getBytes(StandardCharsets.UTF_8)); - Assertions.assertTrue( - stmt.toString().contains("parameters=[Parameter{codec=ByteArrayCodec, value=") - || stmt.toString().contains("parameters={0=Parameter{codec=ByteArrayCodec, value=")); stmt.execute().blockLast(); validate(Optional.of("çà¤\\"), Optional.of("你好"), Optional.of("🌟hello\\")); } @@ -165,6 +155,7 @@ void stringValue() { @Test void stringValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); stringValue(sharedConnPrepare); } @@ -218,7 +209,6 @@ private void clobValue(MariadbConnection connection) { .bind(0, Clob.from(Mono.just("123"))) .bind(1, Clob.from(Mono.just("你好"))) .bind(2, Clob.from(Mono.just("🌟hello\\"))); - Assertions.assertTrue(stmt.toString().contains("Parameter{codec=ClobCodec,")); stmt.execute().blockLast(); validate(Optional.of("123"), Optional.of("你好"), Optional.of("🌟hello\\")); } @@ -230,6 +220,7 @@ void decimalValue() { @Test void decimalValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); decimalValue(sharedConnPrepare); } @@ -251,6 +242,7 @@ void intValue() { @Test void intValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); intValue(sharedConnPrepare); } @@ -272,6 +264,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -316,6 +309,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); validate(Optional.of("127"), Optional.of("-128"), Optional.of("0")); } @@ -358,6 +352,7 @@ void longValue() { @Test void longValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); longValue(sharedConnPrepare); } @@ -477,7 +472,6 @@ private void durationValue(MariadbConnection connection) { .bind(0, Duration.parse("P3DT18H0.012340S")) .bind(1, Duration.parse("PT8M")) .bind(2, Duration.parse("PT22S")); - Assertions.assertTrue(stmt.toString().contains("Parameter{codec=DurationCodec,")); stmt.execute().blockLast(); } diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeParameterTest.java index b0b36513..711fb1c0 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -11,10 +11,7 @@ import java.time.LocalDateTime; import java.time.LocalTime; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbResult; @@ -176,9 +173,9 @@ private void intValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT1S")), - Optional.of(Duration.parse("PT-1S")), - Optional.of(Duration.parse("PT0M"))); + Optional.of(LocalTime.parse("00:00:01")), + Optional.of(LocalTime.parse("23:59:59")), + Optional.of(LocalTime.parse("00:00:00"))); } @Test @@ -188,6 +185,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -200,9 +198,9 @@ private void byteValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT1M27S")), - Optional.of(Duration.parse("PT-1M-28S")), - Optional.of(Duration.parse("PT0M"))); + Optional.of(LocalTime.parse("00:01:27")), + Optional.of(LocalTime.parse("23:58:32")), + Optional.of(LocalTime.parse("00:00:00"))); } @Test @@ -224,9 +222,9 @@ private void floatValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT1M27S")), - Optional.of(Duration.parse("PT-1M-28S")), - Optional.of(Duration.parse("PT0M"))); + Optional.of(LocalTime.parse("00:01:27")), + Optional.of(LocalTime.parse("23:58:32")), + Optional.of(LocalTime.parse("00:00:00"))); } @Test @@ -236,6 +234,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); } @@ -248,9 +247,9 @@ private void doubleValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT1M27S")), - Optional.of(Duration.parse("PT1M28S")), - Optional.of(Duration.parse("PT0M"))); + Optional.of(LocalTime.parse("00:01:27")), + Optional.of(LocalTime.parse("00:01:28")), + Optional.of(LocalTime.parse("00:00:00"))); } @Test @@ -272,9 +271,9 @@ private void shortValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT1S")), - Optional.of(Duration.parse("PT-1S")), - Optional.of(Duration.parse("PT0M"))); + Optional.of(LocalTime.parse("00:00:01")), + Optional.of(LocalTime.parse("23:59:59")), + Optional.of(LocalTime.parse("00:00:00"))); } @Test @@ -296,9 +295,9 @@ private void longValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT1S")), - Optional.of(Duration.parse("PT-1S")), - Optional.of(Duration.parse("PT0M"))); + Optional.of(LocalTime.parse("00:00:01")), + Optional.of(LocalTime.parse("23:59:59")), + Optional.of(LocalTime.parse("00:00:00"))); } @Test @@ -320,9 +319,9 @@ private void localDateTimeValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT5H8M9.0014S")), - Optional.of(Duration.parse("PT5H8M10.123456S")), - Optional.of(Duration.parse("PT5H8M11.123S"))); + Optional.of(LocalTime.parse("05:08:09.0014")), + Optional.of(LocalTime.parse("05:08:10.123456")), + Optional.of(LocalTime.parse("05:08:11.123"))); } @Test @@ -388,14 +387,14 @@ private void durationValue(MariadbConnection connection) { connection .createStatement("INSERT INTO TimeParam VALUES (?,?,?)") .bind(0, Duration.parse("PT5H8M9.0014S")) - .bind(1, Duration.parse("PT-5H8M10S")) + .bind(1, Duration.parse("PT-5H-8M-10S")) .bind(2, Duration.parse("PT5H8M11.123S")) .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT5H8M9.0014S")), - Optional.of(Duration.parse("PT-5H8M10S")), - Optional.of(Duration.parse("PT5H8M11.123S"))); + Optional.of(LocalTime.parse("05:08:09.0014")), + Optional.of(LocalTime.parse("18:51:50")), + Optional.of(LocalTime.parse("05:08:11.123"))); } private void durationValue2(MariadbConnection connection) { @@ -405,13 +404,13 @@ private void durationValue2(MariadbConnection connection) { .createStatement("INSERT INTO TimeParam VALUES (?,?,?)") .bind(0, Duration.parse("PT0S")) .bind(1, Duration.parse("PT-1.123S")) - .bind(2, Duration.parse("PT-5H8M11.123S")) + .bind(2, Duration.parse("PT-5H-8M-11.123S")) .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT0S")), - Optional.of(Duration.parse("PT-1.123S")), - Optional.of(Duration.parse("PT-5H8M11.123S"))); + Optional.of(LocalTime.parse("00:00:00")), + Optional.of(LocalTime.parse("23:59:58.877")), + Optional.of(LocalTime.parse("18:51:48.877"))); } @Test @@ -433,12 +432,12 @@ private void localTimeValue(MariadbConnection connection) { .execute() .blockLast(); validate( - Optional.of(Duration.parse("PT5H8M9.0014S")), - Optional.of(Duration.parse("PT5H8M10S")), - Optional.of(Duration.parse("PT5H8M11.123S"))); + Optional.of(LocalTime.parse("05:08:09.0014")), + Optional.of(LocalTime.parse("05:08:10")), + Optional.of(LocalTime.parse("05:08:11.123"))); } - private void validate(Optional t1, Optional t2, Optional t3) { + private void validate(Optional t1, Optional t2, Optional t3) { sharedConn .createStatement("SELECT * FROM TimeParam") .execute() @@ -447,7 +446,7 @@ private void validate(Optional t1, Optional t2, Optional Flux.just( - Optional.ofNullable((Duration) row.get(0)), + Optional.ofNullable((LocalTime) row.get(0)), Optional.ofNullable(row.get(1)), Optional.ofNullable(row.get(2))))) .blockLast() diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeStampParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeStampParameterTest.java index f8526e8d..3ca6d67b 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeStampParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/TimeStampParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -396,9 +396,9 @@ void localTimeValuePrepare() { .execute() .blockLast(); validate( - Optional.of(LocalDateTime.parse(LocalDate.now().toString() + "T05:08:10.123456")), - Optional.of(LocalDateTime.parse(LocalDate.now().toString() + "T06:08:15.045500")), - Optional.of(LocalDateTime.parse(LocalDate.now().toString() + "T07:08:10.123000"))); + Optional.of(LocalDateTime.parse(LocalDate.now() + "T05:08:10.123456")), + Optional.of(LocalDateTime.parse(LocalDate.now() + "T06:08:15.045500")), + Optional.of(LocalDateTime.parse(LocalDate.now() + "T07:08:10.123000"))); } private void localTimeValue(MariadbConnection connection) { diff --git a/src/test/java/org/mariadb/r2dbc/integration/parameter/TinyIntParameterTest.java b/src/test/java/org/mariadb/r2dbc/integration/parameter/TinyIntParameterTest.java index e0605609..0db37b25 100644 --- a/src/test/java/org/mariadb/r2dbc/integration/parameter/TinyIntParameterTest.java +++ b/src/test/java/org/mariadb/r2dbc/integration/parameter/TinyIntParameterTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.integration.parameter; @@ -11,10 +11,7 @@ import java.time.LocalDateTime; import java.time.LocalTime; import java.util.Optional; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.mariadb.r2dbc.BaseConnectionTest; import org.mariadb.r2dbc.api.MariadbConnection; import org.mariadb.r2dbc.api.MariadbResult; @@ -68,6 +65,7 @@ void bigIntValue() { @Test void bigIntValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); bigIntValue(sharedConnPrepare); } @@ -164,6 +162,7 @@ void byteValue() { @Test void byteValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); byteValue(sharedConnPrepare); } @@ -212,6 +211,7 @@ void doubleValue() { @Test void doubleValuePrepare() { + Assumptions.assumeFalse(!isMariaDBServer() && minVersion(8, 0, 0)); doubleValue(sharedConnPrepare); } diff --git a/src/test/java/org/mariadb/r2dbc/tools/TcpProxy.java b/src/test/java/org/mariadb/r2dbc/tools/TcpProxy.java index 14132818..7855dd8e 100644 --- a/src/test/java/org/mariadb/r2dbc/tools/TcpProxy.java +++ b/src/test/java/org/mariadb/r2dbc/tools/TcpProxy.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.tools; @@ -13,7 +13,7 @@ public class TcpProxy { private static final Logger logger = LoggerFactory.getLogger(TcpProxy.class); private final String host; - private TcpProxySocket socket; + private final TcpProxySocket socket; /** * Initialise proxy. @@ -53,13 +53,29 @@ public void restart(long sleepTime) { public void forceClose() { socket.sendRst(); + try { + Thread.sleep(5); + } catch (InterruptedException e) { + // eat Exception + } + socket.kill(); + } + + public void restartForce() { + socket.sendRst(); + Executors.newSingleThreadExecutor().execute(socket); + try { + Thread.sleep(5); + } catch (InterruptedException e) { + // eat Exception + } } /** Restart proxy. */ public void restart() { Executors.newSingleThreadExecutor().execute(socket); try { - Thread.sleep(10); + Thread.sleep(5); } catch (InterruptedException e) { // eat Exception } diff --git a/src/test/java/org/mariadb/r2dbc/tools/TcpProxySocket.java b/src/test/java/org/mariadb/r2dbc/tools/TcpProxySocket.java index 6f35ac3d..af451c7b 100644 --- a/src/test/java/org/mariadb/r2dbc/tools/TcpProxySocket.java +++ b/src/test/java/org/mariadb/r2dbc/tools/TcpProxySocket.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.tools; @@ -15,7 +15,7 @@ public class TcpProxySocket implements Runnable { private final String host; private final int remoteport; - private int localport; + private final int localport; private boolean stop = false; private Socket client = null; private Socket server = null; diff --git a/src/test/java/org/mariadb/r2dbc/unit/InitFinalClass.java b/src/test/java/org/mariadb/r2dbc/unit/InitFinalClass.java index 3f9654b9..e85cd798 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/InitFinalClass.java +++ b/src/test/java/org/mariadb/r2dbc/unit/InitFinalClass.java @@ -1,12 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit; import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.codec.Codecs; import org.mariadb.r2dbc.util.BufferUtils; -import org.mariadb.r2dbc.util.PidFactory; +import org.mariadb.r2dbc.util.constants.Capabilities; +import org.mariadb.r2dbc.util.constants.ColumnFlags; +import org.mariadb.r2dbc.util.constants.ServerStatus; +import org.mariadb.r2dbc.util.constants.StateChange; public class InitFinalClass { @@ -14,7 +17,10 @@ public class InitFinalClass { public void init() throws Exception { Codecs codecs = new Codecs(); BufferUtils buf = new BufferUtils(); - PidFactory pid = new PidFactory(); - System.out.println(codecs.hashCode() + buf.hashCode() + pid.hashCode()); + Capabilities c = new Capabilities(); + ColumnFlags c2 = new ColumnFlags(); + ServerStatus c3 = new ServerStatus(); + StateChange c4 = new StateChange(); + System.out.println(codecs.hashCode() + buf.hashCode()); } } diff --git a/src/test/java/org/mariadb/r2dbc/unit/SslModeTest.java b/src/test/java/org/mariadb/r2dbc/unit/SslModeTest.java index d9e9a173..b4d5d9d1 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/SslModeTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/SslModeTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit; diff --git a/src/test/java/org/mariadb/r2dbc/unit/client/HostnameVerifierTest.java b/src/test/java/org/mariadb/r2dbc/unit/client/HostnameVerifierTest.java index bbfdf7b3..710371f0 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/client/HostnameVerifierTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/client/HostnameVerifierTest.java @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2012-2014 Monty Program Ab -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit.client; diff --git a/src/test/java/org/mariadb/r2dbc/unit/client/ServerVersionTest.java b/src/test/java/org/mariadb/r2dbc/unit/client/ServerVersionTest.java index 09a85b9f..8b818841 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/client/ServerVersionTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/client/ServerVersionTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit.client; diff --git a/src/test/java/org/mariadb/r2dbc/unit/util/BufferUtilsTest.java b/src/test/java/org/mariadb/r2dbc/unit/util/BufferUtilsTest.java index 81b6e4d0..31fb72ce 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/util/BufferUtilsTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/util/BufferUtilsTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit.util; @@ -7,10 +7,12 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.r2dbc.spi.IsolationLevel; import java.nio.charset.StandardCharsets; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.mariadb.r2dbc.client.Context; +import org.mariadb.r2dbc.client.SimpleContext; +import org.mariadb.r2dbc.message.Context; import org.mariadb.r2dbc.util.BufferUtils; import org.mariadb.r2dbc.util.constants.ServerStatus; @@ -69,6 +71,7 @@ void skipLengthEncode() { BufferUtils.skipLengthEncode(buf); assertEquals(10, buf.readerIndex()); + buf.release(); } @Test @@ -128,6 +131,7 @@ void readLengthEncodedInt() { buf.writerIndex(1000); assertEquals(-1, BufferUtils.readLengthEncodedInt(buf)); + buf.release(); } @Test @@ -153,6 +157,7 @@ void readLengthEncodedString() { buf.setBytes(0, b); buf.writerIndex(1000); assertEquals("AB", BufferUtils.readLengthEncodedString(buf)); + buf.release(); } @Test @@ -181,23 +186,44 @@ void readLengthEncodedBuffer() { byte[] res = new byte[2]; bb.getBytes(0, res); assertArrayEquals("AB".getBytes(StandardCharsets.UTF_8), res); + buf.release(); } @Test void write() { Context ctxNoBackSlash = - new Context("10.5.5-mariadb", 1, 1, ServerStatus.NO_BACKSLASH_ESCAPES, true); - Context ctx = new Context("10.5.5-mariadb", 1, 1, (short) 0, true); + new SimpleContext( + "10.5.5-mariadb", + 1, + 1, + ServerStatus.NO_BACKSLASH_ESCAPES, + true, + 1, + "testr2", + null, + IsolationLevel.REPEATABLE_READ); + Context ctx = + new SimpleContext( + "10.5.5-mariadb", + 1, + 1, + (short) 0, + true, + 1, + "testr2", + null, + IsolationLevel.REPEATABLE_READ); ByteBuf buf = allocator.buffer(1000); buf.writerIndex(0); - BufferUtils.write(buf, "A'\"\0\\€'\"\0\\", false, ctxNoBackSlash); + byte[] val = "A'\"\0\\€'\"\0\\".getBytes(StandardCharsets.UTF_8); + BufferUtils.escapedBytes(buf, val, val.length, ctxNoBackSlash); byte[] res = new byte[buf.writerIndex()]; buf.getBytes(0, res); assertArrayEquals("A''\"\0\\€''\"\0\\".getBytes(StandardCharsets.UTF_8), res); buf.writerIndex(0); - BufferUtils.write(buf, "A'\"\0\\€'\"\0\\", false, ctx); + BufferUtils.escapedBytes(buf, val, val.length, ctx); res = new byte[buf.writerIndex()]; buf.getBytes(0, res); assertArrayEquals("A\\'\\\"\\\0\\\\€\\'\\\"\\\0\\\\".getBytes(StandardCharsets.UTF_8), res); @@ -210,28 +236,29 @@ void write() { final byte[] utf8Wrong4bytes2 = new byte[] {-16, (byte) -97, (byte) -103}; buf.writerIndex(0); - BufferUtils.write(buf, new String(utf8Wrong2bytes, StandardCharsets.UTF_8), false, ctx); + BufferUtils.escapedBytes(buf, utf8Wrong2bytes, utf8Wrong2bytes.length, ctx); res = new byte[buf.writerIndex()]; buf.getBytes(0, res); - assertArrayEquals(new byte[] {8, -17, -65, -67, 111, 111}, res); + assertArrayEquals(utf8Wrong2bytes, res); buf.writerIndex(0); - BufferUtils.write(buf, new String(utf8Wrong3bytes, StandardCharsets.UTF_8), false, ctx); + BufferUtils.escapedBytes(buf, utf8Wrong3bytes, utf8Wrong3bytes.length, ctx); res = new byte[buf.writerIndex()]; buf.getBytes(0, res); - assertArrayEquals(new byte[] {7, 10, -17, -65, -67, 111, 111}, res); + assertArrayEquals(utf8Wrong3bytes, res); buf.writerIndex(0); - BufferUtils.write(buf, new String(utf8Wrong4bytes, StandardCharsets.UTF_8), false, ctx); + BufferUtils.escapedBytes(buf, utf8Wrong4bytes, utf8Wrong4bytes.length, ctx); res = new byte[buf.writerIndex()]; buf.getBytes(0, res); - assertArrayEquals(new byte[] {16, 32, 10, -17, -65, -67, 111, 111}, res); + assertArrayEquals(utf8Wrong4bytes, res); buf.writerIndex(0); - BufferUtils.write(buf, new String(utf8Wrong4bytes2, StandardCharsets.UTF_8), false, ctx); + BufferUtils.escapedBytes(buf, utf8Wrong4bytes2, utf8Wrong4bytes2.length, ctx); res = new byte[buf.writerIndex()]; buf.getBytes(0, res); - assertArrayEquals(new byte[] {-17, -65, -67}, res); + assertArrayEquals(utf8Wrong4bytes2, res); + buf.release(); } @Test @@ -244,8 +271,8 @@ void toStringBuf() { 0x2e }); buf.readerIndex(0); - System.out.println(buf.readerIndex()); buf.writerIndex(16); Assertions.assertEquals("6D0000000A352E352E352D31302E362E", BufferUtils.toString(buf)); + buf.release(); } } diff --git a/src/test/java/org/mariadb/r2dbc/unit/util/ClientPrepareResultTest.java b/src/test/java/org/mariadb/r2dbc/unit/util/ClientPrepareResultTest.java index 9681da80..d8723233 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/util/ClientPrepareResultTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/util/ClientPrepareResultTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit.util; @@ -85,6 +85,19 @@ public void stringEscapeParsing() throws Exception { }); } + @Test + public void stringReturningParsing() throws Exception { + checkParsing( + "select * from t \t RETURNINGa()", + 0, + 0, + true, + false, + false, + new String[] {"select * from t \t RETURNINGa()"}, + new String[] {"select * from t \t RETURNINGa()"}); + } + @Test public void testRewritableWithConstantParameter() throws Exception { checkParsing( diff --git a/src/test/java/org/mariadb/r2dbc/unit/util/DefaultHostnameVerifierTest.java b/src/test/java/org/mariadb/r2dbc/unit/util/DefaultHostnameVerifierTest.java index ea09c589..96eec692 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/util/DefaultHostnameVerifierTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/util/DefaultHostnameVerifierTest.java @@ -1,5 +1,5 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit.util; diff --git a/src/test/java/org/mariadb/r2dbc/unit/util/HostAddressTest.java b/src/test/java/org/mariadb/r2dbc/unit/util/HostAddressTest.java new file mode 100644 index 00000000..641b38a5 --- /dev/null +++ b/src/test/java/org/mariadb/r2dbc/unit/util/HostAddressTest.java @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.unit.util; + +import io.r2dbc.spi.ConnectionFactoryOptions; +import java.util.List; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mariadb.r2dbc.MariadbConnectionConfiguration; +import org.mariadb.r2dbc.util.HostAddress; + +public class HostAddressTest { + @Test + void parseTest() { + List addresses = HostAddress.parse("host1:3303,host2:3305", 3306); + Assertions.assertEquals(2, addresses.size()); + Assertions.assertEquals(new HostAddress("host1", 3303), addresses.get(0)); + Assertions.assertEquals(new HostAddress("host2", 3305), addresses.get(1)); + + List addresses2 = HostAddress.parse(null, 3303); + Assertions.assertEquals(1, addresses2.size()); + Assertions.assertEquals(new HostAddress("localhost", 3303), addresses2.get(0)); + Assertions.assertNotEquals(addresses.hashCode(), addresses2.hashCode()); + } + + @Test + void parseTestSpiFromOption() { + final ConnectionFactoryOptions option1s = + ConnectionFactoryOptions.builder() + .option(ConnectionFactoryOptions.USER, "someUser") + .option(ConnectionFactoryOptions.HOST, "host1:3303,host2,host3:3305,host4") + .build(); + + MariadbConnectionConfiguration conf = + MariadbConnectionConfiguration.fromOptions(option1s).build(); + Assertions.assertEquals(4, conf.getHostAddresses().size()); + Assertions.assertEquals(new HostAddress("host1", 3303), conf.getHostAddresses().get(0)); + Assertions.assertEquals(new HostAddress("host2", 3306), conf.getHostAddresses().get(1)); + Assertions.assertEquals(new HostAddress("host3", 3305), conf.getHostAddresses().get(2)); + Assertions.assertEquals(new HostAddress("host4", 3306), conf.getHostAddresses().get(3)); + + final ConnectionFactoryOptions option2s = + ConnectionFactoryOptions.builder() + .option(ConnectionFactoryOptions.USER, "someUser") + .option(ConnectionFactoryOptions.HOST, "host1:3303,host2:3305") + .option(ConnectionFactoryOptions.PORT, 3307) + .build(); + + conf = MariadbConnectionConfiguration.fromOptions(option2s).build(); + Assertions.assertEquals(2, conf.getHostAddresses().size()); + Assertions.assertEquals(new HostAddress("host1", 3303), conf.getHostAddresses().get(0)); + Assertions.assertEquals(new HostAddress("host2", 3305), conf.getHostAddresses().get(1)); + + final ConnectionFactoryOptions option3s = + ConnectionFactoryOptions.builder() + .option(ConnectionFactoryOptions.USER, "someUser") + .option(ConnectionFactoryOptions.HOST, "host1:3303,host2,host3:3309") + .option(ConnectionFactoryOptions.PORT, 3307) + .build(); + + conf = MariadbConnectionConfiguration.fromOptions(option3s).build(); + Assertions.assertEquals(3, conf.getHostAddresses().size()); + Assertions.assertEquals(new HostAddress("host1", 3303), conf.getHostAddresses().get(0)); + Assertions.assertEquals(new HostAddress("host2", 3307), conf.getHostAddresses().get(1)); + Assertions.assertEquals(new HostAddress("host3", 3309), conf.getHostAddresses().get(2)); + Assertions.assertEquals("host3:3309", conf.getHostAddresses().get(2).toString()); + } + + @Test + void parseTestSpiFromString() { + + final ConnectionFactoryOptions option1s = + ConnectionFactoryOptions.parse( + "r2dbc:mariadb://someUser:pwd@host1:3303,host2,host3:3305,host4/"); + + MariadbConnectionConfiguration conf = + MariadbConnectionConfiguration.fromOptions(option1s).build(); + Assertions.assertEquals(4, conf.getHostAddresses().size()); + Assertions.assertEquals(new HostAddress("host1", 3303), conf.getHostAddresses().get(0)); + Assertions.assertEquals(new HostAddress("host2", 3306), conf.getHostAddresses().get(1)); + Assertions.assertEquals(new HostAddress("host3", 3305), conf.getHostAddresses().get(2)); + Assertions.assertEquals(new HostAddress("host4", 3306), conf.getHostAddresses().get(3)); + + final ConnectionFactoryOptions option2s = + ConnectionFactoryOptions.parse("r2dbc:mariadb://someUser:pwd@host1:3303,host2:3305/"); + + conf = MariadbConnectionConfiguration.fromOptions(option2s).build(); + Assertions.assertEquals(2, conf.getHostAddresses().size()); + Assertions.assertEquals(new HostAddress("host1", 3303), conf.getHostAddresses().get(0)); + Assertions.assertEquals(new HostAddress("host2", 3305), conf.getHostAddresses().get(1)); + + final ConnectionFactoryOptions option3s = + ConnectionFactoryOptions.parse("r2dbc:mariadb://someUser:pwd@host1:3303,host2,host3:3309/"); + conf = MariadbConnectionConfiguration.fromOptions(option3s).build(); + Assertions.assertEquals(3, conf.getHostAddresses().size()); + Assertions.assertEquals(new HostAddress("host1", 3303), conf.getHostAddresses().get(0)); + Assertions.assertEquals(new HostAddress("host2", 3306), conf.getHostAddresses().get(1)); + Assertions.assertEquals(new HostAddress("host3", 3309), conf.getHostAddresses().get(2)); + Assertions.assertEquals("host3:3309", conf.getHostAddresses().get(2).toString()); + } +} diff --git a/src/test/java/org/mariadb/r2dbc/unit/util/MariadbTypeTest.java b/src/test/java/org/mariadb/r2dbc/unit/util/MariadbTypeTest.java new file mode 100644 index 00000000..76d28eed --- /dev/null +++ b/src/test/java/org/mariadb/r2dbc/unit/util/MariadbTypeTest.java @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2020-2022 MariaDB Corporation Ab + +package org.mariadb.r2dbc.unit.util; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mariadb.r2dbc.util.MariadbType; + +public class MariadbTypeTest { + + @Test + void getName() { + Assertions.assertEquals("VARCHAR", MariadbType.VARCHAR.getName()); + Assertions.assertEquals("BIGINT", MariadbType.UNSIGNED_BIGINT.getName()); + Assertions.assertEquals("UNSIGNED_BIGINT", MariadbType.UNSIGNED_BIGINT.name()); + } +} diff --git a/src/test/java/org/mariadb/r2dbc/unit/util/ServerPrepareResultTest.java b/src/test/java/org/mariadb/r2dbc/unit/util/ServerPrepareResultTest.java index b8c01201..cb1c5e25 100644 --- a/src/test/java/org/mariadb/r2dbc/unit/util/ServerPrepareResultTest.java +++ b/src/test/java/org/mariadb/r2dbc/unit/util/ServerPrepareResultTest.java @@ -1,11 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 -// Copyright (c) 2020-2021 MariaDB Corporation Ab +// Copyright (c) 2020-2022 MariaDB Corporation Ab package org.mariadb.r2dbc.unit.util; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.mariadb.r2dbc.message.server.ColumnDefinitionPacket; +import org.mariadb.r2dbc.util.ServerNamedParamParser; import org.mariadb.r2dbc.util.ServerPrepareResult; public class ServerPrepareResultTest { @@ -20,4 +21,202 @@ public void equalsTest() throws Exception { Assertions.assertFalse(prepare.equals(null)); Assertions.assertFalse(prepare.equals("dd")); } + + @Test + public void stringReturningParsing() throws Exception { + checkParsing("select * from t \t RETURNINGa()", 0, 0); + } + + @Test + public void stringEscapeParsing() throws Exception { + checkParsing( + "select '\\'\"`/*#' as a, ? as \\b, \"\\\"'returningInsertDeleteUpdate\" as c, ? as d", + 2, + 1); + } + + @Test + public void testRewritableWithConstantParameter() throws Exception { + checkParsing( + "INSERT INTO TABLE_INSERT(col1,col2,col3,col4, col5) VALUES (9, ?, 5, ?, 8) ON DUPLICATE KEY UPDATE col2=col2+10", + 2, + 2); + } + + @Test + public void testNamedParam() throws Exception { + checkParsing( + "SELECT * FROM TABLE WHERE 1 = :firstParam AND 3 = ':para' and 2 = :secondParam", 2, 2); + } + + @Test + public void stringEscapeParsing2() throws Exception { + checkParsing("SELECT '\\\\test' /*test* #/ ;`*/", 0, 0); + } + + @Test + public void stringEscapeParsing3() throws Exception { + checkParsing("DO '\\\"', \"\\'\"", 0, 0); + } + + @Test + public void testComment() throws Exception { + checkParsing( + "/* insert Select INSERT INTO tt VALUES (?,?,?,?) insert update delete select returning */" + + " INSERT into " + + "/* insert Select INSERT INTO tt VALUES (?,?,?,?) */" + + " tt VALUES " + + "/* insert Select INSERT INTO tt VALUES (?,?,?,?) */" + + " (?) " + + "/* insert Select INSERT INTO tt VALUES (?,?,?,?) */", + 1, + 1); + } + + @Test + public void testRewritableWithConstantParameterAndParamAfterValue() throws Exception { + checkParsing( + "INSERT INTO TABLE(col1,col2,col3,col4, col5) VALUES (9, ?, 5, ?, 8) ON DUPLICATE KEY UPDATE col2=?", + 3, + 3); + } + + @Test + public void testRewritableMultipleInserts() throws Exception { + checkParsing("INSERT INTO TABLE(col1,col2) VALUES (?, ?), (?, ?)", 4, 4); + } + + @Test + public void testCall() throws Exception { + checkParsing("CALL dsdssd(?,?)", 2, 2); + } + + @Test + public void testUpdate() throws Exception { + checkParsing("UPDATE MultiTestt4 SET test = ? WHERE test = ?", 2, 2); + } + + @Test + public void testUpdate2() throws Exception { + checkParsing("UPDATE UpdateMultiTestt4UPDATE() SET test = ? WHERE test = ?", 2, 2); + } + + @Test + public void testDelete() throws Exception { + checkParsing("DELETE FROM MultiTestt4 WHERE test = ?", 1, 1); + } + + @Test + public void testDelete2() throws Exception { + checkParsing("DELETE FROM DELETEMultiTestt4DELETE WHERE test = ?", 1, 1); + } + + @Test + public void testInsertSelect() throws Exception { + checkParsing( + "insert into test_insert_select ( field1) (select TMP.field1 from " + + "(select CAST(? as binary) `field1` from dual) TMP)", + 1, + 1); + } + + @Test + public void testWithoutParameter() throws Exception { + checkParsing("SELECT testFunction()", 0, 0); + } + + @Test + public void testWithoutParameterAndParenthesis() throws Exception { + checkParsing("SELECT 1", 0, 0); + } + + @Test + public void testWithoutParameterAndValues() throws Exception { + checkParsing("INSERT INTO tt VALUES (1)", 0, 0); + } + + @Test + public void testSemiColon() throws Exception { + checkParsing("INSERT INTO tt (tt) VALUES (?); INSERT INTO tt (tt) VALUES ('multiple')", 1, 1); + } + + @Test + public void testSemicolonRewritableIfAtEnd() throws Exception { + checkParsing("INSERT INTO table (column1) VALUES (?); ", 1, 1); + } + + @Test + public void testSemicolonNotRewritableIfNotAtEnd() throws Exception { + checkParsing("INSERT INTO table (column1) VALUES (?); SELECT 1", 1, 1); + } + + @Test + public void testError() throws Exception { + checkParsing("INSERT INTO tt (tt) VALUES (?); INSERT INTO tt (tt) VALUES ('multiple')", 1, 1); + } + + @Test + public void testLineComment() throws Exception { + checkParsing("INSERT INTO tt (tt) VALUES (?) --fin", 1, 1); + } + + @Test + public void testEscapeInString() throws Exception { + checkParsing("INSERT INTO tt (tt) VALUES (?, '\\'?', \"\\\"?\") --fin", 1, 2); + } + + @Test + public void testEol() throws Exception { + checkParsing("INSERT INTO tt (tt) VALUES (?, //test \n ?)", 2, 2); + } + + @Test + public void testLineCommentFinished() throws Exception { + checkParsing("INSERT INTO tt (tt) VALUES --fin\n (?)", 1, 1); + } + + @Test + public void testSelect1() throws Exception { + checkParsing("SELECT 1", 0, 0); + } + + @Test + public void testLastInsertId() throws Exception { + checkParsing("INSERT INTO tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?)", 1, 1); + } + + @Test + public void testReturning() throws Exception { + checkParsing( + "INSERT INTO tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?) # test \n RETURNING ID", 1, 1); + checkParsing("INSERT INTO tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?, _RETURNING)", 1, 1); + checkParsing("INSERT INTO tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?, _RETURNING)", 1, 1); + checkParsing("DELETE tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?) RETURNING ID", 1, 1); + checkParsing("DELETE tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?, _RETURNING)", 1, 1); + checkParsing("DELETE tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?, _RETURNING)", 1, 1); + checkParsing("UPDATE tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?) RETURNING ID", 1, 1); + checkParsing("UPDATE tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?, _RETURNING)", 1, 1); + checkParsing("UPDATE tt (tt, tt2) VALUES (LAST_INSERT_ID(), ?, _RETURNING)", 1, 1); + } + + @Test + public void testValuesForPartition() throws Exception { + checkParsing( + "ALTER table test_partitioning PARTITION BY RANGE COLUMNS( created_at ) " + + "(PARTITION test_p201605 VALUES LESS THAN ('2016-06-01', '\"', \"'\"))", + 0, + 0); + } + + @Test + public void testEolskip() throws Exception { + checkParsing("CREATE TABLE tt \n # test \n(ID INT)", 0, 0); + } + + private void checkParsing(String sql, int paramNumber, int paramNumberBackSlash) { + ServerNamedParamParser res = ServerNamedParamParser.parameterParts(sql, false); + Assertions.assertEquals(paramNumber, res.getParamCount()); + res = ServerNamedParamParser.parameterParts(sql, true); + Assertions.assertEquals(paramNumberBackSlash, res.getParamCount()); + } } diff --git a/src/test/resources/conf.properties b/src/test/resources/conf.properties index 68f3b352..31d3e887 100644 --- a/src/test/resources/conf.properties +++ b/src/test/resources/conf.properties @@ -3,4 +3,4 @@ DB_PORT=3306 DB_DATABASE=testr2 DB_USER=root DB_PASSWORD= -DB_OTHER= \ No newline at end of file +DB_OTHER= diff --git a/src/test/resources/logback-test.xml b/src/test/resources/logback-test.xml index 1e2ab522..b15256ac 100644 --- a/src/test/resources/logback-test.xml +++ b/src/test/resources/logback-test.xml @@ -1,7 +1,7 @@