Skip to content

Commit 3fd301c

Browse files
committed
Add rollback functionality to setup wizard
Fixes AgentOps-AI#9 Add error handling and rollback functionality to the setup wizard. * Add `rollback_actions` function in `agentstack/cli/cli.py` to undo actions like directory creation and file writing. * Update `init_project_builder` in `agentstack/cli/cli.py` to catch exceptions and call `rollback_actions`. * Track created files and directories in `insert_template` for potential rollback in `agentstack/cli/cli.py`. * Add `rollback_actions` function in `agentstack/generation/agent_generation.py` and `agentstack/generation/task_generation.py` to undo actions. * Update `generate_agent` in `agentstack/generation/agent_generation.py` and `generate_task` in `agentstack/generation/task_generation.py` to catch exceptions and call `rollback_actions`. * Add tests in `tests/test_cli_loads.py` to verify rollback functionality. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/AgentOps-AI/AgentStack/issues/9?shareId=XXXX-XXXX-XXXX-XXXX).
1 parent c2725af commit 3fd301c

File tree

4 files changed

+162
-98
lines changed

4 files changed

+162
-98
lines changed

agentstack/cli/cli.py

Lines changed: 94 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,86 +20,100 @@
2020
from .. import generation
2121
from ..utils import open_json_file, term_color, is_snake_case
2222

23+
created_files = []
24+
created_dirs = []
2325

24-
def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False):
25-
if slug_name and not is_snake_case(slug_name):
26-
print(term_color("Project name must be snake case", 'red'))
27-
return
26+
def rollback_actions():
27+
for file in created_files:
28+
if os.path.exists(file):
29+
os.remove(file)
30+
for dir in created_dirs:
31+
if os.path.exists(dir):
32+
shutil.rmtree(dir)
2833

29-
if template is not None and use_wizard:
30-
print(term_color("Template and wizard flags cannot be used together", 'red'))
31-
return
32-
33-
template_data = None
34-
if template is not None:
35-
url_start = "https://"
36-
if template[:len(url_start)] == url_start:
37-
# template is a url
38-
response = requests.get(template)
39-
if response.status_code == 200:
40-
template_data = response.json()
41-
else:
42-
print(term_color(f"Failed to fetch template data from {template}. Status code: {response.status_code}", 'red'))
43-
sys.exit(1)
44-
else:
45-
with importlib.resources.path('agentstack.templates.proj_templates', f'{template}.json') as template_path:
46-
if template_path is None:
47-
print(term_color(f"No such template {template} found", 'red'))
34+
def init_project_builder(slug_name: Optional[str] = None, template: Optional[str] = None, use_wizard: bool = False):
35+
try:
36+
if slug_name and not is_snake_case(slug_name):
37+
print(term_color("Project name must be snake case", 'red'))
38+
return
39+
40+
if template is not None and use_wizard:
41+
print(term_color("Template and wizard flags cannot be used together", 'red'))
42+
return
43+
44+
template_data = None
45+
if template is not None:
46+
url_start = "https://"
47+
if template[:len(url_start)] == url_start:
48+
# template is a url
49+
response = requests.get(template)
50+
if response.status_code == 200:
51+
template_data = response.json()
52+
else:
53+
print(term_color(f"Failed to fetch template data from {template}. Status code: {response.status_code}", 'red'))
4854
sys.exit(1)
49-
template_data = open_json_file(template_path)
50-
51-
if template_data:
52-
project_details = {
53-
"name": slug_name or template_data['name'],
54-
"version": "0.0.1",
55-
"description": template_data['description'],
56-
"author": "Name <Email>",
57-
"license": "MIT"
58-
}
59-
framework = template_data['framework']
60-
design = {
61-
'agents': template_data['agents'],
62-
'tasks': template_data['tasks']
63-
}
64-
65-
tools = template_data['tools']
66-
67-
elif use_wizard:
68-
welcome_message()
69-
project_details = ask_project_details(slug_name)
70-
welcome_message()
71-
framework = ask_framework()
72-
design = ask_design()
73-
tools = ask_tools()
74-
75-
else:
76-
welcome_message()
77-
project_details = {
78-
"name": slug_name or "agentstack_project",
79-
"version": "0.0.1",
80-
"description": "New agentstack project",
81-
"author": "Name <Email>",
82-
"license": "MIT"
83-
}
84-
85-
framework = "CrewAI" # TODO: if --no-wizard, require a framework flag
86-
87-
design = {
88-
'agents': [],
89-
'tasks': []
90-
}
91-
92-
tools = []
93-
94-
log.debug(
95-
f"project_details: {project_details}"
96-
f"framework: {framework}"
97-
f"design: {design}"
98-
)
99-
insert_template(project_details, framework, design, template_data)
100-
for tool_data in tools:
101-
generation.add_tool(tool_data['name'], agents=tool_data['agents'], path=project_details['name'])
55+
else:
56+
with importlib.resources.path('agentstack.templates.proj_templates', f'{template}.json') as template_path:
57+
if template_path is None:
58+
print(term_color(f"No such template {template} found", 'red'))
59+
sys.exit(1)
60+
template_data = open_json_file(template_path)
61+
62+
if template_data:
63+
project_details = {
64+
"name": slug_name or template_data['name'],
65+
"version": "0.0.1",
66+
"description": template_data['description'],
67+
"author": "Name <Email>",
68+
"license": "MIT"
69+
}
70+
framework = template_data['framework']
71+
design = {
72+
'agents': template_data['agents'],
73+
'tasks': template_data['tasks']
74+
}
75+
76+
tools = template_data['tools']
77+
78+
elif use_wizard:
79+
welcome_message()
80+
project_details = ask_project_details(slug_name)
81+
welcome_message()
82+
framework = ask_framework()
83+
design = ask_design()
84+
tools = ask_tools()
10285

