Skip to content

Commit 01f8631

Browse files
JoanFMjina-bot
andauthored
test: add test with shaped ndarray (jina-ai#5900)
Co-authored-by: Jina Dev Bot <[email protected]>
1 parent b5c8b64 commit 01f8631

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

jina_cli/autocomplete.py

-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,6 @@
247247
'cloud scale': ['--help', '--executor', '--replicas'],
248248
'cloud recreate': ['--help'],
249249
'cloud logs': ['--help', '--gateway', '--executor'],
250-
'cloud survey': ['--help'],
251250
'cloud': [
252251
'--help',
253252
'--version',
@@ -267,7 +266,6 @@
267266
'scale',
268267
'recreate',
269268
'logs',
270-
'survey',
271269
],
272270
'help': ['--help'],
273271
'pod': [

tests/integration/docarray_v2/test_v2.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from docarray import BaseDoc, DocList
77
from docarray.documents import ImageDoc
8-
from docarray.typing import AnyTensor, ImageUrl
8+
from docarray.typing import AnyTensor, ImageUrl, NdArray
99
from docarray.documents import TextDoc
1010
from docarray.documents.legacy import LegacyDocument
1111
from jina.helper import random_port
@@ -292,7 +292,6 @@ def foo(self, docs: DocList[ProcessingTestDocConditions], **kwargs) -> DocList[P
292292
fp.write(doc.text)
293293
doc.text += f' processed by {self.metas.name}'
294294

295-
296295
class FirstExec(Executor):
297296
@requests
298297
def foo(self, docs: DocList[LegacyDocument], **kwargs) -> DocList[ProcessingTestDocConditions]:
@@ -462,6 +461,32 @@ def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[TextDoc]:
462461
assert len(ret) == 0
463462

464463

464+
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
465+
@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow'])
466+
def test_input_output_with_shaped_tensor(protocol, ctxt_manager):
467+
if ctxt_manager == 'deployment' and protocol == 'websocket':
468+
return
469+
470+
class MyDoc(BaseDoc):
471+
text: str
472+
embedding: NdArray[128]
473+
474+
class Foo(Executor):
475+
@requests(on='/hello')
476+
def foo(self, docs: DocList[MyDoc], **kwargs) -> DocList[MyDoc]:
477+
for doc in docs:
478+
doc.text += 'Processed by foo'
479+
480+
if ctxt_manager == 'flow':
481+
ctxt_mgr = Flow(protocol=protocol).add(uses=Foo)
482+
else:
483+
ctxt_mgr = Deployment(protocol=protocol, uses=Foo)
484+
485+
with ctxt_mgr:
486+
ret = ctxt_mgr.post(on='/hello', inputs=DocList[MyDoc]([MyDoc(text='', embedding=np.random.rand(128))]))
487+
assert len(ret) == 1
488+
489+
465490
@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
466491
@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow'])
467492
def test_send_parameters(protocol, ctxt_manager):
@@ -651,14 +676,14 @@ class Previous(Executor):
651676
def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[TextDoc]:
652677
pass
653678

654-
f = Flow(protocol=protocol).add(uses=Previous, name='previous').add(uses=First, name='first', needs='previous').add(uses=Second, name='second', needs='previous').needs_all()
679+
f = Flow(protocol=protocol).add(uses=Previous, name='previous').add(uses=First, name='first', needs='previous').add(
680+
uses=Second, name='second', needs='previous').needs_all()
655681

656682
with pytest.raises(RuntimeFailToStart):
657683
with f:
658684
pass
659685

660686

661-
662687
class ExternalDeploymentDoc(BaseDoc):
663688
tags: Dict[str, str] = {}
664689

0 commit comments

Comments
 (0)