Skip to content

Commit

Permalink
Updates to LLM function and ReadMe.md (#170)
Browse files Browse the repository at this point in the history
* WIP

* WIP
  • Loading branch information
gsheni authored Oct 28, 2023
1 parent c2ec913 commit bea7470
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 30 deletions.
69 changes: 50 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,71 @@ Install Trane using pip:
python -m pip install trane
```

## Usage
## Usage

Here's a quick demonstration of Trane in action:

```python
import trane

data, metadata = trane.load_airbnb()
entity_columns = ["location"]
window_size = "2d"

problem_generator = trane.ProblemGenerator(
metadata=metadata,
window_size=window_size,
entity_columns=entity_columns
metadata=metadata,
entity_columns=["location"]
)
problems = problem_generator.generate()

print(f'Generated {len(problems)} problems.')
print(problems[108])
print(problems[108].create_target_values(data).head(5))
for problem in problems[:5]:
print(problem)
```

Output:
A few of the generated problems:
```
==================================================
Generated 40 total problems
--------------------------------------------------
Classification problems: 5
Regression problems: 35
==================================================
For each <location> predict if there exists a record
For each <location> predict if there exists a record with <location> equal to <str>
For each <location> predict if there exists a record with <location> not equal to <str>
For each <location> predict if there exists a record with <rating> equal to <str>
For each <location> predict if there exists a record with <rating> not equal to <str>
```

With Trane's LLM add-on (`pip install trane[llm]`), we can determine the relevant problems with OpenAI:
```python
from trane.llm import analyze

instructions = "determine 5 most relevant problems about user's booking preferences. Do not include 'predict the first/last X' problems"
context = "Airbnb data listings in major cities, including information about hosts, pricing, location, and room type, along with over 5 million historical reviews."
relevant_problems = analyze(
problems=problems,
instructions=instructions,
context=context,
model="gpt-3.5-turbo-16k"
)
for problem in relevant_problems:
print(problem)
print(f'Reasoning: {problem.get_reasoning()}\n')
```
Generated 168 problems.
For each <location> predict the majority <rating> in all related records in the next 2 days.
location time target
0 London 2021-01-01 5
1 London 2021-01-03 4
2 London 2021-01-05 5
3 London 2021-01-07 4
4 London 2021-01-09 5
Output
```text
For each <location> predict if there exists a record
Reasoning: This problem can help identify locations with missing data or locations that have not been booked at all.
For each <location> predict the first <location> in all related records
Reasoning: Predicting the first location in all related records can provide insights into the most frequently booked locations for each city.
For each <location> predict the first <rating> in all related records
Reasoning: Predicting the first rating in all related records can provide insights into the average satisfaction level of guests for each location.
For each <location> predict the last <location> in all related records
Reasoning: Predicting the last location in all related records can provide insights into the most recent bookings for each city.
For each <location> predict the last <rating> in all related records
Reasoning: Predicting the last rating in all related records can provide insights into the recent satisfaction level of guests for each location.
```

## Community
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/test_threshold_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def test_get_k_most_frequent(dtype):
@pytest.mark.parametrize(
"dtype",
[
("int64"),
("int64[pyarrow]"),
("float64"),
("float64[pyarrow]"),
],
)
def test_get_k_most_frequent_raises(dtype):
Expand Down
17 changes: 15 additions & 2 deletions trane/core/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ def __init__(
operations,
entity_column=None,
window_size=None,
reasoning=None,
):
self.operations = operations
self.metadata = metadata
self.entity_column = entity_column
self.window_size = window_size
self.reasoning = reasoning

def __lt__(self, other):
return self.__str__() < (other.__str__())
Expand Down Expand Up @@ -57,7 +59,16 @@ def get_required_parameters(self):
return self.operations[0].required_parameters

def set_parameters(self, threshold):
return self.operations[0].set_parameters(threshold)
self.operations[0].set_parameters(threshold)

def set_reasoning(self, reasoning):
self.reasoning = reasoning

def get_reasoning(self):
return self.reasoning

def reset_reasoning(self):
self.reasoning = None

def is_classification(self):
return isinstance(self.operations[2], ExistsAggregationOp)
Expand Down Expand Up @@ -118,7 +129,9 @@ def create_target_values(self, dataframes):
# Won't this always be normalized?
normalized_dataframe = self.get_normalized_dataframe(dataframes)
if self.has_parameters_set() is False:
raise ValueError("Filter operation's parameters are not set")
print("Filter operation's parameters are not set, setting them now")
thresholds = self.get_recommended_thresholds(dataframes)
self.set_parameters(thresholds[-1])

target_dataframe_index = self.entity_column
if self.entity_column is None:
Expand Down
2 changes: 2 additions & 0 deletions trane/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def determine_gap_size(gap):
return pd.Timedelta(gap)
elif isinstance(gap, int) or isinstance(gap, pd.Timedelta):
return gap
elif not gap:
return 1
return int(gap)


Expand Down
1 change: 0 additions & 1 deletion trane/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from trane.llm.chat import chat
from trane.llm.helpers import *
20 changes: 18 additions & 2 deletions trane/llm/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re

from IPython.display import Markdown, display

Expand Down Expand Up @@ -45,6 +46,7 @@ def analyze(
instructions,
context,
model="gpt-3.5-turbo-16k",
jupyter=False,
):
prompt_context = f" The context is: {context}"
constraints = (
Expand Down Expand Up @@ -81,14 +83,20 @@ def analyze(
problems_formatted,
)
response = openai_gpt(prompt, model)
display(Markdown(response))
if jupyter:
display(Markdown(response))
else:
print(response)

relevant_ids = extract_problems_from_response(response, model)
print(relevant_ids)
relevant_ids = list(set(relevant_ids))
relevant_problems = []
for id_ in relevant_ids:
relevant_problems.append(problems[int(id_) - 1])

reasonsings = extract_reasonings_from_response(response)
for idx, reason in enumerate(reasonsings):
relevant_problems[idx].set_reasoning(reason)
relevant_problems = sorted(relevant_problems, key=lambda p: str(p))
return relevant_problems

Expand All @@ -108,6 +116,14 @@ def extract_problems_from_response(response, model):
return response


def extract_reasonings_from_response(text):
reasonsings = []
matches = re.findall(r"Reasoning: (.+?)(?:\n\n|$)", text)
for match in matches:
reasonsings.append(match)
return reasonsings


def format_problems(problems: list) -> str:
formatted = ""
for idx, problem in enumerate(problems):
Expand Down
11 changes: 7 additions & 4 deletions trane/ops/threshold_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from pandas.api.types import (
is_integer_dtype,
is_object_dtype,
is_string_dtype,
)
Expand Down Expand Up @@ -106,13 +107,15 @@ def find_threshold_to_maximize_uncertainty(

def get_k_most_frequent(series, k=3):
# get the top k most frequent values
dtype = series.dtype
if (
is_object_dtype(series.dtype)
or isinstance(series.dtype, pd.CategoricalDtype)
or is_string_dtype(series.dtype)
is_object_dtype(dtype)
or isinstance(dtype, pd.CategoricalDtype)
or is_string_dtype(dtype)
or is_integer_dtype(dtype)
):
return series.value_counts()[:k].index.tolist()
raise ValueError("Series must be categorical, string or object dtype")
raise ValueError("Series must be categorical, string, object or int dtype")


def sample_unique_values(series, max_num_unique_values=10, random_state=None):
Expand Down

0 comments on commit bea7470

Please sign in to comment.