86+
else:
87+
welcome_message()
88+
project_details = {
89+
"name": slug_name or "agentstack_project",
90+
"version": "0.0.1",
91+
"description": "New agentstack project",
92+
"author": "Name <Email>",
93+
"license": "MIT"
94+
}
95+
96+
framework = "CrewAI" # TODO: if --no-wizard, require a framework flag
97+
98+
design = {
99+
'agents': [],
100+
'tasks': []
101+
}
102+
103+
tools = []
104+
105+
log.debug(
106+
f"project_details: {project_details}"
107+
f"framework: {framework}"
108+
f"design: {design}"
109+
)
110+
insert_template(project_details, framework, design, template_data)
111+
for tool_data in tools:
112+
generation.add_tool(tool_data['name'], agents=tool_data['agents'], path=project_details['name'])
113+
except Exception as e:
114+
print(term_color(f"An error occurred: {e}", 'red'))
115+
rollback_actions()
116+
sys.exit(1)
103117

104118
def welcome_message():
105119
os.system("cls" if os.name == "nt" else "clear")
@@ -323,17 +337,20 @@ def insert_template(project_details: dict, framework_name: str, design: dict, te
323337
template_path = get_package_path() / f'templates/{framework.name}'
324338
with open(f"{template_path}/cookiecutter.json", "w") as json_file:
325339
json.dump(cookiecutter_data.to_dict(), json_file)
340+
created_files.append(f"{template_path}/cookiecutter.json")
326341

327342
# copy .env.example to .env
328343
shutil.copy(
329344
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env.example',
330345
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env')
346+
created_files.append(f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env')
331347

332348
if os.path.isdir(project_details['name']):
333349
print(term_color(f"Directory {template_path} already exists. Please check this and try again", "red"))
334350
return
335351

336352
cookiecutter(str(template_path), no_input=True, extra_context=None)
353+
created_dirs.append(project_details['name'])
337354

338355
# TODO: inits a git repo in the directory the command was run in
339356
# TODO: not where the project is generated. Fix this
@@ -378,4 +395,4 @@ def list_tools():
378395
print(f": {tool.url if tool.url else 'AgentStack default tool'}")
379396

380397
print("\n\n✨ Add a tool with: agentstack tools add <tool_name>")
381-
print(" https://docs.agentstack.sh/tools/core")
398+
print(" https://docs.agentstack.sh/tools/core")

agentstack/generation/agent_generation.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
from ruamel.yaml import YAML
77
from ruamel.yaml.scalarstring import FoldedScalarString
88

9+
created_files = []
10+
created_dirs = []
11+
12+
def rollback_actions():
13+
for file in created_files:
14+
if os.path.exists(file):
15+
os.remove(file)
16+
for dir in created_dirs:
17+
if os.path.exists(dir):
18+
shutil.rmtree(dir)
919

1020
def generate_agent(
1121
name,
@@ -27,18 +37,19 @@ def generate_agent(
2737

2838
framework = get_framework()
2939

30-
if framework == 'crewai':
31-
generate_crew_agent(name, role, goal, backstory, llm)
32-
print(" > Added to src/config/agents.yaml")
33-
else:
34-
print(f"This function is not yet implemented for {framework}")
35-
return
36-
37-
print(f"Added agent \"{name}\" to your AgentStack project successfully!")
38-
39-
40-
40+
try:
41+
if framework == 'crewai':
42+
generate_crew_agent(name, role, goal, backstory, llm)
43+
print(" > Added to src/config/agents.yaml")
44+
else:
45+
print(f"This function is not yet implemented for {framework}")
46+
return
4147

48+
print(f"Added agent \"{name}\" to your AgentStack project successfully!")
49+
except Exception as e:
50+
print(f"An error occurred: {e}")
51+
rollback_actions()
52+
sys.exit(1)
4253

4354
def generate_crew_agent(
4455
name,
@@ -83,6 +94,7 @@ def generate_crew_agent(
8394
# Write back to the file without altering existing content
8495
with open(config_path, 'w') as file:
8596
yaml.dump(data, file)
97+
created_files.append(config_path)
8698

8799
# Now lets add the agent to crew.py
88100
file_path = 'src/crew.py'
@@ -103,4 +115,4 @@ def generate_crew_agent(
103115

104116
def get_agent_names(framework: str = 'crewai', path: str = '') -> List[str]:
105117
"""Get only agent names from the crew file"""
106-
return get_crew_components(framework, CrewComponent.AGENT, path)['agents']
118+
return get_crew_components(framework, CrewComponent.AGENT, path)['agents']

agentstack/generation/task_generation.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
from ruamel.yaml import YAML
77
from ruamel.yaml.scalarstring import FoldedScalarString
88

9+
created_files = []
10+
created_dirs = []
11+
12+
def rollback_actions():
13+
for file in created_files:
14+
if os.path.exists(file):
15+
os.remove(file)
16+
for dir in created_dirs:
17+
if os.path.exists(dir):
18+
shutil.rmtree(dir)
919

1020
def generate_task(
1121
name,
@@ -24,15 +34,19 @@ def generate_task(
2434

2535
framework = get_framework()
2636

27-
if framework == 'crewai':
28-
generate_crew_task(name, description, expected_output, agent)
29-
print(" > Added to src/config/tasks.yaml")
30-
else:
31-
print(f"This function is not yet implemented for {framework}")
32-
return
33-
34-
print(f"Added task \"{name}\" to your AgentStack project successfully!")
37+
try:
38+
if framework == 'crewai':
39+
generate_crew_task(name, description, expected_output, agent)
40+
print(" > Added to src/config/tasks.yaml")
41+
else:
42+
print(f"This function is not yet implemented for {framework}")
43+
return
3544

45+
print(f"Added task \"{name}\" to your AgentStack project successfully!")
46+
except Exception as e:
47+
print(f"An error occurred: {e}")
48+
rollback_actions()
49+
sys.exit(1)
3650

3751
def generate_crew_task(
3852
name,
@@ -74,6 +88,7 @@ def generate_crew_task(
7488
# Write back to the file without altering existing content
7589
with open(config_path, 'w') as file:
7690
yaml.dump(data, file)
91+
created_files.append(config_path)
7792

7893
# Add task to crew.py
7994
file_path = 'src/crew.py'
@@ -91,4 +106,4 @@ def generate_crew_task(
91106

92107
def get_task_names(framework: str, path: str = '') -> List[str]:
93108
"""Get only task names from the crew file"""
94-
return get_crew_components(framework, CrewComponent.TASK, path)['tasks']
109+
return get_crew_components(framework, CrewComponent.TASK, path)['tasks']

tests/test_cli_loads.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44
from pathlib import Path
55
import shutil
6+
import os
67

78

89
class TestAgentStackCLI(unittest.TestCase):
@@ -44,6 +45,25 @@ def test_init_command(self):
4445
# Clean up
4546
shutil.rmtree(test_dir)
4647

48+
def test_rollback_on_error(self):
49+
"""Test rollback functionality when an error occurs during project initialization."""
50+
test_dir = Path("test_project_with_error")
51+
52+
# Ensure the directory doesn't exist from previous runs
53+
if test_dir.exists():
54+
shutil.rmtree(test_dir)
55+
56+
# Simulate an error by creating a directory that will cause a failure
57+
os.makedirs(test_dir / "src")
58+
59+
result = self.run_cli("init", str(test_dir))
60+
self.assertNotEqual(result.returncode, 0)
61+
self.assertFalse(test_dir.exists()) # Directory should be removed on rollback
62+
63+
# Clean up
64+
if test_dir.exists():
65+
shutil.rmtree(test_dir)
66+
4767

4868
if __name__ == "__main__":
4969
unittest.main()

0 commit comments

Comments
 (0)