Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom middleware (eg. tracing) #14

Merged
merged 12 commits into from
Feb 21, 2024
5 changes: 1 addition & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[workspace]
resolver = "2"
members = [
"jwt",
"snowflake-api"
]
members = ["jwt", "snowflake-api", "snowflake-api/examples/tracing"]
2 changes: 2 additions & 0 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ version = "0.6.0"
[features]
cert-auth = ["dep:snowflake-jwt"]


[dependencies]
arrow = "50"
async-trait = "0.1"
Expand All @@ -37,6 +38,7 @@ thiserror = "1"
url = "2"
uuid = { version = "1", features = ["v4"] }


[dev-dependencies]
anyhow = "1"
arrow = { version = "50", features = ["prettyprint"] }
Expand Down
24 changes: 24 additions & 0 deletions snowflake-api/examples/tracing/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "snowflake-rust-tracing"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
anyhow = "1.0.79"
arrow = { version = "50.0.0", features = ["prettyprint"] }
dotenv = "0.15.0"
snowflake-api = { path = "../../../snowflake-api" }


tokio = { version = "1.35.1", features = ["full"] }
tracing = "0.1.40"
tracing-subscriber = "0.3"
# use the same version of opentelemetry as the one used by snowflake-api
tracing-opentelemetry = "0.22"
opentelemetry-otlp = "*"
opentelemetry = "0.21"
opentelemetry_sdk = { version = "0.21", features = ["rt-tokio"] }
reqwest-tracing = { version = "0.4", features = ["opentelemetry_0_21"] }
reqwest-middleware = { version = "*" }
78 changes: 78 additions & 0 deletions snowflake-api/examples/tracing/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use anyhow::Result;
use arrow::util::pretty::pretty_format_batches;
use opentelemetry::global;
use opentelemetry_otlp::WithExportConfig;

use snowflake_api::connection::Connection;
use snowflake_api::{AuthArgs, AuthType, PasswordArgs, QueryResult, SnowflakeApiBuilder};
use tracing_subscriber::layer::SubscriberExt;

use reqwest_middleware::Extension;
use reqwest_tracing::{OtelName, SpanBackendWithUrl};

#[tokio::main]
async fn main() -> Result<()> {
std::env::set_var("OTEL_SERVICE_NAME", "snowflake-rust-client-demo");

let exporter = opentelemetry_otlp::new_exporter()
.tonic()
.with_endpoint("http://localhost:4319");

let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(exporter)
.install_batch(opentelemetry_sdk::runtime::Tokio)?;

let telemetry = tracing_opentelemetry::layer().with_tracer(tracer.clone());
let subscriber = tracing_subscriber::Registry::default().with(telemetry);
tracing::subscriber::set_global_default(subscriber)?;

dotenv::dotenv().ok();

let auth_args = AuthArgs {
account_identifier: std::env::var("SNOWFLAKE_ACCOUNT").expect("SNOWFLAKE_ACCOUNT not set"),
warehouse: std::env::var("SNOWLFLAKE_WAREHOUSE").ok(),
database: std::env::var("SNOWFLAKE_DATABASE").ok(),
schema: std::env::var("SNOWFLAKE_SCHEMA").ok(),
username: std::env::var("SNOWFLAKE_USER").expect("SNOWFLAKE_USER not set"),
role: std::env::var("SNOWFLAKE_ROLE").ok(),
auth_type: AuthType::Password(PasswordArgs {
password: std::env::var("SNOWFLAKE_PASSWORD").expect("SNOWFLAKE_PASSWORD not set"),
}),
};

let mut client = Connection::default_client_builder()?;
client = client
.with_init(Extension(OtelName(std::borrow::Cow::Borrowed(
"snowflake-api",
))))
.with(reqwest_tracing::TracingMiddleware::<SpanBackendWithUrl>::new());

let builder = SnowflakeApiBuilder::new(auth_args).with_client(client.build());
let api = builder.build()?;

run_in_span(&api).await?;

global::shutdown_tracer_provider();

Ok(())
}

