|
5 | 5 | import numpy as np
|
6 | 6 | from docarray import BaseDoc, DocList
|
7 | 7 | from docarray.documents import ImageDoc
|
8 |
| -from docarray.typing import AnyTensor, ImageUrl |
| 8 | +from docarray.typing import AnyTensor, ImageUrl, NdArray |
9 | 9 | from docarray.documents import TextDoc
|
10 | 10 | from docarray.documents.legacy import LegacyDocument
|
11 | 11 | from jina.helper import random_port
|
@@ -292,7 +292,6 @@ def foo(self, docs: DocList[ProcessingTestDocConditions], **kwargs) -> DocList[P
|
292 | 292 | fp.write(doc.text)
|
293 | 293 | doc.text += f' processed by {self.metas.name}'
|
294 | 294 |
|
295 |
| - |
296 | 295 | class FirstExec(Executor):
|
297 | 296 | @requests
|
298 | 297 | def foo(self, docs: DocList[LegacyDocument], **kwargs) -> DocList[ProcessingTestDocConditions]:
|
@@ -462,6 +461,32 @@ def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[TextDoc]:
|
462 | 461 | assert len(ret) == 0
|
463 | 462 |
|
464 | 463 |
|
| 464 | +@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) |
| 465 | +@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow']) |
| 466 | +def test_input_output_with_shaped_tensor(protocol, ctxt_manager): |
| 467 | + if ctxt_manager == 'deployment' and protocol == 'websocket': |
| 468 | + return |
| 469 | + |
| 470 | + class MyDoc(BaseDoc): |
| 471 | + text: str |
| 472 | + embedding: NdArray[128] |
| 473 | + |
| 474 | + class Foo(Executor): |
| 475 | + @requests(on='/hello') |
| 476 | + def foo(self, docs: DocList[MyDoc], **kwargs) -> DocList[MyDoc]: |
| 477 | + for doc in docs: |
| 478 | + doc.text += 'Processed by foo' |
| 479 | + |
| 480 | + if ctxt_manager == 'flow': |
| 481 | + ctxt_mgr = Flow(protocol=protocol).add(uses=Foo) |
| 482 | + else: |
| 483 | + ctxt_mgr = Deployment(protocol=protocol, uses=Foo) |
| 484 | + |
| 485 | + with ctxt_mgr: |
| 486 | + ret = ctxt_mgr.post(on='/hello', inputs=DocList[MyDoc]([MyDoc(text='', embedding=np.random.rand(128))])) |
| 487 | + assert len(ret) == 1 |
| 488 | + |
| 489 | + |
465 | 490 | @pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket'])
|
466 | 491 | @pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow'])
|
467 | 492 | def test_send_parameters(protocol, ctxt_manager):
|
@@ -651,14 +676,14 @@ class Previous(Executor):
|
651 | 676 | def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[TextDoc]:
|
652 | 677 | pass
|
653 | 678 |
|
654 |
| - f = Flow(protocol=protocol).add(uses=Previous, name='previous').add(uses=First, name='first', needs='previous').add(uses=Second, name='second', needs='previous').needs_all() |
| 679 | + f = Flow(protocol=protocol).add(uses=Previous, name='previous').add(uses=First, name='first', needs='previous').add( |
| 680 | + uses=Second, name='second', needs='previous').needs_all() |
655 | 681 |
|
656 | 682 | with pytest.raises(RuntimeFailToStart):
|
657 | 683 | with f:
|
658 | 684 | pass
|
659 | 685 |
|
660 | 686 |
|
661 |
| - |
662 | 687 | class ExternalDeploymentDoc(BaseDoc):
|
663 | 688 | tags: Dict[str, str] = {}
|
664 | 689 |
|
|
0 commit comments