Skip to content

Commit

Permalink
init DEA-SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingFeather committed Feb 20, 2024
0 parents commit 1dbd78b
Show file tree
Hide file tree
Showing 312 changed files with 1,307,342 additions and 0 deletions.
51 changes: 51 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# ide settings
.idea
.vscode
.codecc

# mac temp file
.DS_Store

# env
venv
.venv
.noxenv

# build
*.py[cod]
*.egg-info
*.egg
.eggs
eggs
build
dist
__pycache__

# logs
logs
*.log
*.log*
*.out
nohup*

# pytest
.pytest_cache
# pytest report
htmlcov

# coverage
*.coverage
*.coverage.*
pytest.cov.tmp

# data
datasets
dataset
vector_cache
third_party
data

# discard
discard

tests
66 changes: 66 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Decomposition for Enhancing Attention: Improving LLM-based Text-to-SQL through Workflow Paradigm

Based on the idea that **D**ecomposition for **E**nhancing **A**ttention, we propose the workflow paradigm method named DEA-SQL with five major steps as shown in Figure. Check out our [paper](https://arxiv.org/abs/2402.10671) for more information.


![model](./docs/model.png)


## Requirements
```
nltk==3.8.1
sqlparse==0.4.2
openai==0.28.0
langchain==0.0.281
backoff==2.2.1
termcolor==2.3.0
pandas==2.0.3
scikit-learn==1.3.0
timeout_decorator==0.5.0
sql_metadata==2.9.0
transformers==4.32.0
torch==1.12.1
```
## Environment
1. pip install requirements
2. python nltk_downloader.py


## Data Preparation
Download the data set from the [spider official website](https://yale-lily.github.io/spider), unzip it and put it into the data folder. An example of the file directory is data/spider/database.

## Usage
Please modify the OpenAI configuration in common/static_config.py and configure the relevant environment variables for the Azure OpenAI API.

Several important parameters:
- **dataset**: The name of dataset.
- **few_shot_mode**: The method of retrieving fewshot can be selected from [random, ques_tim, masked_ques_sim].
- **few_shot_data**: The data of retrieving fewshot can be selected from [train_merge_v1, train_merge_v5]
- **insert_value**: The number of lines that are inserted in database prompt.
- **embedding_base_model**: The base embedding model in retrieving few-shot step.
- **sc_filter_nums**: The number of information filter layer.

## Quick Start

### prediction on the Spider Dev datasets
```
python main.py --save_file_name "dea-sql.txt" --dataset "spider" --mode "dev" --sample "False" --few_shot_mode "masked_ques_sim" --insert_value 3 --embedding_base_model "openai" --sc_filter_nums 3 --few_shot_data "train_merge_v5"
```

### evaluation on the Spider Dev datasets
For the first evaluation, please perform: """python nltk_downloader.py"""

```
python evaluation/test-suite-sql-eval/evaluation.py --gold "evaluation/gold_files/spider_dev_gold.sql" --pred "outputs/spider/dea-sql.txt" --db ./data/spider/database --print_file_name "outputs/spider/spider-dea-sql.txt" --table './data/spider/tables.json' --etype exec
```

## Citing DEA-SQL

```
@article{xie2024decomposition,
title={Decomposition for Enhancing Attention: Improving LLM-based Text-to-SQL through Workflow Paradigm},
author={Yuanzhen Xie and Xinzhou Jin and Tao Xie and MingXiong Lin and Liang Chen and Chenyun Yu and Lei Cheng and ChengXiang Zhuo and Bo Hu and Zang Li},
journal={arXiv preprint arXiv:2402.10671},
year={2024}
}
```
6 changes: 6 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
# Project : DEASQL
# File : __init__.py
# Author :
# Email :
# Time : 2023/12/26 15:19
83 changes: 83 additions & 0 deletions argsparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--question", type=str)
parser.add_argument("--openai_api_key", type=str)
parser.add_argument("--openai_group_id", type=str, default="")
parser.add_argument("--model", type=str)
parser.add_argument("--start_index", type=int, default=0)
parser.add_argument("--end_index", type=int, default=1000000)
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--mini_index_path", type=str, default="")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--n", type=int, default=5, help="Size of self-consistent set")
parser.add_argument("--db_dir", type=str, default="dataset/spider/database")

# our configuration
parser.add_argument("--key_config", type=str, default='api_key1', help="api_key1, api_key2, api_key3")
parser.add_argument("--key_version", type=str, default='gpt-4', help="gpt-35-turbo, gpt-4")
parser.add_argument("--save_file_name", type=str, default="test.txt", help="sql save file")
parser.add_argument("--dataset", type=str, default="spider", help="spider or dev (bird)")
parser.add_argument("--sample", type=str, default="False", help="True or False")
parser.add_argument("--mode", type=str, default="debug", help="debug or dev")
parser.add_argument("--lang_mode", type=str, default="en", help="just en")
parser.add_argument("--filter_mode", type=str, default="complex", help="simple, complex, simple_v2, none")
parser.add_argument("--prompt_mode", type=str, default="v2", help="v1, v2, v3, v4")
parser.add_argument("--data_fold", type=str, default="1", help="1, 2, 3")
parser.add_argument('--train', default=False, action='store_true', help='train or dev')
parser.add_argument("--dataset_file", type=str, default="dev.json")
parser.add_argument("--test_id", type=int, default=46, help="1, 2, 3")

parser.add_argument("--re_run", default=False, action='store_true')
parser.add_argument('--re_run_idx', type=int, default=0)

parser.add_argument("--sc_nums_question_label", type=int, default=1, help="self-consistency numbers")
parser.add_argument("--sc_nums", type=int, default=1, help="self-consistency numbers")
parser.add_argument("--sc_filter_nums", type=int, default=2, help="self-consistency numbers")
parser.add_argument("--sc_filter_temp", type=float, default=0, help="self-consistency temperature for filter")
parser.add_argument("--sc_ques_temp", type=float, default=0, help="self-consistency temperature for question type")
parser.add_argument("--sc_sql_temp", type=float, default=0, help="self-consistency temperature for sql generation")
parser.add_argument("--insert_value", type=int, default=0, help="insert value of table schema")
parser.add_argument('--step_name', type=str, default="all",
help='Which step to execute? one of ["all", "ner_results", "filter_infos", "qc", "sql"]')
parser.add_argument('--step', default=False, action='store_true', help='whether open the mode step debug')
parser.add_argument('--step1', default=False, action='store_true', help='rerun step1')
parser.add_argument('--step2', default=False, action='store_true', help='skip step1')
parser.add_argument('--step3', default=False, action='store_true', help='skip step1, 2')
parser.add_argument('--step4', default=False, action='store_true', help='skip step1, 2, 3')
parser.add_argument('--step5', default=False, action='store_true', help='skip step1, 2, 3, 4')
parser.add_argument('--step6', default=False, action='store_true', help='skip step1, 2, 3, 4, 5')
parser.add_argument('--save_version', type=int, default=1, help='the step version')
parser.add_argument('--n_shots', type=int, default=3, help='the number of shots')
parser.add_argument('--few_shot_data', type=str, default='train_merge_v1',
help='one of ["train_merge_v1", "train_merge_v5"]')
parser.add_argument('--few_shot_mode', type=str, default='ques_sim1',
help='one of ["random", "ques_sim", "masked_ques_sim", "query_sim"]')
parser.add_argument('--embedding_base_model', type=str, default='openai', help='one of ["transformer", "openai"]')
parser.add_argument('--schema_mode', type=str, default='CreateTableInsertRowFK',
help='one of ["CreateTableInsertRow", "CreateTableInsertRowFK"]')
# Table(Columns), Columns=[], Columns=[]+FK, CreateTable, CreateTableInsertRow, CreateTableSelectRow, CreateTableSelectCol
parser.add_argument('--fk_mode', type=str, default="newupperfk", help="newupperfk means keep internal fk, newupper means not keep internal fk")

parser.add_argument('--has_error_case', default=False, action='store_true', help='has error case in generate sql')

### ablation experiment
parser.add_argument('--reduce_ql', default=False, action='store_true', help='reduce the step of question label')


################## evaluate singel sql ##########################
parser.add_argument('--db', dest='db', type=str, default="./data/spider/database",
help="the directory that contains all the databases and test suites")
parser.add_argument('--table', dest='table', type=str, default="./data/spider/tables.json",
help="the tables.json schema file")
parser.add_argument('--etype', dest='etype', type=str, default='exec',
help="evaluation type, exec for test suite accuracy, match for the original "
"exact set match accuracy",
choices=('all', 'exec', 'match'))
parser.add_argument('--plug_value', default=False, action='store_true',
help='whether to plug in the gold value into the predicted query; suitable if your model '
'does not predict values.')
parser.add_argument('--keep_distinct', default=False, action='store_true',
help='whether to keep distinct keyword during evaluation. default is false.')
parser.add_argument('--progress_bar_for_each_datapoint', default=False, action='store_true',
help='whether to print progress bar of running test inputs for each datapoint')
139 changes: 139 additions & 0 deletions common/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Project : DEASQL
# File : common.py
# Author :
# Email :
# Time : 2023/10/16 19:58
import json
import os
import pickle
import re

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate


def run_chain(llm, prompt, param_dicts):
chain = LLMChain(llm=llm, prompt=prompt)
output = chain.run(**param_dicts)
return output


def get_prompt(inputs, template):
prompt = PromptTemplate(
input_variables=inputs,
template=template
)
return prompt


def get_prompt_content(prompt, param_dicts):
formatted_text = prompt.format(**param_dicts)
return formatted_text


def save_obj(obj, name):
"""
save pickle
:param obj: The data that needs to be stored.
:param name: file path
"""
with open(name, "wb") as file:
pickle.dump(obj, file)


def load_obj(name):
"""load pickle"""
with open(name, "rb") as file:
return pickle.load(file)


def get_lower_list(temp_list):
target_list = [word.lower() for word in temp_list]
return target_list


def extract_references(sql):
pattern = r"REFERENCES\s+(\w+)\s+\((\w+)\)"
matches = re.findall(pattern, sql)
return matches


def extract_label(text):

pattern = r"Label: ([\w-]+), ([\w-]+)"

result = re.search(pattern, text)
return_label = []

if result:
labels = result.groups()
if "NON-JOIN" in labels and "NON-NESTED" in labels:
return_label.append("EASY")
elif "JOIN" in labels and "NON-NESTED" in labels:
return_label.append("JOIN")
elif "NESTED" in labels and "NON-JOIN" in labels:
return_label.append("NESTED")
elif "NESTED" in labels and "JOIN" in labels:
return_label.append("JOIN-NESTED")
else:
return_label.append("EASY")

if "MAX" in labels:
return_label.append("MAX")
elif "MIN" in labels:
return_label.append("MIN")
elif "SUM" in labels:
return_label.append("SUM")
elif "AVG" in labels:
return_label.append("AVG")
elif "COUNT" in labels:
return_label.append("COUNT")
else:
return_label.append("NON")
return return_label[0], return_label[1]
return "EASY", "NON"


def extract_sql(text, init_sql):
pattern = r'The modified SQL: (.*)'

result = re.search(pattern, text)
if "not an extremum problem" in text:
return init_sql

if result:
modified_sql = result.group(1)
return modified_sql
else:
print("No match found.")
return init_sql


def get_dict_from_str(content):
content = content.replace("\n", " ")
content = content.replace(",", ",")
result = json.loads(content)
return result


def load_or_save(filename):
if os.path.exists(filename):
with open(filename, 'r') as f:
data = json.load(f)
return data
else:
with open(filename, 'w') as f:
json.dump(data, f)


def ensure_dir(dir_path):
r"""Make sure the directory exists, if it does not exist, create it
Args:
dir_path (str): directory path
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
39 changes: 39 additions & 0 deletions common/config/static_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Project : DEASQL
# File : static_config.py
# Author :
# Email :
# Time : 2023/10/16 20:24
CONFIG35 = {
"api_key": "your openai_api_key",
"max_tokens": 4096,
"methods": {
"model_name": "gpt-3.5-turbo",
"temperature": 0,
"engine": "gpt-35-turbo",
},
"emb_api_key": "your emb_api_key",
}

CONFIG = {
"api_key1": {
"api_key": "your openai_api_key",
"api_base": "your openai_api_base",
},
"api_key2": {
"api_key": "your openai_api_key",
"api_base": "your openai_api_base",
},
"api_key3": {
"api_key": "your openai_api_key",
"api_base": "your openai_api_base",
},
"max_tokens": 4096,
"methods": {
"model_name": "gpt-4",
"temperature": 0,
"engine": "gpt-4",
},
"emb_api_key": "your emb_api_key",
}

Loading

0 comments on commit 1dbd78b

Please sign in to comment.