Skip to content

Commit

Permalink
Merge branch 'main' into uuid-regeneration-on-retry
Browse files Browse the repository at this point in the history
  • Loading branch information
wseaton authored Feb 27, 2024
2 parents 05e8a06 + 56dc842 commit 74a60a3
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 9 deletions.
10 changes: 8 additions & 2 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ version = "0.7.0"

[features]
default = ["cert-auth"]
all = ["cert-auth", "polars"]
cert-auth = ["dep:snowflake-jwt"]

# support for conversion of arrow and json payloads to dataframes
polars = ["dep:polars-core", "dep:polars-io"]

[dependencies]
arrow = "50"
Expand All @@ -32,14 +34,18 @@ reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
reqwest-middleware = "0.2"
task-local-extensions = "0.1"
reqwest-retry = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
snowflake-jwt = { version = "0.3.0", optional = true }
thiserror = "1"
url = "2"
uuid = { version = "1", features = ["v4"] }
task-local-extensions = "0.1"

polars-io = { version = ">=0.32", features = ["json", "ipc_streaming"], optional = true}
polars-core = { version = ">=0.32", optional = true}


[dev-dependencies]
anyhow = "1"
Expand Down
2 changes: 1 addition & 1 deletion snowflake-api/examples/tracing/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ tracing = "0.1.40"
tracing-subscriber = "0.3"
# use the same version of opentelemetry as the one used by snowflake-api
tracing-opentelemetry = "0.22"
opentelemetry-otlp = "*"
opentelemetry-otlp = "0.14"
opentelemetry = "0.21"
opentelemetry_sdk = { version = "0.21", features = ["rt-tokio"] }
reqwest-tracing = { version = "0.4", features = ["opentelemetry_0_21"] }
Expand Down
58 changes: 52 additions & 6 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ clippy::future_not_send, // This one seems like something we should eventually f
clippy::missing_panics_doc
)]

use std::fmt::{Display, Formatter};
use std::io;
use std::path::Path;
use std::sync::Arc;
Expand All @@ -37,13 +38,18 @@ use session::{AuthError, Session};

use crate::connection::QueryType;
use crate::requests::ExecRequest;
use crate::responses::{AwsPutGetStageInfo, PutGetExecResponse, PutGetStageInfo};
use crate::responses::{
AwsPutGetStageInfo, ExecResponseRowType, PutGetExecResponse, PutGetStageInfo, SnowflakeType,
};

pub mod connection;
mod middleware;

#[cfg(feature = "polars")]
mod polars;
mod requests;
mod responses;
mod session;
mod middleware;

#[derive(Error, Debug)]
pub enum SnowflakeApiError {
Expand Down Expand Up @@ -90,12 +96,49 @@ pub enum SnowflakeApiError {
UnexpectedResponse,
}

/// Even if Arrow is specified as a return type non-select queries
/// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`.
pub struct JsonResult {
// todo: can it _only_ be a json array of arrays or something else too?
pub value: serde_json::Value,
/// Field ordering matches the array ordering
pub schema: Vec<FieldSchema>,
}

impl Display for JsonResult {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.value)
}
}

/// Based on the [`ExecResponseRowType`]
pub struct FieldSchema {
pub name: String,
// todo: is it a good idea to expose internal response struct to the user?
pub type_: SnowflakeType,
pub scale: Option<i64>,
pub precision: Option<i64>,
pub nullable: bool,
}

impl From<ExecResponseRowType> for FieldSchema {
fn from(value: ExecResponseRowType) -> Self {
FieldSchema {
name: value.name,
type_: value.type_,
scale: value.scale,
precision: value.precision,
nullable: value.nullable,
}
}
}

/// Container for query result.
/// Arrow is returned by-default for all SELECT statements,
/// unless there is session configuration issue or it's a different statement type.
pub enum QueryResult {
Arrow(Vec<RecordBatch>),
Json(serde_json::Value),
Json(JsonResult),
Empty,
}

