Skip to content

Commit f8c63c1

Browse files
committed
feat(ui): model relationship management
Adds full support for managing model-to-model relationships in the UI and backend. Introduces RelatedModels subpanel for linking and unlinking models in model management. - Adds REST API routes for adding, removing, and retrieving model relationships. - New database migration: creates model_relationships table for bidirectional links. - New service layer (model_relationships) for relationship management. - Updated frontend: Related models float to top of LoRA/Main grouped model comboboxes for quick access. - Added 'Show Only Related' toggle badge to MainModelPicker filter bar
1 parent 19a63ab commit f8c63c1

File tree

23 files changed

+1285
-9
lines changed

23 files changed

+1285
-9
lines changed

invokeai/app/api/dependencies.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
2424
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
2525
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
26+
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
27+
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import SqliteModelRelationshipRecordStorage
2628
from invokeai.app.services.names.names_default import SimpleNameService
2729
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
2830
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@@ -136,6 +138,8 @@ def initialize(
136138
download_queue=download_queue_service,
137139
events=events,
138140
)
141+
model_relationships = ModelRelationshipsService()
142+
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
139143
names = SimpleNameService()
140144
performance_statistics = InvocationStatsService()
141145
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
@@ -161,6 +165,8 @@ def initialize(
161165
logger=logger,
162166
model_images=model_images_service,
163167
model_manager=model_manager,
168+
model_relationships=model_relationships,
169+
model_relationship_records=model_relationship_records,
164170
download_queue=download_queue_service,
165171
names=names,
166172
performance_statistics=performance_statistics,
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""FastAPI route for model relationship records."""
2+
3+
from fastapi import HTTPException, APIRouter, Path, Body, status
4+
from pydantic import BaseModel, Field
5+
from typing import List
6+
from invokeai.app.api.dependencies import ApiDependencies
7+
8+
model_relationships_router = APIRouter(
9+
prefix="/v1/model_relationships",
10+
tags=["model_relationships"]
11+
)
12+
13+
# === Schemas ===
14+
15+
class ModelRelationshipCreateRequest(BaseModel):
16+
model_key_1: str = Field(..., description="The key of the first model in the relationship", examples=[
17+
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
18+
"ac32b914-10ab-496e-a24a-3068724b9c35",
19+
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
20+
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
21+
"12345678-90ab-cdef-1234-567890abcdef",
22+
"fedcba98-7654-3210-fedc-ba9876543210"
23+
])
24+
model_key_2: str = Field(..., description="The key of the second model in the relationship", examples=[
25+
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
26+
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
27+
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
28+
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
29+
"10f7807b-ab54-46a9-ab03-600e88c630a1",
30+
"f6c1d267-cf87-4ee0-bee0-37e791eacab7"
31+
])
32+
33+
class ModelRelationshipBatchRequest(BaseModel):
34+
model_keys: List[str] = Field(..., description="List of model keys to fetch related models for", examples=
35+
[[
36+
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
37+
"ac32b914-10ab-496e-a24a-3068724b9c35",
38+
],[
39+
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
40+
"12345678-90ab-cdef-1234-567890abcdef",
41+
"fedcba98-7654-3210-fedc-ba9876543210"
42+
],[
43+
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
44+
]])
45+
46+
# === Routes ===
47+
48+
@model_relationships_router.get(
49+
"/i/{model_key}",
50+
operation_id="get_related_models",
51+
response_model=list[str],
52+
responses={
53+
200: {
54+
"description": "A list of related model keys was retrieved successfully",
55+
"content": {
56+
"application/json": {
57+
"example": [
58+
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
59+
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
60+
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2"
61+
]
62+
}
63+
},
64+
},
65+
404: {"description": "The specified model could not be found"},
66+
422: {"description": "Validation error"},
67+
},
68+
)
69+
async def get_related_models(
70+
model_key: str = Path(..., description="The key of the model to get relationships for")
71+
) -> list[str]:
72+
"""
73+
Get a list of model keys related to a given model.
74+
"""
75+
try:
76+
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
77+
except Exception as e:
78+
raise HTTPException(status_code=500, detail=str(e))
79+
80+
81+
@model_relationships_router.post(
82+
"/",
83+
status_code=status.HTTP_204_NO_CONTENT,
84+
responses={
85+
204: {"description": "The relationship was successfully created"},
86+
400: {"description": "Invalid model keys or self-referential relationship"},
87+
409: {"description": "The relationship already exists"},
88+
422: {"description": "Validation error"},
89+
500: {"description": "Internal server error"},
90+
},
91+
summary="Add Model Relationship",
92+
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
93+
)
94+
async def add_model_relationship(
95+
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate")
96+
) -> None:
97+
"""
98+
Add a relationship between two models.
99+
100+
Relationships are bidirectional and will be accessible from both models.
101+
102+
- Raises 400 if keys are invalid or identical.
103+
- Raises 409 if the relationship already exists.
104+
"""
105+
try:
106+
if req.model_key_1 == req.model_key_2:
107+
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
108+
109+
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
110+
req.model_key_1,
111+
req.model_key_2,
112+
)
113+
except ValueError as e:
114+
raise HTTPException(status_code=409, detail=str(e))
115+
except Exception as e:
116+
raise HTTPException(status_code=500, detail=str(e))
117+
118+
119+
@model_relationships_router.delete(
120+
"/",
121+
status_code=status.HTTP_204_NO_CONTENT,
122+
responses={
123+
204: {"description": "The relationship was successfully removed"},
124+
400: {"description": "Invalid model keys or self-referential relationship"},
125+
404: {"description": "The relationship does not exist"},
126+
422: {"description": "Validation error"},
127+
500: {"description": "Internal server error"},
128+
},
129+
summary="Remove Model Relationship",
130+
description="Removes a **bidirectional** relationship between two models. The relationship must already exist."
131+
)
132+
async def remove_model_relationship(
133+
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect")
134+
) -> None:
135+
"""
136+
Removes a bidirectional relationship between two model keys.
137+
138+
- Raises 400 if attempting to unlink a model from itself.
139+
- Raises 404 if the relationship was not found.
140+
"""
141+
try:
142+
if req.model_key_1 == req.model_key_2:
143+
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
144+
145+
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
146+
req.model_key_1,
147+
req.model_key_2,
148+
)
149+
except ValueError as e:
150+
raise HTTPException(status_code=404, detail=str(e))
151+
except Exception as e:
152+
raise HTTPException(status_code=500, detail=str(e))
153+
154+
@model_relationships_router.post(
155+
"/batch",
156+
operation_id="get_related_models_batch",
157+
response_model=List[str],
158+
responses={
159+
200: {
160+
"description": "Related model keys retrieved successfully",
161+
"content": {
162+
"application/json": {
163+
"example": [
164+
"ca562b14-995e-4a42-90c1-9528f1a5921d",
165+
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
166+
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
167+
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
168+
"c382eaa3-0e28-4ab0-9446-408667699aeb",
169+
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
170+
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2"
171+
]
172+
}
173+
}
174+
},
175+
422: {"description": "Validation error"},
176+
500: {"description": "Internal server error"},
177+
},
178+
summary="Get Related Model Keys (Batch)",
179+
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering."
180+
)
181+
async def get_related_models_batch(
182+
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections")
183+
) -> list[str]:
184+
"""
185+
Accepts multiple model keys and returns a flat list of all unique related keys.
186+
187+
Useful when working with multiple selections in the UI or cross-model comparisons.
188+
"""
189+
try:
190+
all_related: set[str] = set()
191+
for key in req.model_keys:
192+
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
193+
all_related.update(related)
194+
return list(all_related)
195+
except Exception as e:
196+
raise HTTPException(status_code=500, detail=str(e))

invokeai/app/api_app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
download_queue,
2323
images,
2424
model_manager,
25+
model_relationships,
2526
session_queue,
2627
style_presets,
2728
utilities,
@@ -125,6 +126,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
125126
app.include_router(images.images_router, prefix="/api")
126127
app.include_router(boards.boards_router, prefix="/api")
127128
app.include_router(board_images.board_images_router, prefix="/api")
129+
app.include_router(model_relationships.model_relationships_router, prefix="/api")
128130
app.include_router(app_info.app_router, prefix="/api")
129131
app.include_router(session_queue.session_queue_router, prefix="/api")
130132
app.include_router(workflows.workflows_router, prefix="/api")

invokeai/app/services/invocation_services.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
2828
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
2929
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
30+
from invokeai.app.services.model_relationship_records.model_relationship_records_base import ModelRelationshipRecordStorageBase
31+
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
3032
from invokeai.app.services.names.names_base import NameServiceBase
3133
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
3234
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@@ -54,6 +56,8 @@ def __init__(
5456
logger: "Logger",
5557
model_images: "ModelImageFileStorageBase",
5658
model_manager: "ModelManagerServiceBase",
59+
model_relationships: "ModelRelationshipsServiceABC",
60+
model_relationship_records: "ModelRelationshipRecordStorageBase",
5761
download_queue: "DownloadQueueServiceBase",
5862
performance_statistics: "InvocationStatsServiceBase",
5963
session_queue: "SessionQueueBase",
@@ -81,6 +85,8 @@ def __init__(
8185
self.logger = logger
8286
self.model_images = model_images
8387
self.model_manager = model_manager
88+
self.model_relationships = model_relationships
89+
self.model_relationship_records = model_relationship_records
8490
self.download_queue = download_queue
8591
self.performance_statistics = performance_statistics
8692
self.session_queue = session_queue
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING
3+
4+
if TYPE_CHECKING:
5+
from invokeai.backend.model_manager.config import AnyModelConfig
6+
7+
class ModelRelationshipRecordStorageBase(ABC):
8+
"""Abstract base class for model-to-model relationship record storage."""
9+
10+
@abstractmethod
11+
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
12+
"""Creates a relationship between two models by keys."""
13+
pass
14+
15+
@abstractmethod
16+
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
17+
"""Removes a relationship between two models by keys."""
18+
pass
19+
20+
@abstractmethod
21+
def get_related_model_keys(self, model_key: str) -> list[str]:
22+
"""Gets all models keys related to a given model key."""
23+
pass
24+
25+
@abstractmethod
26+
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
27+
"""Get related model keys for multiple models given a list of keys."""
28+
pass
29+
30+
@abstractmethod
31+
def get_related_model_key_count(self, model_key: str) -> int:
32+
"""Gets the number of relations for a given model key."""
33+
pass
34+
35+
""" Below are methods that use ModelConfigs instead of model keys, as convenience methods.
36+
These methods are not required to be implemented, but they are potentially useful for later development.
37+
They are not used in the current codebase."""
38+
39+
@abstractmethod
40+
def add_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None:
41+
"""Creates a relationship between two models using ModelConfigs."""
42+
pass
43+
44+
@abstractmethod
45+
def remove_relationship_from_models(self, model_1: "AnyModelConfig", model_2: "AnyModelConfig") -> None:
46+
"""Removes a relationship between two models using ModelConfigs."""
47+
pass
48+
49+
@abstractmethod
50+
def get_related_keys_from_model(self, model: "AnyModelConfig") -> list[str]:
51+
"""Gets all model keys related to a given model using it's config."""
52+
pass
53+
54+
@abstractmethod
55+
def get_related_model_key_count_from_model(self, model: "AnyModelConfig") -> int:
56+
"""Gets the number of relations for a given model config."""
57+
pass

0 commit comments

Comments
 (0)