Skip to content

Commit 18fb903

Browse files
JoanFMjina-botalexcg1
authored
feat: Flow compatible with docarray v2 (jina-ai#5861)
Signed-off-by: Joan Fontanals Martinez <[email protected]> Co-authored-by: Jina Dev Bot <[email protected]> Co-authored-by: Alex Cureton-Griffiths <[email protected]>
1 parent dbbecad commit 18fb903

29 files changed

+2240
-525
lines changed

.github/workflows/cd.yml

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ jobs:
140140
- name: Test
141141
id: test
142142
run: |
143+
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/unit/serve/runtimes/test_helper.py
143144
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2
144145
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/deployment_http_composite
145146
echo "flag it as jina for codeoverage"

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ jobs:
421421
- name: Test
422422
id: test
423423
run: |
424+
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/unit/serve/runtimes/test_helper.py
424425
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2
425426
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/deployment_http_composite
426427
echo "flag it as jina for codeoverage"

docs/docarray-v2.md

+93-10
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,99 @@ If there is no `request_schema` and `response_schema`, the type hint is used to
9292
and `response_schema` will be used.
9393

9494

95+
## Serve one Executor in a Deployment
96+
97+
Once you have defined the Executor with the New Executor API, you can easily serve and scale it as a Deployment with `gRPC`, `HTTP` or any combination of these
98+
protocols.
99+
100+
101+
```{code-block} python
102+
from jina import Deployment
103+
104+
with Deployment(uses=MyExec, protocol='grpc', replicas=2) as dep:
105+
dep.block()
106+
```
107+
108+
109+
## Chain Executors in Flow with different schemas
110+
111+
With the new API, when building a Flow you should ensure that the Document types used as input of an Executor match the schema
112+
of the output of its incoming previous Flow.
113+
114+
For instance, this Flow will fail to start because the Document types are wrongly chained.
115+
116+
````{tab} Valid Flow
117+
```{code-block} python
118+
from jina import Executor, requests, Flow
119+
from docarray import DocList, BaseDoc
120+
from docarray.typing import NdArray
121+
import numpy as np
122+
123+
124+
class SimpleStrDoc(BaseDoc):
125+
text: str
126+
127+
class TextWithEmbedding(SimpleStrDoc):
128+
embedding: NdArray
129+
130+
class TextEmbeddingExecutor(Executor):
131+
@requests(on='/foo')
132+
def foo(docs: DocList[SimpleStrDoc], **kwargs) -> DocList[TextWithEmbedding]
133+
ret = DocList[TextWithEmbedding]()
134+
for doc in docs:
135+
ret.append(TextWithEmbedding(text=doc.text, embedding=np.ramdom.rand(10))
136+
return ret
137+
138+
class ProcessEmbedding(Executor):
139+
@requests(on='/foo')
140+
def foo(docs: DocList[TextWithEmbedding], **kwargs) -> DocList[TextWithEmbedding]
141+
for doc in docs:
142+
self.logger.info(f'Getting embedding with shape {doc.embedding.shape}')
143+
144+
flow = Flow().add(uses=TextEmbeddingExecutor, name='embed').add(uses=ProcessEmbedding, name='process')
145+
with flow:
146+
flow.block()
147+
```
148+
````
149+
````{tab} Invalid Flow
150+
```yaml
151+
from jina import Executor, requests, Flow
152+
from docarray import DocList, BaseDoc
153+
from docarray.typing import NdArray
154+
import numpy as np
155+
156+
157+
class SimpleStrDoc(BaseDoc):
158+
text: str
159+
160+
class TextWithEmbedding(SimpleStrDoc):
161+
embedding: NdArray
162+
163+
class TextEmbeddingExecutor(Executor):
164+
@requests(on='/foo')
165+
def foo(docs: DocList[SimpleStrDoc], **kwargs) -> DocList[TextWithEmbedding]
166+
ret = DocList[TextWithEmbedding]()
167+
for doc in docs:
168+
ret.append(TextWithEmbedding(text=doc.text, embedding=np.ramdom.rand(10))
169+
return ret
170+
171+
class ProcessText(Executor):
172+
@requests(on='/foo')
173+
def foo(docs: DocList[SimpleStrDoc], **kwargs) -> DocList[TextWithEmbedding]
174+
for doc in docs:
175+
self.logger.info(f'Getting embedding with type {doc.text}')
176+
177+
# This Flow will fail to start because the input type of "process" does not match the output type of "embed"
178+
flow = Flow().add(uses=TextEmbeddingExecutor, name='embed').add(uses=ProcessText, name='process')
179+
with flow:
180+
flow.block()
181+
```
182+
````
183+
184+
95185
## Client API
96186

97-
In the client, you similarly specify the schema that you expect the Flow to return. You can pass the return type by using the `return_type` parameter in the `client.post` method:
187+
Similarly, In the client, you specify the schema that you expect the Deployment or Flow to return. You can pass the return type by using the `return_type` parameter in the `client.post` method:
98188

99189
```{code-block} python
100190
---
@@ -117,20 +207,13 @@ with Deployment(uses=MyExec) as dep:
117207

118208
Jina is working to offer full compatibility with the new DocArray version.
119209

120-
At present, these features are supported if you use the APIs described in the previous sections.
121-
122-
- All the features offered by {ref}`Deployment <deployment>` where a single Executor is served and {ref}`scaled <scale-out>`. This includes both
123-
HTTP and gRPC protocols, and both of them at the same type.
124-
125-
- When combining multiple Deployments in a pipeline using a {ref}`Flow <flow-cookbook>`, there are currently several limitations:
210+
However, there are currently some limitations to consider.
126211

127-
- Only gRPC protocol is supported.
128-
- Only linear Flows are supported, no topologies using bifurcations can be used at the moment.
129212

130213
````{admonition} Note
131214
:class: note
132215
133-
With DocArray 0.30 support, Jina introduced the concept of input/output schema at the Executor level. In order to chain multiple Executor into a Flow you always need to make sure that the output schema of an Executor is the same as the Input of the Executor that follows him in the Flow
216+
With DocArray 0.30 support, Jina introduced the concept of input/output schema at the Executor level. To chain multiple Executors into a Flow you need to ensure that the output schema of an Executor is the same as the input of the Executor that follows it in the Flow
134217
```
135218
136219
````{admonition} Note

jina/clients/request/helper.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jina.types.request.data import DataRequest
77

88
if docarray_v2:
9-
from docarray import DocList
9+
from docarray import DocList, BaseDoc
1010

1111

1212
def _new_data_request_from_batch(
@@ -80,7 +80,11 @@ def _add_docs(req: DataRequest, batch, data_type: DataInputType) -> None:
8080
if not docarray_v2:
8181
da = DocumentArray([])
8282
else:
83-
da = DocList[batch[0].__class__]()
83+
if len(batch) > 0:
84+
da = DocList[batch[0].__class__]()
85+
else:
86+
da = DocList[BaseDoc]()
87+
8488
for content in batch:
8589
d, data_type = _new_doc_from_data(content, data_type)
8690
da.append(d)

jina/serve/consensus/run.go

-1
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ func run(self *C.PyObject, args *C.PyObject, kwargs *C.PyObject) *C.PyObject {
359359

360360
//export add_voter
361361
func add_voter(self *C.PyObject, args *C.PyObject) *C.PyObject {
362-
//TODO: Instantiate new logger based on JINA_LOG_LEVEL
363362
logLevel := os.Getenv("JINA_LOG_LEVEL")
364363
if logLevel == "" {
365364
logLevel = "INFO"

jina/serve/networking/__init__.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,15 @@ def send_discover_endpoint(
203203
shard_id: Optional[int] = None,
204204
timeout: Optional[float] = None,
205205
retries: Optional[int] = -1,
206-
) -> Optional[asyncio.Task]:
206+
):
207207
"""Sends a discover Endpoint call to target.
208208
209209
:param deployment: name of the Jina deployment to send the request to
210210
:param head: If True it is send to the head, otherwise to the worker pods
211211
:param shard_id: Send to a specific shard of the deployment, ignored for polling ALL
212212
:param timeout: timeout for sending the requests
213213
:param retries: number of retries per gRPC call. If <0 it defaults to max(3, num_replicas)
214-
:return: asyncio.Task items to send call
214+
:return: coroutine items to send call
215215
"""
216216
connection_list = self._connections.get_replicas(
217217
deployment, head, shard_id, True
@@ -378,12 +378,9 @@ async def _handle_aiorpcerror(
378378
details=error.details(),
379379
)
380380
else:
381-
if error.code() == grpc.StatusCode.UNAVAILABLE and 'not the leader' in error.details():
382-
self._logger.debug(f'RAFT node of {current_deployment} is not the leader. Trying next replica, if available.')
383-
else:
384-
self._logger.debug(
385-
f'gRPC call to deployment {current_deployment} failed with error {format_grpc_error(error)}, for retry attempt {retry_i + 1}/{total_num_tries - 1}.'
386-
f' Trying next replica, if available.'
381+
if connection_list:
382+
await connection_list.reset_connection(
383+
current_address, current_deployment
387384
)
388385
return None
389386

@@ -460,11 +457,10 @@ def _send_discover_endpoint(
460457
connection_list: _ReplicaList,
461458
timeout: Optional[float] = None,
462459
retries: Optional[int] = -1,
463-
) -> asyncio.Task:
460+
):
464461
# this wraps the awaitable object from grpc as a coroutine so it can be used as a task
465462
# the grpc call function is not a coroutine but some _AioCall
466-
async def task_wrapper():
467-
463+
async def task_coroutine():
468464
tried_addresses = set()
469465
if retries is None or retries < 0:
470466
total_num_tries = (
@@ -500,7 +496,7 @@ async def task_wrapper():
500496
except AttributeError:
501497
return default_endpoints_proto, None
502498

503-
return asyncio.create_task(task_wrapper())
499+
return task_coroutine()
504500

505501
async def warmup(
506502
self,
@@ -557,7 +553,6 @@ async def task_wrapper(target_warmup_responses, stub):
557553
for task in tasks:
558554
task.cancel()
559555
raise
560-
561556
except Exception as ex:
562557
self._logger.error(f'error with warmup up task: {ex}')
563558
return
@@ -568,5 +563,4 @@ def _get_all_replicas(self, deployment):
568563
replica_set.add(
569564
self._connections.get_replicas(deployment=deployment, head=True)
570565
)
571-
572566
return set(filter(None, replica_set))

jina/serve/networking/replica_list.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ class _ReplicaList:
2323
"""
2424

2525
def __init__(
26-
self,
27-
metrics: _NetworkingMetrics,
28-
histograms: _NetworkingHistograms,
29-
logger,
30-
runtime_name: str,
31-
aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None,
32-
tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None,
33-
deployment_name: str = '',
34-
channel_options: Optional[Union[list, Dict[str, Any]]] = None,
26+
self,
27+
metrics: _NetworkingMetrics,
28+
histograms: _NetworkingHistograms,
29+
logger,
30+
runtime_name: str,
31+
aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None,
32+
tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None,
33+
deployment_name: str = '',
34+
channel_options: Optional[Union[list, Dict[str, Any]]] = None,
3535
):
3636
self.runtime_name = runtime_name
3737
self._connections = []
@@ -59,20 +59,21 @@ async def reset_connection(self, address: str, deployment_name: str):
5959
:param deployment_name: Target deployment of this connection
6060
"""
6161
self._logger.debug(f'resetting connection for {deployment_name} to {address}')
62-
62+
parsed_address = urlparse(address)
63+
resolved_address = parsed_address.netloc if parsed_address.netloc else address
6364
if (
64-
address in self._address_to_connection_idx
65-
and self._address_to_connection_idx[address] is not None
65+
resolved_address in self._address_to_connection_idx
66+
and self._address_to_connection_idx[resolved_address] is not None
6667
):
6768
# remove connection:
6869
# in contrast to remove_connection(), we don't 'shorten' the data structures below, instead
6970
# update the data structure with the new connection and let the old connection be colleced by
7071
# the GC
71-
id_to_reset = self._address_to_connection_idx[address]
72+
id_to_reset = self._address_to_connection_idx[resolved_address]
7273
# re-add connection:
73-
self._address_to_connection_idx[address] = id_to_reset
74+
self._address_to_connection_idx[resolved_address] = id_to_reset
7475
stubs, channel = self._create_connection(address, deployment_name)
75-
self._address_to_channel[address] = channel
76+
self._address_to_channel[resolved_address] = channel
7677
self._connections[id_to_reset] = stubs
7778

7879
def add_connection(self, address: str, deployment_name: str):
@@ -81,10 +82,13 @@ def add_connection(self, address: str, deployment_name: str):
8182
:param address: Target address of this connection
8283
:param deployment_name: Target deployment of this connection
8384
"""
84-
if address not in self._address_to_connection_idx:
85-
self._address_to_connection_idx[address] = len(self._connections)
85+
parsed_address = urlparse(address)
86+
resolved_address = parsed_address.netloc if parsed_address.netloc else address
87+
88+
if resolved_address not in self._address_to_connection_idx:
89+
self._address_to_connection_idx[resolved_address] = len(self._connections)
8690
stubs, channel = self._create_connection(address, deployment_name)
87-
self._address_to_channel[address] = channel
91+
self._address_to_channel[resolved_address] = channel
8892
self._connections.append(stubs)
8993
# create a new set of stubs and channels for warmup to avoid
9094
# loosing channel during remove_connection or reset_connection
@@ -103,20 +107,23 @@ async def remove_connection(self, address: str):
103107
104108
:param address: Remove connection for this address
105109
"""
106-
if address in self._address_to_connection_idx:
110+
parsed_address = urlparse(address)
111+
resolved_address = parsed_address.netloc if parsed_address.netloc else address
112+
if resolved_address in self._address_to_connection_idx:
107113
self._rr_counter = (
108114
self._rr_counter % (len(self._connections) - 1)
109115
if (len(self._connections) - 1)
110116
else 0
111117
)
112-
idx_to_delete = self._address_to_connection_idx.pop(address)
118+
idx_to_delete = self._address_to_connection_idx.pop(resolved_address)
113119
self._connections.pop(idx_to_delete)
114120
# update the address/idx mapping
115-
for address in self._address_to_connection_idx:
116-
if self._address_to_connection_idx[address] > idx_to_delete:
117-
self._address_to_connection_idx[address] -= 1
121+
for a in self._address_to_connection_idx:
122+
if self._address_to_connection_idx[a] > idx_to_delete:
123+
self._address_to_connection_idx[a] -= 1
118124

119125
def _create_connection(self, address, deployment_name: str):
126+
self._logger.debug(f'create_connection connection for {deployment_name} to {address}')
120127
parsed_address = urlparse(address)
121128
address = parsed_address.netloc if parsed_address.netloc else address
122129
use_tls = parsed_address.scheme in TLS_PROTOCOL_SCHEMES
@@ -185,7 +192,9 @@ def has_connection(self, address: str) -> bool:
185192
:param address: The address to check
186193
:returns: True if a connection for the ip exists in the list
187194
"""
188-
return address in self._address_to_connection_idx
195+
parsed_address = urlparse(address)
196+
resolved_address = parsed_address.netloc if parsed_address.netloc else address
197+
return resolved_address in self._address_to_connection_idx
189198

190199
def has_connections(self) -> bool:
191200
"""

0 commit comments

Comments
 (0)