diff --git a/quackpipe.go b/quackpipe.go index b4f6b10..9492f24 100644 --- a/quackpipe.go +++ b/quackpipe.go @@ -1,361 +1,336 @@ package main import ( - "bufio" - "database/sql" - _ "embed" - "encoding/json" - "flag" - "fmt" - "io/ioutil" - "log" - "net/http" - "os" - "regexp" - "strings" - "time" - - _ "github.com/marcboeker/go-duckdb" // load duckdb driver + "context" + "crypto/sha256" + "database/sql" + "encoding/base64" + "encoding/csv" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "time" + + _ "github.com/mattn/go-duckdb" ) -//go:embed play.html -var staticPlay string +type Session struct { + DB *sql.DB +} + +var sessionCache sync.Map -//go:embed aliases.sql -var staticAliases string +func main() { + http.HandleFunc("/query", basicAuth(queryHandler)) -// params for Flags -type CommandLineFlags struct { - Host *string `json:"host"` - Port *string `json:"port"` - Stdin *bool `json:"stdin"` - Format *string `json:"format"` - Params *string `json:"params"` + // Existing endpoints for backwards compatibility + http.HandleFunc("/start_session", startSessionHandler) + http.HandleFunc("/close_session", closeSessionHandler) + + fmt.Println("Starting server on :8080") + http.ListenAndServe(":8080", nil) } -var appFlags CommandLineFlags +func basicAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" { + http.Error(w, "authorization required", http.StatusUnauthorized) + return + } + + parts := strings.SplitN(auth, " ", 2) + if len(parts) != 2 || parts[0] != "Basic" { + http.Error(w, "invalid authorization header", http.StatusUnauthorized) + return + } + + payload, _ := base64.StdEncoding.DecodeString(parts[1]) + authStr := string(payload) + sessionID := hashCredentials(authStr) + + session, err := getSession(sessionID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + r.Header.Set("X-Session-ID", sessionID) + r = r.WithContext(context.WithValue(r.Context(), "session", session)) + + next.ServeHTTP(w, r) + } +} -var db *sql.DB +func hashCredentials(auth string) string { + hash := sha256.Sum256([]byte(auth)) + return hex.EncodeToString(hash[:]) +} -func check(args ...interface{}) { - err := args[len(args)-1] - if err != nil { - panic(err) - } +func getSession(sessionID string) (*Session, error) { + if session, ok := sessionCache.Load(sessionID); ok { + return session.(*Session), nil + } + + db, err := sql.Open("duckdb", "") + if err != nil { + return nil, err + } + + session := &Session{DB: db} + sessionCache.Store(sessionID, session) + + return session, nil } -func quack(query string, stdin bool, format string, params string) (string, error) { - var err error +func queryHandler(w http.ResponseWriter, r *http.Request) { + sessionID := r.Header.Get("X-Session-ID") + query := r.URL.Query().Get("query") + format := r.URL.Query().Get("format") - db, err = sql.Open("duckdb", params) - if err != nil { - log.Fatal(err) - } - defer db.Close() + if format == "" { + format = "json" + } - if !stdin { - check(db.Exec("LOAD httpfs; LOAD json; LOAD parquet;")) - } - - if staticAliases != "" { - check(db.Exec(staticAliases)) - } - - startTime := time.Now() - rows, err := db.Query(query) - if err != nil { - return "", err - } - elapsedTime := time.Since(startTime) - - switch format { - case "JSONCompact", "JSON": - return rowsToJSON(rows, elapsedTime) - case "CSVWithNames": - return rowsToCSV(rows, true) - case "TSVWithNames", "TabSeparatedWithNames": - return rowsToTSV(rows, true) - case "TSV", "TabSeparated": - return rowsToTSV(rows, false) - default: - return rowsToTSV(rows, false) - } + session := r.Context().Value("session").(*Session) + result, err := quackWithDB(session.DB, query, format) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Write([]byte(result)) } -// initFlags initializes the command line flags -func initFlags() { - appFlags.Host = flag.String("host", "0.0.0.0", "API host. Default 0.0.0.0") - appFlags.Port = flag.String("port", "8123", "API port. Default 8123") - appFlags.Format = flag.String("format", "JSONCompact", "API port. Default JSONCompact") - appFlags.Params = flag.String("params", "", "DuckDB optional parameters. Default to none.") - appFlags.Stdin = flag.Bool("stdin", false, "STDIN query. Default false") - flag.Parse() +func quackWithDB(db *sql.DB, query string, format string) (string, error) { + startTime := time.Now() + rows, err := db.Query(query) + if err != nil { + return "", err + } + elapsedTime := time.Since(startTime) + + switch format { + case "json": + return rowsToJSON(rows, elapsedTime) + case "csv": + return rowsToCSV(rows, elapsedTime) + case "tsv": + return rowsToTSV(rows, elapsedTime) + default: + return "", fmt.Errorf("unsupported format: %s", format) + } } -// extractAndRemoveFormat extracts the FORMAT clause from the query and returns the query without the FORMAT clause -func extractAndRemoveFormat(input string) (string, string) { - re := regexp.MustCompile(`(?i)\bFORMAT\s+(\w+)\b`) - match := re.FindStringSubmatch(input) - if len(match) != 2 { - return input, "" - } - format := match[1] - return re.ReplaceAllString(input, ""), format +func rowsToJSON(rows *sql.Rows, elapsedTime time.Duration) (string, error) { + columns, err := rows.Columns() + if err != nil { + return "", err + } + + count := len(columns) + tableData := make([]map[string]interface{}, 0) + values := make([]interface{}, count) + valuePtrs := make([]interface{}, count) + + for rows.Next() { + for i := range columns { + valuePtrs[i] = &values[i] + } + rows.Scan(valuePtrs...) + + entry := make(map[string]interface{}) + for i, col := range columns { + var v interface{} + val := values[i] + b, ok := val.([]byte) + if ok { + v = string(b) + } else { + v = val + } + entry[col] = v + } + + tableData = append(tableData, entry) + } + + result := map[string]interface{}{ + "data": tableData, + "elapsed_time": elapsedTime.String(), + } + + jsonData, err := json.Marshal(result) + if err != nil { + return "", err + } + + return string(jsonData), nil } -// Metadata is the metadata for a column -type Metadata struct { - Name string `json:"name"` - Type string `json:"type"` +func rowsToCSV(rows *sql.Rows, elapsedTime time.Duration) (string, error) { + var sb strings.Builder + writer := csv.NewWriter(&sb) + + columns, err := rows.Columns() + if err != nil { + return "", err + } + + writer.Write(columns) + + count := len(columns) + values := make([]interface{}, count) + valuePtrs := make([]interface{}, count) + + for rows.Next() { + for i := range columns { + valuePtrs[i] = &values[i] + } + rows.Scan(valuePtrs...) + + row := make([]string, count) + for i, col := range columns { + var v interface{} + val := values[i] + b, ok := val.([]byte) + if ok { + v = string(b) + } else { + v = val + } + row[i] = fmt.Sprintf("%v", v) + } + + writer.Write(row) + } + + writer.Flush() + if err := writer.Error(); err != nil { + return "", err + } + + sb.WriteString(fmt.Sprintf("\nElapsed Time: %s", elapsedTime)) + return sb.String(), nil } -// Statistics is the statistics for a query -type Statistics struct { - Elapsed float64 `json:"elapsed"` - RowsRead int `json:"rows_read"` - BytesRead int `json:"bytes_read"` +func rowsToTSV(rows *sql.Rows, elapsedTime time.Duration) (string, error) { + var sb strings.Builder + writer := csv.NewWriter(&sb) + writer.Comma = '\t' + + columns, err := rows.Columns() + if err != nil { + return "", err + } + + writer.Write(columns) + + count := len(columns) + values := make([]interface{}, count) + valuePtrs := make([]interface{}, count) + + for rows.Next() { + for i := range columns { + valuePtrs[i] = &values[i] + } + rows.Scan(valuePtrs...) + + row := make([]string, count) + for i, col := range columns { + var v interface{} + val := values[i] + b, ok := val.([]byte) + if ok { + v = string(b) + } else { + v = val + } + row[i] = fmt.Sprintf("%v", v) + } + + writer.Write(row) + } + + writer.Flush() + if err := writer.Error(); err != nil { + return "", err + } + + sb.WriteString(fmt.Sprintf("\nElapsed Time: %s", elapsedTime)) + return sb.String(), nil } -// OutputJSON is the JSON output for a query -type OutputJSON struct { - Meta []Metadata `json:"meta"` - Data [][]interface{} `json:"data"` - Rows int `json:"rows"` - RowsBeforeLimitAtLeast int `json:"rows_before_limit_at_least"` - Statistics Statistics `json:"statistics"` +// Existing handlers for backwards compatibility +type SessionManager struct { + sessions map[string]*Session + mu sync.Mutex } -// rowsToJSON converts the rows to JSON string -func rowsToJSON(rows *sql.Rows, elapsedTime time.Duration) (string, error) { - defer rows.Close() - - // Get column names - columns, err := rows.Columns() - if err != nil { - return "", err - } - - // Create a slice to store maps of column names and their corresponding values - var results OutputJSON - results.Meta = make([]Metadata, len(columns)) - results.Data = make([][]interface{}, 0) - - for i, column := range columns { - results.Meta[i].Name = column - } - - for rows.Next() { - // Create a slice to hold pointers to the values of the columns - values := make([]interface{}, len(columns)) - for i := range columns { - values[i] = new(interface{}) - } - - // Scan the values from the row into the pointers - err := rows.Scan(values...) - if err != nil { - return "", err - } - - // Create a slice to hold the row data - rowData := make([]interface{}, len(columns)) - for i, value := range values { - // Convert the value to the appropriate Go type - switch v := (*(value.(*interface{}))).(type) { - case []byte: - rowData[i] = string(v) - default: - rowData[i] = v - } - } - results.Data = append(results.Data, rowData) - } - - err = rows.Err() - if err != nil { - return "", err - } - - results.Rows = len(results.Data) - results.RowsBeforeLimitAtLeast = len(results.Data) - - // Populate the statistics object with number of rows, bytes, and elapsed time - results.Statistics.Elapsed = elapsedTime.Seconds() - results.Statistics.RowsRead = results.Rows - // Note: bytes_read is an approximation, it's just the number of rows * number of columns - // results.Statistics.BytesRead = results.Rows * len(columns) * 8 // Assuming each value takes 8 bytes - jsonData, err := json.Marshal(results) - if err != nil { - return "", err - } - - return string(jsonData), nil +var manager = &SessionManager{sessions: make(map[string]*Session)} + +func startSessionHandler(w http.ResponseWriter, r *http.Request) { + params := r.URL.Query().Get("params") + session, err := manager.NewSession(params) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(map[string]string{"session_id": session.ID}) } -// rowsToTSV converts the rows to TSV string -func rowsToTSV(rows *sql.Rows, cols bool) (string, error) { - var result []string - columns, err := rows.Columns() - if err != nil { - return "", err - } - - if cols { - // Append column names as the first row - result = append(result, strings.Join(columns, "\t")) - } - - // Fetch rows and append their values as tab-delimited lines - values := make([]interface{}, len(columns)) - scanArgs := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - err := rows.Scan(scanArgs...) - if err != nil { - return "", err - } - - var lineParts []string - for _, v := range values { - lineParts = append(lineParts, fmt.Sprintf("%v", v)) - } - result = append(result, strings.Join(lineParts, "\t")) - } - - if err := rows.Err(); err != nil { - return "", err - } - - return strings.Join(result, "\n"), nil +func closeSessionHandler(w http.ResponseWriter, r *http.Request) { + sessionID := r.URL.Query().Get("session_id") + + err := manager.CloseSession(sessionID) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + + w.WriteHeader(http.StatusOK) } -// rowsToCSV converts the rows to CSV string -func rowsToCSV(rows *sql.Rows, cols bool) (string, error) { - var result []string - columns, err := rows.Columns() - if err != nil { - return "", err - } - - if cols { - // Append column names as the first row - result = append(result, strings.Join(columns, ",")) - } - - // Fetch rows and append their values as CSV rows - values := make([]interface{}, len(columns)) - scanArgs := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - err := rows.Scan(scanArgs...) - if err != nil { - return "", err - } - - var lineParts []string - for _, v := range values { - lineParts = append(lineParts, fmt.Sprintf("%v", v)) - } - result = append(result, strings.Join(lineParts, ",")) - } - - if err := rows.Err(); err != nil { - return "", err - } - - return strings.Join(result, "\n"), nil +func (sm *SessionManager) NewSession(params string) (*Session, error) { + sm.mu.Lock() + defer sm.mu.Unlock() + + db, err := sql.Open("duckdb", params) + if err != nil { + return nil, err + } + + session := &Session{ + ID: generateSessionID(), + DB: db, + } + sm.sessions[session.ID] = session + return session, nil } -func main() { - initFlags() - default_format := *appFlags.Format - default_params := *appFlags.Params - if *appFlags.Stdin { - scanner := bufio.NewScanner((os.Stdin)) - query := "" - for scanner.Scan() { - query = query + "\n" + scanner.Text() - } - if err := scanner.Err(); err != nil { - fmt.Fprintln(os.Stderr, "reading standard input:", err) - } - cleanquery, format := extractAndRemoveFormat(query) - if len(format) > 0 { - query = cleanquery - default_format = format - } - result, err := quack(query, true, default_format, default_params) - if err != nil { - fmt.Println(err) - os.Exit(1) - } else { - fmt.Println(result) - } - } else { - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - var bodyBytes []byte - var query string - var err error - - // handle query parameter - if r.URL.Query().Get("query") != "" { - // query = r.FormValue("query") - query = r.URL.Query().Get("query") - } else if r.Body != nil { - bodyBytes, err = ioutil.ReadAll(r.Body) - if err != nil { - fmt.Printf("Body reading error: %v", err) - return - } - defer r.Body.Close() - query = string(bodyBytes) - } - - switch r.Header.Get("Accept") { - case "application/json": - w.Header().Set("Content-Type", "application/json; charset=utf-8") - case "application/xml": - w.Header().Set("Content-Type", "application/xml; charset=utf-8") - case "text/css": - w.Header().Set("Content-Type", "text/css; charset=utf-8") - default: - w.Header().Set("Content-Type", "text/html; charset=utf-8") - } - - // format handling - if r.URL.Query().Get("default_format") != "" { - default_format = r.URL.Query().Get("default_format") - } - // param handling - if r.URL.Query().Get("default_params") != "" { - default_params = r.URL.Query().Get("default_params") - } - // extract FORMAT from query and override the current `default_format` - cleanquery, format := extractAndRemoveFormat(query) - if len(format) > 0 { - query = cleanquery - default_format = format - } - - if len(query) == 0 { - _, _ = w.Write([]byte(staticPlay)) - } else { - result, err := quack(query, false, default_format, default_params) - if err != nil { - _, _ = w.Write([]byte(err.Error())) - } else { - _, _ = w.Write([]byte(result)) - } - } - }) - - fmt.Printf("QuackPipe API Running: %s:%s\n", *appFlags.Host, *appFlags.Port) - if err := http.ListenAndServe(*appFlags.Host+":"+*appFlags.Port, nil); err != nil { - panic(err) - } - } +func (sm *SessionManager) CloseSession(id string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, exists := sm.sessions[id] + if !exists { + return fmt.Errorf("session not found") + } + + err := session.DB.Close() + if err != nil { + return err + } + + delete(sm.sessions, id) + return nil } + +func generateSessionID() string { + return fmt.Sprintf("%d", time.Now().UnixNano()) +} +