Skip to content

Commit 91bfa76

Browse files
authored
fix: fix isse in head missmatch endpoint (jina-ai#5904)
1 parent 5574aa8 commit 91bfa76

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

jina/serve/runtimes/head/request_handling.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import grpc
1010

1111
from jina.enums import PollingType
12+
from jina.constants import __default_endpoint__
1213
from jina.excepts import InternalNetworkError
1314
from jina.helper import get_full_version
1415
from jina.proto import jina_pb2
@@ -289,7 +290,10 @@ async def _handle_data_request(
289290
response_request = worker_results[0]
290291
found = False
291292
if docarray_v2:
292-
model = self._pydantic_models_by_endpoint[endpoint]['output']
293+
check_endpoint = endpoint
294+
if endpoint not in self._pydantic_models_by_endpoint:
295+
check_endpoint = __default_endpoint__
296+
model = self._pydantic_models_by_endpoint[check_endpoint]['output']
293297
for i, worker_result in enumerate(worker_results):
294298
if docarray_v2:
295299
worker_result.document_array_cls = DocList[model]

tests/integration/docarray_v2/test_v2.py

+27
Original file line numberDiff line numberDiff line change
@@ -987,3 +987,30 @@ def search(self, docs: DocList[TextDocWithId], **kwargs) -> DocList[ResultTestDo
987987
assert len(r.matches) == 6
988988
for match in r.matches:
989989
assert 'ID' in match.text
990+
991+
992+
def test_issue_shards_missmatch_endpoint():
993+
994+
class MyDoc(BaseDoc):
995+
text: str
996+
embedding: NdArray[128]
997+
998+
class MyDocWithMatchesAndScores(MyDoc):
999+
matches: DocList[MyDoc]
1000+
scores: List[float]
1001+
1002+
class MyExec(Executor):
1003+
1004+
@requests
1005+
def foo(self, docs: DocList[MyDoc], **kwargs) -> DocList[MyDocWithMatchesAndScores]:
1006+
res = DocList[MyDocWithMatchesAndScores]()
1007+
for doc in docs:
1008+
new_doc = MyDocWithMatchesAndScores(text=doc.text, embedding=doc.embedding, matches=docs,
1009+
scores=[1.0 for _ in docs])
1010+
res.append(new_doc)
1011+
return res
1012+
1013+
d = Deployment(uses=MyExec, shards=2)
1014+
with d:
1015+
res = d.post(on='/', inputs=DocList[MyDoc]([MyDoc(text='hey ha', embedding=np.random.rand(128))]))
1016+
assert len(res) == 1

0 commit comments

Comments
 (0)