-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1dbd78b
Showing
312 changed files
with
1,307,342 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# ide settings | ||
.idea | ||
.vscode | ||
.codecc | ||
|
||
# mac temp file | ||
.DS_Store | ||
|
||
# env | ||
venv | ||
.venv | ||
.noxenv | ||
|
||
# build | ||
*.py[cod] | ||
*.egg-info | ||
*.egg | ||
.eggs | ||
eggs | ||
build | ||
dist | ||
__pycache__ | ||
|
||
# logs | ||
logs | ||
*.log | ||
*.log* | ||
*.out | ||
nohup* | ||
|
||
# pytest | ||
.pytest_cache | ||
# pytest report | ||
htmlcov | ||
|
||
# coverage | ||
*.coverage | ||
*.coverage.* | ||
pytest.cov.tmp | ||
|
||
# data | ||
datasets | ||
dataset | ||
vector_cache | ||
third_party | ||
data | ||
|
||
# discard | ||
discard | ||
|
||
tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Decomposition for Enhancing Attention: Improving LLM-based Text-to-SQL through Workflow Paradigm | ||
|
||
Based on the idea that **D**ecomposition for **E**nhancing **A**ttention, we propose the workflow paradigm method named DEA-SQL with five major steps as shown in Figure. Check out our [paper](https://arxiv.org/abs/2402.10671) for more information. | ||
|
||
|
||
 | ||
|
||
|
||
## Requirements | ||
``` | ||
nltk==3.8.1 | ||
sqlparse==0.4.2 | ||
openai==0.28.0 | ||
langchain==0.0.281 | ||
backoff==2.2.1 | ||
termcolor==2.3.0 | ||
pandas==2.0.3 | ||
scikit-learn==1.3.0 | ||
timeout_decorator==0.5.0 | ||
sql_metadata==2.9.0 | ||
transformers==4.32.0 | ||
torch==1.12.1 | ||
``` | ||
## Environment | ||
1. pip install requirements | ||
2. python nltk_downloader.py | ||
|
||
|
||
## Data Preparation | ||
Download the data set from the [spider official website](https://yale-lily.github.io/spider), unzip it and put it into the data folder. An example of the file directory is data/spider/database. | ||
|
||
## Usage | ||
Please modify the OpenAI configuration in common/static_config.py and configure the relevant environment variables for the Azure OpenAI API. | ||
|
||
Several important parameters: | ||
- **dataset**: The name of dataset. | ||
- **few_shot_mode**: The method of retrieving fewshot can be selected from [random, ques_tim, masked_ques_sim]. | ||
- **few_shot_data**: The data of retrieving fewshot can be selected from [train_merge_v1, train_merge_v5] | ||
- **insert_value**: The number of lines that are inserted in database prompt. | ||
- **embedding_base_model**: The base embedding model in retrieving few-shot step. | ||
- **sc_filter_nums**: The number of information filter layer. | ||
|
||
## Quick Start | ||
|
||
### prediction on the Spider Dev datasets | ||
``` | ||
python main.py --save_file_name "dea-sql.txt" --dataset "spider" --mode "dev" --sample "False" --few_shot_mode "masked_ques_sim" --insert_value 3 --embedding_base_model "openai" --sc_filter_nums 3 --few_shot_data "train_merge_v5" | ||
``` | ||
|
||
### evaluation on the Spider Dev datasets | ||
For the first evaluation, please perform: """python nltk_downloader.py""" | ||
|
||
``` | ||
python evaluation/test-suite-sql-eval/evaluation.py --gold "evaluation/gold_files/spider_dev_gold.sql" --pred "outputs/spider/dea-sql.txt" --db ./data/spider/database --print_file_name "outputs/spider/spider-dea-sql.txt" --table './data/spider/tables.json' --etype exec | ||
``` | ||
|
||
## Citing DEA-SQL | ||
|
||
``` | ||
@article{xie2024decomposition, | ||
title={Decomposition for Enhancing Attention: Improving LLM-based Text-to-SQL through Workflow Paradigm}, | ||
author={Yuanzhen Xie and Xinzhou Jin and Tao Xie and MingXiong Lin and Liang Chen and Chenyun Yu and Lei Cheng and ChengXiang Zhuo and Bo Hu and Zang Li}, | ||
journal={arXiv preprint arXiv:2402.10671}, | ||
year={2024} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# -*- coding: utf-8 -*- | ||
# Project : DEASQL | ||
# File : __init__.py | ||
# Author : | ||
# Email : | ||
# Time : 2023/12/26 15:19 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--question", type=str) | ||
parser.add_argument("--openai_api_key", type=str) | ||
parser.add_argument("--openai_group_id", type=str, default="") | ||
parser.add_argument("--model", type=str) | ||
parser.add_argument("--start_index", type=int, default=0) | ||
parser.add_argument("--end_index", type=int, default=1000000) | ||
parser.add_argument("--temperature", type=float, default=0) | ||
parser.add_argument("--mini_index_path", type=str, default="") | ||
parser.add_argument("--batch_size", type=int, default=1) | ||
parser.add_argument("--n", type=int, default=5, help="Size of self-consistent set") | ||
parser.add_argument("--db_dir", type=str, default="dataset/spider/database") | ||
|
||
# our configuration | ||
parser.add_argument("--key_config", type=str, default='api_key1', help="api_key1, api_key2, api_key3") | ||
parser.add_argument("--key_version", type=str, default='gpt-4', help="gpt-35-turbo, gpt-4") | ||
parser.add_argument("--save_file_name", type=str, default="test.txt", help="sql save file") | ||
parser.add_argument("--dataset", type=str, default="spider", help="spider or dev (bird)") | ||
parser.add_argument("--sample", type=str, default="False", help="True or False") | ||
parser.add_argument("--mode", type=str, default="debug", help="debug or dev") | ||
parser.add_argument("--lang_mode", type=str, default="en", help="just en") | ||
parser.add_argument("--filter_mode", type=str, default="complex", help="simple, complex, simple_v2, none") | ||
parser.add_argument("--prompt_mode", type=str, default="v2", help="v1, v2, v3, v4") | ||
parser.add_argument("--data_fold", type=str, default="1", help="1, 2, 3") | ||
parser.add_argument('--train', default=False, action='store_true', help='train or dev') | ||
parser.add_argument("--dataset_file", type=str, default="dev.json") | ||
parser.add_argument("--test_id", type=int, default=46, help="1, 2, 3") | ||
|
||
parser.add_argument("--re_run", default=False, action='store_true') | ||
parser.add_argument('--re_run_idx', type=int, default=0) | ||
|
||
parser.add_argument("--sc_nums_question_label", type=int, default=1, help="self-consistency numbers") | ||
parser.add_argument("--sc_nums", type=int, default=1, help="self-consistency numbers") | ||
parser.add_argument("--sc_filter_nums", type=int, default=2, help="self-consistency numbers") | ||
parser.add_argument("--sc_filter_temp", type=float, default=0, help="self-consistency temperature for filter") | ||
parser.add_argument("--sc_ques_temp", type=float, default=0, help="self-consistency temperature for question type") | ||
parser.add_argument("--sc_sql_temp", type=float, default=0, help="self-consistency temperature for sql generation") | ||
parser.add_argument("--insert_value", type=int, default=0, help="insert value of table schema") | ||
parser.add_argument('--step_name', type=str, default="all", | ||
help='Which step to execute? one of ["all", "ner_results", "filter_infos", "qc", "sql"]') | ||
parser.add_argument('--step', default=False, action='store_true', help='whether open the mode step debug') | ||
parser.add_argument('--step1', default=False, action='store_true', help='rerun step1') | ||
parser.add_argument('--step2', default=False, action='store_true', help='skip step1') | ||
parser.add_argument('--step3', default=False, action='store_true', help='skip step1, 2') | ||
parser.add_argument('--step4', default=False, action='store_true', help='skip step1, 2, 3') | ||
parser.add_argument('--step5', default=False, action='store_true', help='skip step1, 2, 3, 4') | ||
parser.add_argument('--step6', default=False, action='store_true', help='skip step1, 2, 3, 4, 5') | ||
parser.add_argument('--save_version', type=int, default=1, help='the step version') | ||
parser.add_argument('--n_shots', type=int, default=3, help='the number of shots') | ||
parser.add_argument('--few_shot_data', type=str, default='train_merge_v1', | ||
help='one of ["train_merge_v1", "train_merge_v5"]') | ||
parser.add_argument('--few_shot_mode', type=str, default='ques_sim1', | ||
help='one of ["random", "ques_sim", "masked_ques_sim", "query_sim"]') | ||
parser.add_argument('--embedding_base_model', type=str, default='openai', help='one of ["transformer", "openai"]') | ||
parser.add_argument('--schema_mode', type=str, default='CreateTableInsertRowFK', | ||
help='one of ["CreateTableInsertRow", "CreateTableInsertRowFK"]') | ||
# Table(Columns), Columns=[], Columns=[]+FK, CreateTable, CreateTableInsertRow, CreateTableSelectRow, CreateTableSelectCol | ||
parser.add_argument('--fk_mode', type=str, default="newupperfk", help="newupperfk means keep internal fk, newupper means not keep internal fk") | ||
|
||
parser.add_argument('--has_error_case', default=False, action='store_true', help='has error case in generate sql') | ||
|
||
### ablation experiment | ||
parser.add_argument('--reduce_ql', default=False, action='store_true', help='reduce the step of question label') | ||
|
||
|
||
################## evaluate singel sql ########################## | ||
parser.add_argument('--db', dest='db', type=str, default="./data/spider/database", | ||
help="the directory that contains all the databases and test suites") | ||
parser.add_argument('--table', dest='table', type=str, default="./data/spider/tables.json", | ||
help="the tables.json schema file") | ||
parser.add_argument('--etype', dest='etype', type=str, default='exec', | ||
help="evaluation type, exec for test suite accuracy, match for the original " | ||
"exact set match accuracy", | ||
choices=('all', 'exec', 'match')) | ||
parser.add_argument('--plug_value', default=False, action='store_true', | ||
help='whether to plug in the gold value into the predicted query; suitable if your model ' | ||
'does not predict values.') | ||
parser.add_argument('--keep_distinct', default=False, action='store_true', | ||
help='whether to keep distinct keyword during evaluation. default is false.') | ||
parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true', | ||
help='whether to print progress bar of running test inputs for each datapoint') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# -*- coding: utf-8 -*- | ||
# Project : DEASQL | ||
# File : common.py | ||
# Author : | ||
# Email : | ||
# Time : 2023/10/16 19:58 | ||
import json | ||
import os | ||
import pickle | ||
import re | ||
|
||
from langchain.chains import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
|
||
|
||
def run_chain(llm, prompt, param_dicts): | ||
chain = LLMChain(llm=llm, prompt=prompt) | ||
output = chain.run(**param_dicts) | ||
return output | ||
|
||
|
||
def get_prompt(inputs, template): | ||
prompt = PromptTemplate( | ||
input_variables=inputs, | ||
template=template | ||
) | ||
return prompt | ||
|
||
|
||
def get_prompt_content(prompt, param_dicts): | ||
formatted_text = prompt.format(**param_dicts) | ||
return formatted_text | ||
|
||
|
||
def save_obj(obj, name): | ||
""" | ||
save pickle | ||
:param obj: The data that needs to be stored. | ||
:param name: file path | ||
""" | ||
with open(name, "wb") as file: | ||
pickle.dump(obj, file) | ||
|
||
|
||
def load_obj(name): | ||
"""load pickle""" | ||
with open(name, "rb") as file: | ||
return pickle.load(file) | ||
|
||
|
||
def get_lower_list(temp_list): | ||
target_list = [word.lower() for word in temp_list] | ||
return target_list | ||
|
||
|
||
def extract_references(sql): | ||
pattern = r"REFERENCES\s+(\w+)\s+\((\w+)\)" | ||
matches = re.findall(pattern, sql) | ||
return matches | ||
|
||
|
||
def extract_label(text): | ||
|
||
pattern = r"Label: ([\w-]+), ([\w-]+)" | ||
|
||
result = re.search(pattern, text) | ||
return_label = [] | ||
|
||
if result: | ||
labels = result.groups() | ||
if "NON-JOIN" in labels and "NON-NESTED" in labels: | ||
return_label.append("EASY") | ||
elif "JOIN" in labels and "NON-NESTED" in labels: | ||
return_label.append("JOIN") | ||
elif "NESTED" in labels and "NON-JOIN" in labels: | ||
return_label.append("NESTED") | ||
elif "NESTED" in labels and "JOIN" in labels: | ||
return_label.append("JOIN-NESTED") | ||
else: | ||
return_label.append("EASY") | ||
|
||
if "MAX" in labels: | ||
return_label.append("MAX") | ||
elif "MIN" in labels: | ||
return_label.append("MIN") | ||
elif "SUM" in labels: | ||
return_label.append("SUM") | ||
elif "AVG" in labels: | ||
return_label.append("AVG") | ||
elif "COUNT" in labels: | ||
return_label.append("COUNT") | ||
else: | ||
return_label.append("NON") | ||
return return_label[0], return_label[1] | ||
return "EASY", "NON" | ||
|
||
|
||
def extract_sql(text, init_sql): | ||
pattern = r'The modified SQL: (.*)' | ||
|
||
result = re.search(pattern, text) | ||
if "not an extremum problem" in text: | ||
return init_sql | ||
|
||
if result: | ||
modified_sql = result.group(1) | ||
return modified_sql | ||
else: | ||
print("No match found.") | ||
return init_sql | ||
|
||
|
||
def get_dict_from_str(content): | ||
content = content.replace("\n", " ") | ||
content = content.replace(",", ",") | ||
result = json.loads(content) | ||
return result | ||
|
||
|
||
def load_or_save(filename): | ||
if os.path.exists(filename): | ||
with open(filename, 'r') as f: | ||
data = json.load(f) | ||
return data | ||
else: | ||
with open(filename, 'w') as f: | ||
json.dump(data, f) | ||
|
||
|
||
def ensure_dir(dir_path): | ||
r"""Make sure the directory exists, if it does not exist, create it | ||
Args: | ||
dir_path (str): directory path | ||
""" | ||
if not os.path.exists(dir_path): | ||
os.makedirs(dir_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# -*- coding: utf-8 -*- | ||
# Project : DEASQL | ||
# File : static_config.py | ||
# Author : | ||
# Email : | ||
# Time : 2023/10/16 20:24 | ||
CONFIG35 = { | ||
"api_key": "your openai_api_key", | ||
"max_tokens": 4096, | ||
"methods": { | ||
"model_name": "gpt-3.5-turbo", | ||
"temperature": 0, | ||
"engine": "gpt-35-turbo", | ||
}, | ||
"emb_api_key": "your emb_api_key", | ||
} | ||
|
||
CONFIG = { | ||
"api_key1": { | ||
"api_key": "your openai_api_key", | ||
"api_base": "your openai_api_base", | ||
}, | ||
"api_key2": { | ||
"api_key": "your openai_api_key", | ||
"api_base": "your openai_api_base", | ||
}, | ||
"api_key3": { | ||
"api_key": "your openai_api_key", | ||
"api_base": "your openai_api_base", | ||
}, | ||
"max_tokens": 4096, | ||
"methods": { | ||
"model_name": "gpt-4", | ||
"temperature": 0, | ||
"engine": "gpt-4", | ||
}, | ||
"emb_api_key": "your emb_api_key", | ||
} | ||
|
Oops, something went wrong.