Expand All @@ -107,7 +150,7 @@ pub enum RawQueryResult {
Bytes(Vec<Bytes>),
/// Json payload is deserialized,
/// as it's already a part of REST response
Json(serde_json::Value),
Json(JsonResult),
Empty,
}

Expand Down Expand Up @@ -423,12 +466,15 @@ impl SnowflakeApi {
if resp.data.returned == 0 {
log::debug!("Got response with 0 rows");
Ok(RawQueryResult::Empty)
} else if let Some(json) = resp.data.rowset {
} else if let Some(value) = resp.data.rowset {
log::debug!("Got JSON response");
// NOTE: json response could be chunked too. however, go clients should receive arrow by-default,
// unless user sets session variable to return json. This case was added for debugging and status
// information being passed through that fields.
Ok(RawQueryResult::Json(json))
Ok(RawQueryResult::Json(JsonResult {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
}))
} else if let Some(base64) = resp.data.rowset_base64 {
// fixme: is it possible to give streaming interface?
let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
Expand Down
90 changes: 90 additions & 0 deletions snowflake-api/src/polars.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use std::convert::TryFrom;

use crate::{JsonResult, RawQueryResult};
use bytes::{Buf, Bytes};
use polars_core::frame::DataFrame;
use polars_io::ipc::IpcStreamReader;
use polars_io::json::{JsonFormat, JsonReader};
use polars_io::SerReader;
use serde::de::Error;
use serde_json::{Map, Value};
use thiserror::Error;

#[derive(Error, Debug)]
pub enum PolarsCastError {
#[error(transparent)]
SerdeError(#[from] serde_json::Error),

#[error(transparent)]
PolarsError(#[from] polars_core::error::PolarsError),
}

impl RawQueryResult {
pub fn to_polars(self) -> Result<DataFrame, PolarsCastError> {
match self {
RawQueryResult::Bytes(bytes) => dataframe_from_bytes(bytes),
RawQueryResult::Json(json) => dataframe_from_json(&json),
RawQueryResult::Empty => Ok(DataFrame::empty()),
}
}
}

fn dataframe_from_json(json_result: &JsonResult) -> Result<DataFrame, PolarsCastError> {
let objects = arrays_to_objects(json_result)?;
// fixme: serializing json again, is it possible to keep bytes? or implement casting?
let json_string = serde_json::to_string(&objects)?;
let reader = std::io::Cursor::new(json_string.as_bytes());
let df = JsonReader::new(reader)
.with_json_format(JsonFormat::Json)
.infer_schema_len(Some(5))
.finish()?;
Ok(df)
}

/// This is required because the polars json reader expects an array of objects, and
/// the snowflake json response is an array of arrays (without real column names).
///
/// This is apparent if you run a system query (not a select) like `SHOW DATABASES;`.
fn arrays_to_objects(json_result: &JsonResult) -> Result<Value, PolarsCastError> {
let arrays: &Vec<Value> = json_result
.value
.as_array()
.ok_or(serde_json::Error::custom("Input must be array an array"))?;
let names: Vec<String> = json_result.schema.iter().map(|s| s.name.clone()).collect();

let objects: Result<Vec<Value>, PolarsCastError> = arrays
.iter()
.map(|array| {
array
.as_array()
.ok_or(serde_json::Error::custom("Input must be array of array"))
.map(|array| {
// fixme: lots of copying
let map: Map<String, Value> =
names.clone().into_iter().zip(array.clone()).collect();
Value::Object(map)
})
.map_err(PolarsCastError::SerdeError)
})
.collect();

objects.map(Value::Array)
}

fn dataframe_from_bytes(bytes: Vec<Bytes>) -> Result<DataFrame, PolarsCastError> {
let mut df = DataFrame::empty();
for b in bytes {
let df_chunk = IpcStreamReader::new(b.reader()).finish()?;
df.vstack_mut(&df_chunk)?;
}
df.align_chunks();
Ok(df)
}

impl TryFrom<RawQueryResult> for DataFrame {
type Error = PolarsCastError;

fn try_from(value: RawQueryResult) -> Result<Self, Self::Error> {
value.to_polars()
}
}

0 comments on commit 74a60a3

Please sign in to comment.