#[tracing::instrument(name = "snowflake_api", skip(api))]
async fn run_in_span(api: &snowflake_api::SnowflakeApi) -> anyhow::Result<()> {
let res = api.exec("select 'hello from snowflake' as col1;").await?;

match res {
QueryResult::Arrow(a) => {
println!("{}", pretty_format_batches(&a).unwrap());
}
QueryResult::Json(j) => {
println!("{}", j);
}
QueryResult::Empty => {
println!("Query finished successfully")
}
}

Ok(())
}
28 changes: 22 additions & 6 deletions snowflake-api/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,28 @@ pub struct Connection {

impl Connection {
pub fn new() -> Result<Self, ConnectionError> {
let client = Self::default_client_builder()?;

Ok(Self::new_with_middware(client.build()))
}

/// Allow a user to provide their own middleware
///
/// Users can provide their own middleware to the connection like this:
/// ```rust
/// use snowflake_api::connection::Connection;
/// let mut client = Connection::default_client_builder();
/// // modify the client builder here
/// let connection = Connection::new_with_middware(client.unwrap().build());
/// ```
/// This is not intended to be called directly, but is used by `SnowflakeApiBuilder::with_client`
pub fn new_with_middware(client: ClientWithMiddleware) -> Self {
Self { client }
}

pub fn default_client_builder() -> Result<reqwest_middleware::ClientBuilder, ConnectionError> {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);

// use builder to fail safely, unlike client new
let client = reqwest::ClientBuilder::new()
.user_agent("Rust/0.0.1")
.gzip(true)
Expand All @@ -90,11 +109,8 @@ impl Connection {

let client = client.build()?;

let client = reqwest_middleware::ClientBuilder::new(client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();

Ok(Self { client })
Ok(reqwest_middleware::ClientBuilder::new(client)
.with(RetryTransientMiddleware::new_with_policy(retry_policy)))
}

/// Perform request of given query type with extra body or parameters
Expand Down
81 changes: 80 additions & 1 deletion snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use object_store::aws::AmazonS3Builder;
use object_store::local::LocalFileSystem;
use object_store::ObjectStore;
use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use crate::connection::{Connection, ConnectionError};
Expand All @@ -36,7 +37,7 @@ use crate::connection::QueryType;
use crate::requests::ExecRequest;
use crate::responses::{AwsPutGetStageInfo, PutGetExecResponse, PutGetStageInfo};

mod connection;
pub mod connection;
mod requests;
mod responses;
mod session;
Expand Down Expand Up @@ -95,6 +96,84 @@ pub enum QueryResult {
Empty,
}

pub struct AuthArgs {
pub account_identifier: String,
pub warehouse: Option<String>,
pub database: Option<String>,
pub schema: Option<String>,
pub username: String,
pub role: Option<String>,
pub auth_type: AuthType,
}

pub enum AuthType {
Password(PasswordArgs),
Certificate(CertificateArgs),
}

pub struct PasswordArgs {
pub password: String,
}

pub struct CertificateArgs {
pub private_key_pem: String,
}

#[must_use]
pub struct SnowflakeApiBuilder {
pub auth: AuthArgs,
client: Option<ClientWithMiddleware>,
}

impl SnowflakeApiBuilder {
pub fn new(auth: AuthArgs) -> Self {
Self { auth, client: None }
}

pub fn with_client(mut self, client: ClientWithMiddleware) -> Self {
self.client = Some(client);
self
}

pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
let connection = match self.client {
Some(client) => Arc::new(Connection::new_with_middware(client)),
None => Arc::new(Connection::new()?),
};

let session = match self.auth.auth_type {
AuthType::Password(args) => Session::password_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
&self.auth.username,
self.auth.role.as_deref(),
&args.password,
),
AuthType::Certificate(args) => Session::cert_auth(
Arc::clone(&connection),
&self.auth.account_identifier,
self.auth.warehouse.as_deref(),
self.auth.database.as_deref(),
self.auth.schema.as_deref(),
&self.auth.username,
self.auth.role.as_deref(),
&args.private_key_pem,
),
};

let account_identifier = self.auth.account_identifier.to_uppercase();

Ok(SnowflakeApi {
connection: Arc::clone(&connection),
session,
account_identifier,
})
}
}

/// Snowflake API, keeps connection pool and manages session for you
pub struct SnowflakeApi {
connection: Arc<Connection>,
Expand Down
Loading