From 788788a734682542d07e94dc54cde26020520505 Mon Sep 17 00:00:00 2001 From: Joey Teng Date: Mon, 12 Jun 2023 10:00:42 +0100 Subject: [PATCH] 0.3.0: Performance Improvement (#2) * refactor(jit): make all jitted functions "inline=True" This should be more beneficial, as followed by the discussions in Jax repository, see google/jax#6584 google/jax#6681 google/jax#9298 google/jax#9342 * perf(pipeline): try to render 4/2/1 rows per batch using vmap to reduce fori_loop iterations * feat(_meta_utils): simple way to add multiple trace annotations together for functions add `@ad_tracing_name` to most functions to assist profiling also bump to Python 3.10 BREAKING CHANGE: Now requires Python 3.10 * perf(pipeline): big refactor to not updating per rows, but renders all rows then concat and merge * perf(pipeline): using scan + unroll (equiv map + unroll) This is very similar to map + vmap (minibatch processing) as the inner loop is too complex * build(pyproject): try to relax jax{lib}'s verion constraint to ">=0.3.25,<5.0.0" * test(pre-gen-brax): example inputs for profiling * perf: try to eliminate all `lax.cond` under `vmap`, `lax.cond` are lowered to `select_n` in HLO which leads to execution in both branches, thus fails to 1) save computation when possible; 2) prevent unexpected values to be produced/unexpected branches to be executed (defensive), thus let the non-dummy branch to be executed anyway and only rule-out garbage value at the final stage all together to try to improve performance. See google/brax#8409 for more details about unconditional executation of cond under vmap * fix(pipeline): gl_FrontFacing: fix its determination in pipeline `True` if NOT back-facing * perf: added extra stage in pipeline, aiming to interpolate and shade only one fragment per pixel * docs(changelog): expose option `loop_unroll`; dependency version change Bump minimum Python version from 3.9 to 3.10; lower minimum jax & jaxlib to 0.3.25. * build(pyproject): bump to 0.3.0 --- changelog.md | 8 + pyproject.toml | 8 +- renderer/README.md | 12 +- renderer/_meta_utils.py | 25 ++ renderer/geometry.py | 92 +++-- renderer/model.py | 15 +- renderer/pipeline.py | 361 ++++++++++++-------- renderer/renderer.py | 26 +- renderer/shader.py | 134 ++++++-- renderer/shaders/depth.py | 5 +- renderer/shaders/gouraud.py | 8 +- renderer/shaders/gouraud_texture.py | 11 +- renderer/shaders/phong.py | 11 +- renderer/shaders/phong_darboux.py | 14 +- renderer/shaders/phong_reflection.py | 20 +- renderer/shaders/phong_reflection_shadow.py | 20 +- renderer/shadow.py | 35 +- renderer/utils.py | 12 +- test_resources/pre-gen-brax/README.md | 6 + test_resources/pre-gen-brax/inputs-2.zip | Bin 0 -> 87861 bytes test_resources/pre-gen-brax/inputs-30.zip | Bin 0 -> 221985 bytes 21 files changed, 582 insertions(+), 241 deletions(-) create mode 100644 renderer/_meta_utils.py create mode 100644 test_resources/pre-gen-brax/README.md create mode 100644 test_resources/pre-gen-brax/inputs-2.zip create mode 100644 test_resources/pre-gen-brax/inputs-30.zip diff --git a/changelog.md b/changelog.md index fbcbe20..1652f21 100644 --- a/changelog.md +++ b/changelog.md @@ -49,3 +49,11 @@ 1. Refactor `Scene.set_object_*` methods to be a simple wrapper of `self._replace` and `ModelObject.replace_with_*`, to expose APIs of `ModelObject`s and allows manipulation and rendering without `Scene`. 2. Expose `create_capsule` and `create_cube` APIs. + +## 0.3.0 + +1. Fix `gl_FrontFacing` computation in pipeline so it is consistent to comment: `True` if not backfacing (i.e. frontfacing & side facing). +2. Add an extra stage `Shader.primitive_chooser` to choose which primitive to be rendered for each fragment. The default implementation is provided, which assumes that the depth is just the interpolated `z` value in the eye space. It just picks the values of the single primitive that is closest to the camera and is not discarded in the previous pipeline. +3. Expose `loop_unroll` static option to allow unrolling several operations (row rendering) within a single iteration of the outmost loop (iterating along first axis of the canvas). This may be useful in some cases for performance improvement, but careful benchmarking is needed to determine the optimal value. The default value is `1` (no unrolling) as it is the most general case in larger canvases (benchmarked on `960x540` using [GPU T4 in Colab](https://colab.research.google.com/drive/1xhkYNz5WjvUCjQWpp72CLf9SIy3i5PnN)). +4. Bump the minimum Python version to Python 3.10 +5. Lower the minimum jax & jaxlib version to 0.3.25. diff --git a/pyproject.toml b/pyproject.toml index 1dd42c0..6f06b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "jaxrenderer" -version = "0.2.1" +version = "0.3.0" description = "Jax implementation of rasterizer renderer." authors = ["Joey Teng "] license = "Apache-2.0" @@ -32,10 +32,10 @@ include = [ [tool.poetry.dependencies] -python = "^3.9" -jax = "^0.4.4" +python = "^3.10" +jax = ">=0.3.25,<5.0.0" numpy = "^1.22.0" -jaxlib = {version = "^0.4.4", source = "jax"} +jaxlib = {version = ">=0.3.25,<5.0.0", source = "jax"} jaxtyping = "^0.2.19" importlib-metadata = "^6.6.0" diff --git a/renderer/README.md b/renderer/README.md index dca8a55..0d213ef 100644 --- a/renderer/README.md +++ b/renderer/README.md @@ -37,10 +37,18 @@ The fragments are generated for each position in the zbuffer, and if the screen The processing is done by iterating along the first axis of the buffer, and `vmap` along the second axis, and `vmap` along the primitives. Thus, all fragments at the same position are generated together, then mixed, then written to the buffer. -Interpolation of all attributes for each fragment is defined by `Shader.interpolate`. The default implementation is provided, which simply linearly interpolates the attributes in the clip space according to the barycentric coordinates of the fragment. This behaviour is wrapped as a convenient function `interpolate` and mode `Interpolation.SMOOTH`. The interpolated values are then passed to the fragment shader. - Currently no anti-aliasing strategy is supported. +#### Optional Early Depth Test + +This is an additional stage in this pipeline which may be analogical to the early depth test in OpenGL. It is implemented in `Shader.primitive_chooser`. The default implementation is provided, which assumes that the depth is just the interpolated `z` value in the eye space. It just picks the values of the single primitive that is closest to the camera and is not discarded in the previous pipeline. + +Custom implementations can overload to change this behaviour to achieve special effects like occlusion, transparency, etc. Note that the number of returned primitives must be static, i.e., same for all fragments, as required by `jax.jit`. + +#### Interpolation of Attributes + +Interpolation of all attributes for each fragment is defined by `Shader.interpolate`. The default implementation is provided, which simply linearly interpolates the attributes in the clip space according to the barycentric coordinates of the fragment. This behaviour is wrapped as a convenient function `interpolate` and mode `Interpolation.SMOOTH`. The interpolated values are then passed to the fragment shader. Currently only `Interpolation.SMOOTH` and `Interpolation.FLAT` are supported. + Reference: - For interpolation modes, see [Interpolation Qualifiers](https://www.khronos.org/opengl/wiki/Type_Qualifier_(GLSL)#Interpolation_qualifiers) diff --git a/renderer/_meta_utils.py b/renderer/_meta_utils.py new file mode 100644 index 0000000..ece334c --- /dev/null +++ b/renderer/_meta_utils.py @@ -0,0 +1,25 @@ +import functools +import inspect +from typing import Callable, ParamSpec, TypeVar + +import jax + +ArgT = ParamSpec("ArgT") +RetT = TypeVar("RetT") + + +def add_tracing_name(func: Callable[ArgT, RetT]) -> Callable[ArgT, RetT]: + """Add tracing name to function.""" + + members: dict[str, str] + members = dict(inspect.getmembers(func, lambda v: isinstance(v, str))) + annotation: str = (f"{members.get('__module__', '')}" + f":{members.get('__qualname__', '')}") + + @functools.wraps(func) + def wrapper(*args: ArgT.args, **kwargs: ArgT.kwargs) -> RetT: + with jax.named_scope(annotation): + with jax.profiler.TraceAnnotation(annotation): + return func(*args, **kwargs) + + return wrapper diff --git a/renderer/geometry.py b/renderer/geometry.py index a8c584f..99d0ae8 100644 --- a/renderer/geometry.py +++ b/renderer/geometry.py @@ -7,6 +7,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float, Integer, Num, jaxtyped +from ._meta_utils import add_tracing_name from .types import Triangle2Df, Vec2f, Vec3f, Vec4f # Transform matrix that takes a batch of homogeneous 3D vertices and transform @@ -28,7 +29,8 @@ @jaxtyped -@partial(jax.jit, donate_argnums=(0, )) +@partial(jax.jit, donate_argnums=(0, ), inline=True) +@add_tracing_name def normalise(vector: Float[Array, "dim"]) -> Float[Array, "dim"]: """normalise vector in-place.""" return vector / jnp.linalg.norm(vector) @@ -48,7 +50,8 @@ class Interpolation(enum.Enum): """Perspective correction: linear interpolation in clip space""" @jaxtyped - @partial(jax.jit, static_argnames=("self", )) + @partial(jax.jit, static_argnames=("self", ), inline=True) + @add_tracing_name def __call__( self, values: Num[Array, "3 *valueDimensions"], @@ -92,7 +95,8 @@ def __call__( @jaxtyped -@partial(jax.jit, static_argnames=("mode", )) +@partial(jax.jit, static_argnames=("mode", ), inline=True) +@add_tracing_name def interpolate( values: Num[Array, "3 *valueDimensions"], barycentric_screen: Vec3f, @@ -115,7 +119,8 @@ def interpolate( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def to_homogeneous( coordinates: Float[Array, "*batch dim"], value: Float[Array, "*batch"] = jnp.array(1.), @@ -138,7 +143,8 @@ def to_homogeneous( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def normalise_homogeneous( coordinates: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim"]: """Transform the homogenous coordinates to make the scale factor equals to @@ -152,7 +158,8 @@ def normalise_homogeneous( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def to_cartesian( coordinates: Float[Array, "*batch dim"], ) -> Float[Array, "*batch dim-1"]: """Transform the homogenous coordinates to cartesian coordinates by divide @@ -192,7 +199,8 @@ class Camera(NamedTuple): @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", )) + @partial(jax.jit, static_argnames=("cls", ), inline=True) + @add_tracing_name def create( cls, view: View, @@ -240,7 +248,8 @@ def create( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def apply( points: Num[Array, "*N 4"], matrix: Num[Array, "4 4"], @@ -275,7 +284,8 @@ def apply( @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", )) + @partial(jax.jit, static_argnames=("cls", ), inline=True) + @add_tracing_name def apply_pos( cls, points: Num[Array, "*N 3"], @@ -305,7 +315,8 @@ def apply_pos( @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", )) + @partial(jax.jit, static_argnames=("cls", ), inline=True) + @add_tracing_name def apply_vec( cls, vectors: Num[Array, "*N 3"], @@ -343,7 +354,8 @@ def apply_vec( return transformed_normalised @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def to_screen( self, points: Num[Array, "*N 4"], @@ -368,7 +380,8 @@ def to_screen( return normalised @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def to_clip( self, points: Num[Array, "*N 4"], @@ -390,7 +403,8 @@ def to_clip( return clip_space @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def to_screen_inv( self, screen: Float[Array, "*N 4"], @@ -442,7 +456,8 @@ def to_screen_inv( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def inv_scale_translation_matrix( scale_translation_mat: Float[Array, "4 4"]) -> Float[Array, "4 4"]: """Compute the inverse matrix of a (4, 4) matrix representing a scale and translation, in a form of: @@ -483,7 +498,8 @@ def inv_scale_translation_matrix( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def view_matrix( eye: Vec3f, centre: Vec3f, @@ -524,7 +540,8 @@ def view_matrix( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def view_matrix_inv( eye: Vec3f, centre: Vec3f, @@ -577,7 +594,8 @@ def view_matrix_inv( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def perspective_projection_matrix( fovy: jnp.floating[Any], aspect: jnp.floating[Any], @@ -621,7 +639,8 @@ def perspective_projection_matrix( @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", )) + @partial(jax.jit, static_argnames=("cls", ), inline=True) + @add_tracing_name def perspective_projection_matrix_inv(cls, mat: Projection) -> Projection: """Create the inverse of a perspective projection matrix as defined in `perspective_projection_matrix`. @@ -650,7 +669,8 @@ def perspective_projection_matrix_inv(cls, mat: Projection) -> Projection: @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def orthographic_projection_matrix( left: jnp.floating[Any], right: jnp.floating[Any], @@ -694,6 +714,7 @@ def orthographic_projection_matrix( @classmethod @jaxtyped @partial(jax.jit, static_argnames=("cls", )) + @add_tracing_name def orthographic_projection_matrix_inv(cls, mat: Projection) -> Projection: """Create the inverse of a orthographic projection matrix as defined in `orthographic_projection_matrix`. Since orthographic projection @@ -707,7 +728,8 @@ def orthographic_projection_matrix_inv(cls, mat: Projection) -> Projection: @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def perspective_projection_matrix_tinyrenderer( eye: Vec3f, centre: Vec3f, @@ -734,7 +756,8 @@ def perspective_projection_matrix_tinyrenderer( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def viewport_matrix( lowerbound: Num[Array, "2"], dimension: Integer[Array, "2"], @@ -766,7 +789,8 @@ def viewport_matrix( @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", )) + @partial(jax.jit, static_argnames=("cls", ), inline=True) + @add_tracing_name def viewport_matrix_inv(cls, viewport: Viewport) -> Viewport: """Create the inverse of a viewport matrix as defined in `viewport_matrix`. @@ -782,7 +806,8 @@ def viewport_matrix_inv(cls, viewport: Viewport) -> Viewport: @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def world_to_screen_matrix(width: int, height: int) -> World2Screen: """Generate the projection matrix to convert model coordinates to screen/canvas coordinates. @@ -804,7 +829,8 @@ def world_to_screen_matrix(width: int, height: int) -> World2Screen: @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def compute_normal(triangle_verts: Float[Array, "3 3"]) -> Float[Array, "3"]: normal: Float[Array, "3"] = jnp.cross( triangle_verts[2, :] - triangle_verts[0, :], @@ -817,13 +843,15 @@ def compute_normal(triangle_verts: Float[Array, "3 3"]) -> Float[Array, "3"]: @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def compute_normals(batch_verts: Float[Array, "b 3 3"]) -> Float[Array, "b 3"]: return jax.vmap(compute_normal)(batch_verts) @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def quaternion( rotation_axis: Union[Vec3f, tuple[float, float, float]], rotation_angle: Union[Float[Array, ""], float], @@ -850,7 +878,8 @@ def quaternion( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def quaternion_mul(quatA: Vec4f, quatB: Vec4f) -> Vec4f: """Multiply two quaternion rotations, as to composite them. @@ -877,7 +906,8 @@ def quaternion_mul(quatA: Vec4f, quatB: Vec4f) -> Vec4f: @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def rotation_matrix( rotation_axis: Union[Vec3f, tuple[float, float, float]], rotation_angle: Union[Float[Array, ""], float], @@ -908,7 +938,8 @@ def rotation_matrix( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]: """Generate a transform matrix from a quaternion rotation. @@ -936,7 +967,8 @@ def transform_matrix_from_rotation(rotation: Vec4f) -> Float[Array, "3 3"]: @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def barycentric(pts: Triangle2Df, p: Vec2f) -> Vec3f: """Compute the barycentric coordinate of `p`. Returns u[-1] < 0 if `p` is outside of the triangle. diff --git a/renderer/model.py b/renderer/model.py index 6ecd564..252922a 100644 --- a/renderer/model.py +++ b/renderer/model.py @@ -9,6 +9,7 @@ from jaxtyping import (Array, Bool, Float, Integer, Num, PyTree, Shaped, jaxtyped) +from ._meta_utils import add_tracing_name from .geometry import Camera, transform_matrix_from_rotation from .types import (FALSE_ARRAY, FaceIndices, Normals, SpecularMap, Texture, UVCoordinates, Vec3f, Vec4f, Vertices) @@ -162,6 +163,7 @@ class MergedModel(NamedTuple): @staticmethod @jaxtyped + @add_tracing_name def generate_object_vert_info( counts: Sequence[int], values: Sequence[Shaped[Array, "..."]], @@ -191,7 +193,8 @@ def generate_object_vert_info( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def merge_verts( vs: VertsT, fs: FaceIndicessT, @@ -219,7 +222,8 @@ def merge_verts( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def merge_maps(maps: MapsT) -> tuple[MapT, Shape2DT]: """Merge maps by concatenating them along the first axis. @@ -269,7 +273,8 @@ def merge_maps(maps: MapsT) -> tuple[MapT, Shape2DT]: @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def uv_repeat( uv: Float[Array, "2"], shape: Integer[Array, "2"], @@ -390,6 +395,7 @@ def replace_with_double_sided( return self._replace(double_sided=double_sided) +@add_tracing_name def batch_models(models: Sequence[MergedModel]) -> MergedModel: """Merge multiple MergedModel into one, with each field being a batch, with batch axis at 0. This is intended to facilitate `jax.vmap`. @@ -403,6 +409,7 @@ def batch_models(models: Sequence[MergedModel]) -> MergedModel: @jaxtyped +@add_tracing_name def merge_objects(objects: Sequence[ModelObject]) -> MergedModel: """Merge objects into a single model. @@ -442,6 +449,7 @@ def merge_objects(objects: Sequence[ModelObject]) -> MergedModel: @jaxtyped @partial(jax.jit, inline=True) + @add_tracing_name def transform_vert( verts: Float[Array, "N 3"], local_scaling: Vec3f, @@ -470,6 +478,7 @@ def transform_vert( @jaxtyped @partial(jax.jit, inline=True) + @add_tracing_name def transform_normals( normals: Float[Array, "N 3"], transform: ModelMatrix, diff --git a/renderer/pipeline.py b/renderer/pipeline.py index d47d58a..a638c84 100644 --- a/renderer/pipeline.py +++ b/renderer/pipeline.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, NamedTuple, TypeVar +from typing import Any, NamedTuple import jax import jax.lax as lax @@ -8,14 +8,18 @@ from jax.tree_util import tree_map from jaxtyping import Array, Bool, Float, Integer, Num, jaxtyped +from ._meta_utils import add_tracing_name from .geometry import Camera, Interpolation, Viewport, interpolate from .shader import (ID, MixedExtraT, MixerOutput, PerFragment, PerVertex, Shader, ShaderExtraInputT, VaryingT) -from .types import (FALSE_ARRAY, Buffers, FaceIndices, Triangle, Vec2f, Vec2i, - Vec3f, Vec4f) +from .types import (FALSE_ARRAY, Buffers, CanvasMask, FaceIndices, Triangle, + Vec2f, Vec2i, Vec3f, Vec4f, ZBuffer) jax.config.update('jax_array', True) +RowIndices = Integer[Array, "row_batches row_batch_size"] +"""Indices of the rows in the buffers to be processed in this batch.""" + class PerPrimitive(NamedTuple): """Input for each primitive, using outputs from Vertex Shader. @@ -45,7 +49,8 @@ class PerPrimitive(NamedTuple): @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", )) + @partial(jax.jit, static_argnames=("cls", ), inline=True) + @add_tracing_name def create(cls, per_vertex: PerVertex) -> "PerPrimitive": """per_vertex is batched with size 3 (3 vertices per triangle) in clip-space, not normalised. @@ -63,13 +68,10 @@ def create(cls, per_vertex: PerVertex) -> "PerPrimitive": # an arbitrary number for numerical stability keep: Bool[Array, ""] = lax.abs(determinant) > 1e-6 - mat_inv: Float[Array, "3 3"] = lax.cond( - # an arbitrary number for numerical stability - keep, - # may replace with custom implementation for higher precision - lambda: jnp.linalg.inv(matrix), - lambda: jnp.zeros((3, 3)), - ) + # although this may result in NaN or Inf when keep is False, + # it will be discarded later. + # Perf: Remove lax.cond to reduce extra operations `select_n` in HLO. + mat_inv: Float[Array, "3 3"] = jnp.linalg.inv(matrix) assert isinstance(mat_inv, Float[Array, "3 3"]) return cls( @@ -81,7 +83,13 @@ def create(cls, per_vertex: PerVertex) -> "PerPrimitive": @jaxtyped -@partial(jax.jit, static_argnames=("shader", ), donate_argnums=(1, )) +@partial( + jax.jit, + static_argnames=("shader", "loop_unroll"), + donate_argnums=(1, ), + inline=True, +) +@add_tracing_name def _postprocessing( shader: type[Shader[ShaderExtraInputT, VaryingT, MixedExtraT]], buffers: Buffers, @@ -89,28 +97,43 @@ def _postprocessing( varyings: VaryingT, extra: ShaderExtraInputT, viewport: Viewport, + loop_unroll: int, ) -> Buffers: with jax.ensure_compile_time_eval(): - # loop along first axis, for memory efficiency - # TODO: benchmark if this is actually faster - loop_size: int = int(buffers[0].shape[0]) # vmap batch along second axis batch_size: int = int(buffers[0].shape[1]) + row_indices: Integer[Array, "width"] + row_indices = lax.iota(int, int(buffers[0].shape[0])) @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def _per_pixel(coord: Vec2i) -> tuple[MixerOutput, MixedExtraT]: assert isinstance(coord, Vec2i), f"expected Vec2i, got {coord}" + ReturnT = tuple[ # + Float[Array, "kept_primitives 4"], # + Bool[Array, "kept_primitives"], # + Float[Array, "kept_primitives 2"], # + Bool[Array, "kept_primitives"], # + VaryingT, # + Float[Array, "kept_primitives 3"], # + Float[Array, "kept_primitives 3"], # + ] + @jaxtyped - def _per_primitive_process( + @partial(jax.jit, inline=True) + @add_tracing_name + def _per_primitive_preprocess( primitive: PerPrimitive, varying_per_primitive: VaryingT, - ) -> tuple[PerFragment, VaryingT]: - # PROCESS: Interpolation + ) -> ReturnT: + # PROCESS: Early Culling (`primitive_chooser`) # For early exit when not keep primitive / determinant is 0 + @partial(jax.jit, inline=True) + @add_tracing_name def _when_keep_primitive() -> tuple[Vec3f, Float[Array, ""]]: """Returns clip_coef, w_reciprocal.""" # x/w, y/w, with x, y, w in clip space. @@ -136,16 +159,17 @@ def _when_keep_primitive() -> tuple[Vec3f, Float[Array, ""]]: # END OF `_when_keep_primitive` - # Prepare for interpolation parameters - # clip_coef here interpolates to 1/w * target value - clip_coef, w_reciprocal = lax.cond( - # an arbitrary number for numerical stability - primitive.keep, - _when_keep_primitive, - lambda: (lax.full((3, ), -1.), jnp.zeros(())), - ) - - def _when_in_triangle() -> tuple[PerFragment, VaryingT]: + @partial(jax.jit, inline=True) + @add_tracing_name + def _when_in_triangle( + clip_coef: Vec3f, + w_reciprocal: Float[Array, ""], + ) -> tuple[ # + Float[Array, "kept_primitives 4"], # gl_FragCoord + Bool[Array, "kept_primitives"], # gl_FrontFacing + Float[Array, "kept_primitives 2"], # gl_PointCoord + Float[Array, "kept_primitives 3"], # true_clip_coef + ]: # Prepare inputs for fragment shader z: Float[Array, ""] = interpolate( values=primitive.gl_Position[:, 2], @@ -164,78 +188,121 @@ def _when_in_triangle() -> tuple[PerFragment, VaryingT]: assert isinstance(gl_FragCoord, Vec4f) # Ref: https://registry.khronos.org/OpenGL-Refpages/gl4/html/gl_FrontFacing.xhtml - gl_FrontFacing: Bool[Array, ""] = primitive.determinant > 0 + # True if not back-facing. + gl_FrontFacing: Bool[Array, ""] = primitive.determinant >= 0 assert isinstance(gl_FrontFacing, Bool[Array, ""]) gl_PointCoord: Vec2f with jax.ensure_compile_time_eval(): # TODO: implement Point primitive properly. - gl_PointCoord = lax.full((2, ), 0) + gl_PointCoord = lax.full((2, ), 0.) # this interpolates to target value u, not u/w true_clip_coef: Vec3f = clip_coef / w_reciprocal assert isinstance(true_clip_coef, Vec3f) - varying: VaryingT = shader.interpolate( - values=varying_per_primitive, - barycentric_screen=true_clip_coef, - barycentric_clip=true_clip_coef, - ) - assert isinstance(varying, tuple) - - # PROCESS: Fragment Processing - per_frag: PerFragment - extra_fragment_output: VaryingT - per_frag, extra_fragment_output = shader.fragment( - gl_FragCoord=gl_FragCoord, - gl_FrontFacing=gl_FrontFacing, - gl_PointCoord=gl_PointCoord, - varying=varying, - extra=extra, - ) - assert isinstance(per_frag, PerFragment) - assert isinstance(extra_fragment_output, tuple) - - # enforce default `gl_FragDepth` when it is None - per_frag = lax.cond( - per_frag.use_default_depth, - lambda: per_frag._replace(gl_FragDepth=gl_FragCoord[2]), - lambda: per_frag, - ) - assert isinstance(per_frag, PerFragment) - - return per_frag, extra_fragment_output + return (gl_FragCoord, gl_FrontFacing, gl_PointCoord, + true_clip_coef) # END OF `_when_in_triangle` + # Prepare for interpolation parameters + # clip_coef here interpolates to 1/w * target value + # Perf: although this may result in garbage values (NaN or Inf) + # when keep is False, since it will be discarded later, we can + # remove the lax.cond to reduce extra operations `select_n` in HLO + # as the computation is quite cheap. + # also see google/brax#8409 for why `_when_keep_primitive` is + # always executed. + clip_coef, w_reciprocal = _when_keep_primitive() + in_triangle: Bool[Array, ""] = (clip_coef >= 0).all() assert isinstance(in_triangle, Bool[Array, ""]) - built_in: PerFragment - attachments: VaryingT - built_in, attachments = lax.cond( - jnp.logical_and(primitive.keep, in_triangle), - _when_in_triangle, - # discard out-of-triangle values - lambda: ( - PerFragment(keeps=FALSE_ARRAY), - # dummy values - tree_map(lambda field: field[0], varying_per_primitive), - ), + # Perf: although this may result in garbage values (NaN or Inf) + # when keep or in_triangle is False, since it will be discarded + # later, we can remove the lax.cond to reduce extra operations + # `select_n` in HLO. + # See google/brax#8409 for why `_when_keep_primitive` is always + # executed. + # TODO: change back to `lax.cond` when it does not force execute both branches under vmap. + r = _when_in_triangle(clip_coef, w_reciprocal) + gl_FragCoord, gl_FrontFacing, gl_PointCoord, true_clip_coef = r + + return ( + gl_FragCoord, + gl_FrontFacing, + gl_PointCoord, + primitive.keep & in_triangle, + varying_per_primitive, + true_clip_coef, + true_clip_coef, + ) + + # END OF `_per_primitive_preprocess` + + @partial(jax.jit, inline=True) + @add_tracing_name + def _interpolate_and_fragment_shading( + gl_FragCoord: Vec4f, + gl_FrontFacing: Bool[Array, ""], + gl_PointCoord: Vec2f, + keeps: Bool[Array, ""], + values: VaryingT, + barycentric_screen: Vec3f, + barycentric_clip: Vec3f, + ) -> tuple[PerFragment, VaryingT]: + # PROCESS: Interpolation + varying: VaryingT = shader.interpolate( + values=values, + barycentric_screen=barycentric_screen, + barycentric_clip=barycentric_clip, + ) + assert isinstance(varying, tuple) + + # PROCESS: Fragment Processing + per_frag: PerFragment + extra_fragment_output: VaryingT + per_frag, extra_fragment_output = shader.fragment( + gl_FragCoord=gl_FragCoord, + gl_FrontFacing=gl_FrontFacing, + gl_PointCoord=gl_PointCoord, + varying=varying, + extra=extra, + ) + assert isinstance(per_frag, PerFragment) + assert isinstance(extra_fragment_output, tuple) + + # enforce default `gl_FragDepth` when `use_default_depth` + per_frag = lax.cond( + per_frag.use_default_depth, + lambda: per_frag._replace(gl_FragDepth=gl_FragCoord[2]), + lambda: per_frag, ) - assert isinstance(built_in, PerFragment) - assert isinstance(attachments, tuple) + assert isinstance(per_frag, PerFragment) - return built_in, attachments + per_frag = per_frag._replace(keeps=keeps & per_frag.keeps) - # END OF `_per_primitive_process` + return per_frag, extra_fragment_output - built_in, extra_outputs = jax.vmap(_per_primitive_process)( + # END OF `_interpolate_fragment_shading` + + args = jax.vmap(_per_primitive_preprocess)( per_primitive, varyings, ) + chosen_args = shader.primitive_chooser(*args) + + built_in: PerFragment + extra_outputs: VaryingT + _f = jax.vmap(_interpolate_and_fragment_shading) + built_in, extra_outputs = _f(*chosen_args) + assert isinstance(built_in, PerFragment) + gl_Depths = built_in.gl_FragDepth keeps = built_in.keeps + assert isinstance(gl_Depths, Float[Array, "kept_primitives"]) + assert isinstance(keeps, Bool[Array, "kept_primitives"]) # PROCESS: Per-Sample Operations (Mixing: depth test + colour blending) mixed_output: MixerOutput @@ -246,52 +313,27 @@ def _when_in_triangle() -> tuple[PerFragment, VaryingT]: return mixed_output, attachments - @jaxtyped - @partial(jax.jit, donate_argnums=(1, )) - def loop_body( - index: Integer[Array, ""], - buffers: Buffers, - ) -> Buffers: + # END OF `_per_pixel` - _valueT = TypeVar('_valueT', bound=tuple[Any, ...]) - - @jaxtyped - @partial(jax.jit, donate_argnums=(2, )) - def select_value_per_pixel( - keep: Bool[Array, ""], - new_values: _valueT, - old_values: _valueT, - ) -> _valueT: - """Choose new value of the pixel, or keep the previous.""" - FieldRowT = TypeVar("FieldRowT") - - def _select_per_field( - new_field_value: FieldRowT, - old_field_value: FieldRowT, - ) -> FieldRowT: - """Choose this pixel for this field in the PyTree.""" - return lax.cond( - keep, - lambda: new_field_value, - lambda: old_field_value, - ) - - # tree_map over each field in the PyTree - result: _valueT = tree_map( - _select_per_field, - new_values, - old_values, - ) + @jaxtyped + @partial(jax.jit, inline=True) + @add_tracing_name + def _per_row(i: Integer[Array, ""], ) -> tuple[MixerOutput, MixedExtraT]: + """Render one row. - return result + Parameters: + - i: the index of the row to be rendered on the first axis of the + resultant buffer. + Returns: one row from `Shader.mixer`, `MixerOutput` and `MixerExtraT`. + """ keeps: Bool[Array, "height"] depths: Num[Array, "height"] extras: MixedExtraT - # vmap over axis 1 (height) of the buffers. Axis 0 (width) is `index`. + # vmap over axis 1 (height) of the buffers. Axis 0 (width) is `i`. (keeps, depths), extras = jax.vmap(_per_pixel)(lax.concatenate( ( - lax.full((batch_size, 1), index), + lax.full((batch_size, 1), i), lax.broadcasted_iota(int, (batch_size, 1), 0), ), 1, @@ -300,44 +342,91 @@ def _select_per_field( assert isinstance(depths, Num[Array, "height"]) assert isinstance(extras, tuple) - # vmap each pixel over axis 1 (height) of the buffers (per row in - # matrix) - buffers_row = jax.vmap(select_value_per_pixel)( - keeps, + return MixerOutput(keep=keeps, zbuffer=depths), extras + + # END OF `_per_row` + + @jaxtyped + @partial(jax.jit, donate_argnums=(1, ), inline=True) + @add_tracing_name + def merge_buffers( + mixer_outputs: tuple[MixerOutput, MixedExtraT], + old_buffers: Buffers, + ) -> Buffers: + """Merge the rendered row into the buffers. + + Parameters: + - mixer_outputs: the output from `Shader.mixer`, `MixerOutput` and + `MixerExtraT`. + - old_buffers: the buffers to be updated. + + Returns: the updated buffers. + """ + keeps: CanvasMask = mixer_outputs[0].keep + depths: ZBuffer = mixer_outputs[0].zbuffer + extras: MixedExtraT = mixer_outputs[1] + + @partial(jax.jit, donate_argnums=(2, ), inline=True) + def _merge_first_axis(_mask, _new, _old): + + @partial(jax.jit, donate_argnums=(2, ), inline=True) + def _merge_second_axis(__mask, __new, __old): + return lax.cond(__mask, lambda: __new, lambda: __old) + + return jax.vmap(_merge_second_axis)(_mask, _new, _old) + + new_buffers: Buffers = tree_map( + lambda new, old: jax.vmap(_merge_first_axis)(keeps, new, old), Buffers(zbuffer=depths, targets=tuple(extras)), - tree_map(lambda field: field[index], buffers), + old_buffers, ) + assert isinstance(new_buffers, Buffers) - # tree_map over each field in the PyTree to update all buffers - return tree_map( - lambda field, value: field.at[index].set(value), - buffers, - buffers_row, - ) + return new_buffers - # END OF `loop_body` + # END OF `merge_buffers` - # iterate over axis 0 (width) of the buffers (one row at a time) - buffers = lax.fori_loop( - 0, - loop_size, - loop_body, - buffers, - ) + # iterate over axis 0 (width) of the buffers + # (multiple row at a time, according to `row_indices``) + # Not using vmap due to memory constraints + # TODO: using map for readability when map supports unroll. + # Reference: https://jax.readthedocs.io/en/latest/_modules/jax/_src/lax/control_flow/loops.html#map + mixer_outputs = lax.scan( + lambda _, x: ((), _per_row(x)), + init=(), + xs=row_indices, + unroll=loop_unroll, + )[1] + + buffers = merge_buffers(mixer_outputs, buffers) assert isinstance(buffers, Buffers) return buffers @jaxtyped -@partial(jax.jit, static_argnames=("shader", ), donate_argnums=(2, )) +@partial( + jax.jit, + static_argnames=("shader", "loop_unroll"), + donate_argnums=(2, ), + inline=True, +) +@add_tracing_name def render( camera: Camera, shader: type[Shader[ShaderExtraInputT, VaryingT, MixedExtraT]], buffers: Buffers, face_indices: FaceIndices, extra: ShaderExtraInputT, + loop_unroll: int = 1, ) -> Buffers: + """Render a scene with a shader. + + Parameters: + - loop_unroll: the number of rows to be rendered in one loop. This may + help improve the performance at the cost of increasing compilation time. + Default: 1 + """ vertices_count: int gl_InstanceID: ID with jax.ensure_compile_time_eval(): @@ -347,7 +436,8 @@ def render( assert isinstance(gl_InstanceID, ID) @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex_processing( gl_VertexID: Integer[Array, ""], # ) -> tuple[PerVertex, VaryingT]: @@ -381,6 +471,7 @@ def vertex_processing( varyings=tree_map(lambda field: field[face_indices], varyings), extra=extra, viewport=camera.viewport, + loop_unroll=loop_unroll, ) assert isinstance(buffers, Buffers) diff --git a/renderer/renderer.py b/renderer/renderer.py index 23ef1ae..97548d2 100644 --- a/renderer/renderer.py +++ b/renderer/renderer.py @@ -7,6 +7,7 @@ from jax.tree_util import tree_map from jaxtyping import Array, Bool, Integer, Num, jaxtyped +from ._meta_utils import add_tracing_name from .geometry import Camera, Projection, View, Viewport, normalise from .model import MergedModel, ModelObject, merge_objects from .pipeline import render @@ -100,6 +101,7 @@ class Renderer: @staticmethod @jaxtyped @partial(jax.jit, inline=True) + @add_tracing_name def create_camera_from_parameters(camera: CameraParameters) -> Camera: """Create a camera from camera parameters.""" eye: Vec3f = jnp.asarray(camera.position, dtype=float) @@ -140,6 +142,7 @@ def create_camera_from_parameters(camera: CameraParameters) -> Camera: @staticmethod @jaxtyped + @add_tracing_name def create_buffers( width: int, height: int, @@ -180,7 +183,13 @@ def create_buffers( @classmethod @jaxtyped - @partial(jax.jit, static_argnames=("cls", ), donate_argnums=(4, )) + @partial( + jax.jit, + static_argnames=("cls", "loop_unroll"), + donate_argnums=(4, ), + inline=True, + ) + @add_tracing_name def render( cls, model: MergedModel, @@ -188,6 +197,7 @@ def render( camera: Camera, buffers: Buffers, shadow_param: Optional[ShadowParameters] = None, + loop_unroll: int = 1, ) -> Buffers: """Render the scene with the given camera. @@ -198,6 +208,7 @@ def render( - buffers: the buffers to render the scene with. - shadow_param: the shadow parameters to render the scene with. Keep it None to disable shadows. + - loop_unroll: passed directly to `render`. See `pipeline:render`. Returns: Buffers, with zbuffer and (coloured image, ). """ @@ -249,6 +260,7 @@ def render( buffers=buffers, face_indices=face_indices, extra=extra, + loop_unroll=loop_unroll, ) assert isinstance(buffers, Buffers), f"{buffers}" @@ -272,6 +284,7 @@ def render( up=shadow_param.up, strength=shadow_param.strength, offset=shadow_param.offset, + loop_unroll=loop_unroll, ) assert isinstance(shadow, Shadow), f"{shadow}" @@ -288,6 +301,7 @@ def render( buffers=buffers, face_indices=face_indices, extra=_extra, + loop_unroll=loop_unroll, ) assert isinstance(buffers, Buffers), f"{buffers}" @@ -295,6 +309,12 @@ def render( @classmethod @jaxtyped + @partial( + jax.jit, + static_argnames=("cls", "width", "height", "loop_unroll"), + inline=True, + ) + @add_tracing_name def get_camera_image( cls, objects: Sequence[ModelObject], @@ -305,6 +325,7 @@ def get_camera_image( colour_default: Colour = jnp.array((1., 1., 1.), dtype=jnp.single), zbuffer_default: Num[Array, ""] = jnp.array(1, dtype=jnp.single), shadow_param: Optional[ShadowParameters] = None, + loop_unroll: int = 1, ) -> Canvas: """Render the scene with the given camera. @@ -319,6 +340,8 @@ def get_camera_image( - colour_default: default colours to fill the image with. - zbuffer_default: default zbuffer values to fill with. - shadow_param: the shadow parameters to render the scene with. Keep + it None to disable shadows. + - loop_unroll: passed directly to `render`. See `pipeline:render`. Returns: Buffers, with zbuffer and (coloured image, ). """ @@ -363,6 +386,7 @@ def get_camera_image( camera=_camera, buffers=buffers, shadow_param=shadow_param, + loop_unroll=loop_unroll, ) assert isinstance(canvas, Canvas), f"{canvas}" diff --git a/renderer/shader.py b/renderer/shader.py index 4c36845..a5c28c0 100644 --- a/renderer/shader.py +++ b/renderer/shader.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from functools import partial from typing import Generic, NamedTuple, TypeVar import jax @@ -7,6 +8,7 @@ from jax.tree_util import Partial, tree_map from jaxtyping import Array, Bool, Float, Integer, PyTree, Shaped, jaxtyped +from ._meta_utils import add_tracing_name from .geometry import Camera, Interpolation, interpolate from .types import FALSE_ARRAY, INF_ARRAY, TRUE_ARRAY, Vec2f, Vec3f, Vec4f @@ -76,7 +78,8 @@ class Shader(ABC, Generic[ShaderExtraInputT, VaryingT, MixedExtraT]): @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name @abstractmethod def vertex( gl_VertexID: ID, @@ -132,7 +135,95 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name + def primitive_chooser( + gl_FragCoord: Float[Array, "primitives 4"], + gl_FrontFacing: Bool[Array, "primitives"], + gl_PointCoord: Float[Array, "primitives 2"], + keeps: Bool[Array, "primitives"], + values: VaryingT, + barycentric_screen: Float[Array, "primitives 3"], + barycentric_clip: Float[Array, "primitives 3"], + ) -> tuple[ # + Float[Array, "kept_primitives 4"], # gl_FragCoord + Bool[Array, "kept_primitives"], # gl_FrontFacing + Float[Array, "kept_primitives 2"], # gl_PointCoord + Bool[Array, "kept_primitives"], # keeps + VaryingT, # values + Float[Array, "kept_primitives 3"], # barycentric_screen + Float[Array, "kept_primitives 3"], # barycentric_clip + ]: + """Override this to customise the primitive choosing stage. + + The default implementation is to only keep the primitive with minimum + `gl_FragCoord[2]` and `gl_FrontFacing` and `keeps` (interpolated `z` + value in window space is minimum), i.e., the closest primitive that is + kept and is not back-facing. + + Parameters: + - gl_FragCoord: batch of coordinates in screen space. (x, y, z, 1/w). + - gl_FrontFacing: batch of bool, True if the primitive is NOT back + facing. + - gl_PointCoord: batch of 2d coordinates in screen space. Not supported for now. + - keeps: batch of bool, whether the primitive is kept. This is used + to filter out the primitives that are not visible, or with garbage + values. + + The parameters below are batched values over primitives, with each + value same as the input given to `Shader.interpolate` + + - values: values at the vertices of the triangle, with axis 0 being + the batch axis. It is expected to be a tuple of multiple batched + values. + - barycentric_screen: barycentric coordinates in screen space of the + point to interpolate + - barycentric_clip: barycentric coordinates in clip space of the + point to interpolate + + Return: + tuple of values from kept primitives, in same order and structure of + the input parameters. The returned fields must be batched. + """ + depths: Float[Array, "primitives"] + depths = jnp.where(keeps & gl_FrontFacing, gl_FragCoord[:, 2], jnp.inf) + assert isinstance(depths, Float[Array, "primitives"]) + + # when all keeps are false, all depths will be inf, and there will + # still be a valid idx generated, as promised by argmin. + idx: Integer[Array, ""] = jnp.argmin(depths) + assert isinstance(idx, Integer[Array, ""]) + + _get = partial( + # use `dynamic_slice` instead of `slice` according to benchmark + # https://colab.research.google.com/drive/1idBbgEDbxI6wi5kzlHF6kzWryoFSm8-p#scrollTo=-bHrz3kZ5A0p + lax.dynamic_slice_in_dim, + start_index=idx, + slice_size=1, + axis=0, + ) + + _gl_FragCoord: Float[Array, "kept_primitives 4"] = _get(gl_FragCoord) + assert isinstance(_gl_FragCoord, Float[Array, "kept_primitives 4"]) + _gl_FrontFacing: Bool[Array, "kept_primitives"] = _get(gl_FrontFacing) + assert isinstance(_gl_FrontFacing, Bool[Array, "kept_primitives"]) + _gl_PointCoord: Float[Array, "kept_primitives 2"] = _get(gl_PointCoord) + assert isinstance(_gl_PointCoord, Float[Array, "kept_primitives 2"]) + _keeps: Bool[Array, "kept_primitives"] = _get(keeps) + assert isinstance(_keeps, Bool[Array, "kept_primitives"]) + _values: VaryingT = tree_map(_get, values) + _screen: Float[Array, "kept_primitives 3"] = _get(barycentric_screen) + assert isinstance(_screen, Float[Array, "kept_primitives 3"]) + _clip: Float[Array, "kept_primitives 3"] = _get(barycentric_clip) + assert isinstance(_clip, Float[Array, "kept_primitives 3"]) + + return (_gl_FragCoord, _gl_FrontFacing, _gl_PointCoord, _keeps, + _values, _screen, _clip) + + @staticmethod + @jaxtyped + @partial(jax.jit, inline=True) + @add_tracing_name def interpolate( values: VaryingT, barycentric_screen: Vec3f, @@ -170,7 +261,8 @@ def interpolate( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], @@ -212,10 +304,11 @@ def fragment( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def mix( - gl_FragDepth: Float[Array, "primitives"], - keeps: Bool[Array, "primitives"], + gl_FragDepth: Float[Array, "kept_primitives"], + keeps: Bool[Array, "kept_primitives"], extra: VaryingT, ) -> tuple[MixerOutput, MixedExtraT]: """Override this to customise the mixing behaviour per fragment over @@ -226,7 +319,7 @@ def mix( For the default behaviour, the values from fragment with maximum `gl_FragDepth` value AND `keeps` being True will be used as the output. In the default implementation, if no fragment has `keeps` being True, - then mixed value will be the first fragment's value for both + then mixed value will be the an arbitrary fragment's value for both `gl_FragDepth` and `extra`. Returns: Built-in MixerOutput and user-defined extras. @@ -244,26 +337,21 @@ def mix( - [Blending](https://www.khronos.org/opengl/wiki/Blending) """ - def has_kept_fragment() -> Integer[Array, ""]: - depths: Float[Array, "primitives"] - depths = jnp.where(keeps, gl_FragDepth, jnp.inf) - assert isinstance(depths, Float[Array, "primitives"]) + depths: Float[Array, "primitives"] + depths = jnp.where(keeps, gl_FragDepth, jnp.inf) + assert isinstance(depths, Float[Array, "primitives"]) - idx: Integer[Array, ""] = jnp.argmin(depths) + # when all keeps are false, all depths will be inf, and there will + # still be a valid idx generated, as promised by argmin. + idx: Integer[Array, ""] = jnp.argmin(depths) + assert isinstance(idx, Integer[Array, ""]) - return idx - - has_valid_fragment = jnp.any(keeps) - - idx: Integer[Array, ""] = lax.cond( - has_valid_fragment, - has_kept_fragment, - lambda: jnp.array(0), - ) - depth: Float[Array, ""] = gl_FragDepth[idx] + keep: Bool[Array, ""] = keeps[idx] + assert isinstance(keep, Bool[Array, ""]) + depth: Float[Array, ""] = depths[idx] assert isinstance(depth, Float[Array, ""]) return ( - MixerOutput(keep=has_valid_fragment, zbuffer=depth), + MixerOutput(keep=keep, zbuffer=depth), tree_map(lambda x: x[idx], extra), ) diff --git a/renderer/shaders/depth.py b/renderer/shaders/depth.py index 0fa8a09..083ace8 100644 --- a/renderer/shaders/depth.py +++ b/renderer/shaders/depth.py @@ -1,8 +1,10 @@ +from functools import partial from typing import NamedTuple import jax from jaxtyping import Array, Float, jaxtyped +from .._meta_utils import add_tracing_name from ..shader import ID, PerVertex, Shader from ..geometry import Camera, to_homogeneous from ..types import Vec4f @@ -33,7 +35,8 @@ class DepthShader(Shader[DepthExtraInput, DepthExtraFragmentData, @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, diff --git a/renderer/shaders/gouraud.py b/renderer/shaders/gouraud.py index f3ac064..c5a0d28 100644 --- a/renderer/shaders/gouraud.py +++ b/renderer/shaders/gouraud.py @@ -1,9 +1,11 @@ +from functools import partial from typing import NamedTuple import jax import jax.numpy as jnp from jaxtyping import Array, Bool, Float, jaxtyped +from .._meta_utils import add_tracing_name from ..shader import ID, PerFragment, PerVertex, Shader from ..geometry import Camera, normalise, to_homogeneous from ..types import Colour, LightSource, Vec2f, Vec3f, Vec4f @@ -41,7 +43,8 @@ class GouraudShader(Shader[GouraudExtraInput, GouraudExtraFragmentData, @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, @@ -73,7 +76,8 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], diff --git a/renderer/shaders/gouraud_texture.py b/renderer/shaders/gouraud_texture.py index 6783ea7..f8fcdaa 100644 --- a/renderer/shaders/gouraud_texture.py +++ b/renderer/shaders/gouraud_texture.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NamedTuple import jax @@ -5,6 +6,7 @@ import jax.numpy as jnp from jaxtyping import Array, Bool, Float, jaxtyped +from .._meta_utils import add_tracing_name from ..geometry import Camera, normalise, to_homogeneous from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader from ..types import Colour, LightSource, Texture, Vec2f, Vec3f, Vec4f @@ -48,7 +50,8 @@ class GouraudTextureShader(Shader[GouraudTextureExtraInput, @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, @@ -83,7 +86,8 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], @@ -123,7 +127,8 @@ def fragment( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def mix( gl_FragDepth: Float[Array, "primitives"], keeps: Bool[Array, "primitives"], diff --git a/renderer/shaders/phong.py b/renderer/shaders/phong.py index c06053b..f146d8e 100644 --- a/renderer/shaders/phong.py +++ b/renderer/shaders/phong.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NamedTuple import jax @@ -5,6 +6,7 @@ import jax.numpy as jnp from jaxtyping import Array, Bool, Float, jaxtyped +from .._meta_utils import add_tracing_name from ..geometry import Camera, normalise, to_homogeneous from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader from ..types import Colour, LightSource, Texture, Vec2f, Vec3f, Vec4f @@ -55,7 +57,8 @@ class PhongTextureShader(Shader[PhongTextureExtraInput, @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, @@ -86,7 +89,8 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], @@ -132,7 +136,8 @@ def fragment( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def mix( gl_FragDepth: Float[Array, "primitives"], keeps: Bool[Array, "primitives"], diff --git a/renderer/shaders/phong_darboux.py b/renderer/shaders/phong_darboux.py index a5a9086..a3c7b4d 100644 --- a/renderer/shaders/phong_darboux.py +++ b/renderer/shaders/phong_darboux.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NamedTuple import jax @@ -6,6 +7,7 @@ from jax.tree_util import Partial from jaxtyping import Array, Bool, Float, Integer, jaxtyped +from .._meta_utils import add_tracing_name from ..geometry import (Camera, Interpolation, interpolate, normalise, to_cartesian, to_homogeneous) from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader @@ -79,7 +81,8 @@ class PhongTextureDarbouxShader(Shader[PhongTextureDarbouxExtraInput, @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, @@ -132,7 +135,8 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def interpolate( values: PhongTextureDarbouxExtraFragmentData, barycentric_screen: Vec3f, @@ -167,7 +171,8 @@ def interpolate( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], @@ -236,7 +241,8 @@ def fragment( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def mix( gl_FragDepth: Float[Array, "primitives"], keeps: Bool[Array, "primitives"], diff --git a/renderer/shaders/phong_reflection.py b/renderer/shaders/phong_reflection.py index 78e7d85..953bd92 100644 --- a/renderer/shaders/phong_reflection.py +++ b/renderer/shaders/phong_reflection.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NamedTuple import jax @@ -5,6 +6,7 @@ import jax.numpy as jnp from jaxtyping import Array, Bool, Float, Integer, jaxtyped +from .._meta_utils import add_tracing_name from ..geometry import Camera, normalise, to_homogeneous from ..model import MergedModel from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader @@ -75,7 +77,8 @@ class PhongReflectionTextureShader( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, @@ -108,7 +111,8 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def interpolate( values: PhongReflectionTextureExtraFragmentData, barycentric_screen: Vec3f, @@ -127,7 +131,8 @@ def interpolate( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], @@ -186,11 +191,7 @@ def fragment( use_default_depth=built_in.use_default_depth, ), PhongReflectionTextureExtraFragmentData( - colour=lax.cond( - (colour >= 0).all(), - lambda: colour, - lambda: jnp.zeros(3), - ), + colour=colour, uv=varying.uv, normal=varying.normal, ), @@ -198,7 +199,8 @@ def fragment( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def mix( gl_FragDepth: Float[Array, "primitives"], keeps: Bool[Array, "primitives"], diff --git a/renderer/shaders/phong_reflection_shadow.py b/renderer/shaders/phong_reflection_shadow.py index d05a195..1e7d1f9 100644 --- a/renderer/shaders/phong_reflection_shadow.py +++ b/renderer/shaders/phong_reflection_shadow.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NamedTuple import jax @@ -5,6 +6,7 @@ import jax.numpy as jnp from jaxtyping import Array, Bool, Float, Integer, jaxtyped +from .._meta_utils import add_tracing_name from ..geometry import Camera, normalise, normalise_homogeneous, to_homogeneous from ..model import MergedModel from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader @@ -85,7 +87,8 @@ class PhongReflectionShadowTextureShader( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def vertex( gl_VertexID: ID, gl_InstanceID: ID, @@ -124,7 +127,8 @@ def vertex( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def interpolate( values: PhongReflectionShadowTextureExtraFragmentData, barycentric_screen: Vec3f, @@ -144,7 +148,8 @@ def interpolate( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def fragment( gl_FragCoord: Vec4f, gl_FrontFacing: Bool[Array, ""], @@ -220,11 +225,7 @@ def fragment( use_default_depth=built_in.use_default_depth, ), PhongReflectionShadowTextureExtraFragmentData( - colour=lax.cond( - (colour >= 0).all(), - lambda: colour, - lambda: jnp.zeros(3), - ), + colour=colour, uv=varying.uv, normal=varying.normal, ), @@ -232,7 +233,8 @@ def fragment( @staticmethod @jaxtyped - @jax.jit + @partial(jax.jit, inline=True) + @add_tracing_name def mix( gl_FragDepth: Float[Array, "primitives"], keeps: Bool[Array, "primitives"], diff --git a/renderer/shadow.py b/renderer/shadow.py index 4892b6e..6e52c41 100644 --- a/renderer/shadow.py +++ b/renderer/shadow.py @@ -6,6 +6,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float, jaxtyped +from ._meta_utils import add_tracing_name from .geometry import Camera, View, Viewport from .pipeline import render from .shaders.depth import DepthExtraInput, DepthShader @@ -27,7 +28,13 @@ class Shadow(NamedTuple): @staticmethod @jaxtyped - @partial(jax.jit, donate_argnums=(0, )) + @partial( + jax.jit, + static_argnames=("loop_unroll", ), + donate_argnums=(0, ), + inline=True, + ) + @add_tracing_name def render_shadow_map( shadow_map: ZBuffer, verts: Vertices, @@ -39,6 +46,7 @@ def render_shadow_map( strength: Colour, offset: float = 0.001, distance: float = 10., + loop_unroll: int = 1, ) -> "Shadow": """Render shadow map from light source's point of view. @@ -57,6 +65,7 @@ def render_shadow_map( the light. - distance: Distance from the light source to the centre of the scene. This is mainly to avoid objects being clipped. + - loop_unroll: passed directly to `render`. See `pipeline:render`. Returns: Updated `Shadow` object with shadow_map updated. """ @@ -85,7 +94,14 @@ def render_shadow_map( buffers = Buffers(zbuffer=shadow_map, targets=tuple()) extra = DepthExtraInput(position=verts) - shadow_map, _ = render(_camera, DepthShader, buffers, faces, extra) + shadow_map, _ = render( + _camera, + DepthShader, + buffers, + faces, + extra, + loop_unroll=loop_unroll, + ) shadow_map = shadow_map + offset assert isinstance(shadow_map, ZBuffer) @@ -99,6 +115,7 @@ def render_shadow_map( @jaxtyped @partial(jax.jit, inline=True) + @add_tracing_name def get(self, position: Vec2f) -> Float[Array, ""]: """Get shadow depth at `position`. @@ -110,13 +127,13 @@ def get(self, position: Vec2f) -> Float[Array, ""]: pos: Vec2i = lax.round(position[:2]).astype(int) assert isinstance(pos, Vec2i) - value: Float[Array, ""] = lax.cond( - jnp.logical_or( - pos < 0, - pos >= jnp.asarray(self.shadow_map.shape[:2]), - ).any(), - lambda: jnp.inf, # outside shadow map, no shadow - lambda: self.shadow_map[pos[0], pos[1]], + value: Float[Array, ""] + value = self.shadow_map.at[pos[0], pos[1]].get( + mode="fill", + indices_are_sorted=True, + unique_indices=True, + # outside shadow map, no shadow + fill_value=jnp.inf, ) assert isinstance(value, Float[Array, ""]) diff --git a/renderer/utils.py b/renderer/utils.py index cdf7689..803472d 100644 --- a/renderer/utils.py +++ b/renderer/utils.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Sequence, Union import jax @@ -5,11 +6,13 @@ from jax import lax from jaxtyping import Array, Integer, Num, Shaped, jaxtyped +from ._meta_utils import add_tracing_name from .types import Canvas, Texture, ZBuffer @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def get_value_from_index( matrix: Shaped[Array, "width height batch *valueDimensions"], index: Integer[Array, "width height"], @@ -19,7 +22,8 @@ def get_value_from_index( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def merge_canvases( zbuffers: Num[Array, "batch width height"], canvases: Shaped[Array, "batch width height channel"], @@ -47,7 +51,8 @@ def merge_canvases( @jaxtyped -@jax.jit +@partial(jax.jit, inline=True) +@add_tracing_name def transpose_for_display( matrix: Num[Array, "fst snd *channel"], flip_vertical: bool = True, @@ -69,6 +74,7 @@ def transpose_for_display( @jaxtyped +@add_tracing_name def build_texture_from_PyTinyrenderer( texture: Union[Num[Array, "length"], Sequence[float]], width: int, diff --git a/test_resources/pre-gen-brax/README.md b/test_resources/pre-gen-brax/README.md new file mode 100644 index 0000000..8d77e2b --- /dev/null +++ b/test_resources/pre-gen-brax/README.md @@ -0,0 +1,6 @@ +# Pregenerated Brax environments + +This directory contains pregenerated Brax environments. These environments are generated using [colab here](https://colab.research.google.com/drive/1c_83TLtb_pOt4OSlWFQgIKzA9DxdBDTp). An example colab to render and profile with these environments are [available here](https://colab.research.google.com/drive/12yNBVOdwUqUOBRgmQHF0gl8eMzzIi_BH). + +1. `inputs-2.zip`: width 960, heigh 540, render 2 frames +2. `inputs-30.zip`: width 960, heigh 540, render 30 frames diff --git a/test_resources/pre-gen-brax/inputs-2.zip b/test_resources/pre-gen-brax/inputs-2.zip new file mode 100644 index 0000000000000000000000000000000000000000..fcb9d36d92600ac990a5ab04513f8722a0a104ae GIT binary patch literal 87861 zcmeF4cU)6h{{HDvDH$sr#&Hx2f`IglZBT>gvIWbi^pwuDHl6S2S=+bhTB=#uVI63LjOIy zxmj2_SinEVkE!bIWrnZ2N&;f{e6?q9BCE}*Jrdu4<*Rsq$9>mRtn!yQL|j-Ro+^E< zeD=(xaDhnn`Xj#zq`7h}%-ju#ni5bv5UO`IYL)#FG5ggI>AsFD`psG#D&DCLTyA(t@k55G)<`|Wb5}hhNok?XfJ;|nky^*4t3DSyu6=d7@|n8isOK{Qg;sOd-cZ#ny3|g6<(WKkMoAdlG2VSfSw11O<)br$!-$dfvx>UT@h5FEhn%zy=m(t4BfaG38uqYpuXbNpx#TYDxbNQTH~ENf z7axSLrYS3gPQF40*sBR;E+8<+*DD1)+sZ9IE-xcSGwpmk#|2uGtVI;&#+<{7;h5%Tuj(q# zb?0HK=Dm@11JPu3;$-yCO6jACN2=HGmbp&Nu?~*A2NpV{sMy-vPv_OCr1rfNpH_aV znRyg|Nb_+HolCOmE7QrErAkXY6!)Yz_t=;prF_3s?iBg>ic`}yABWx*r)w{ENGBS; zbWqJ>8)NAW-J3Py(l2+}7JkTMZ=R6x-rMY1M$JPNs$;C(&n&|5 z^1^Gg4@TDyPim=UCU>s;x%cv)Q{S`hpjd5n8Jlh%x=smPYhPczxVAdfcIQ%r_fZW7mzuhD3^|h<~?Vj&XehugGY|(nfTnU>(Mi)Na}p=bvxm$ z)9D2F_>632oA0%&Y`lY<)f8)7eLDr0oo^{te(Ek35)yjgBWc4ktzGJTFExK8VsUW& z-s1Z7YBT={smgXIRlCUy<Qc@dpQB>6eD!rNT=N=jT6p5H_}mOpPSU#*HZvY=EZh~? zOUrplmiFA?+qCrS!SsN#T$?nN0F7mO_Z|BZ*#&#^?)tGtQ!nwyk`?5#CO@?QTHAMB zxu2xRaJOC+EOVM=7QH2$6mv*ZStS;~@;<2$A<=Mlp{baJ7{wYzKcEYXE3+h zvRgwo|FyT9nBcnCvP9v9oI)Q$MZ?b7Ef>5Dc%-Gg+-7%U>Z_>YOG&7c>(?;dR!)YQ`U}o0 zDt6^w!b0FFUFE;vNGDXSXR-&Pa%5~#*FFh0nAM+Mm{5C<{PdHp>9yg;-e(sy<6)iJ zcIv`i;|2|x-B}(c!d*RQTLY&Pdmd@E=8QK8L_+VNySA`)-@{b6%O__|PyEWKa--SO z^WA3`?CA$PNv{??@=6t67!B*n8+nlOACXR5c0!+VKF#OE^1V5*n|*SIJLf&ZL6h4E z_M`dnG6wG$`-9n#$X(2an6Ko1tNfO3I3xbp0PF@#7j}HIiEyU;rjsG2Ey^y*C;OGh zeu9ITuA>at<-Q)d7v$kWrV7c7w*qf)>00|)kCpdyhw5uF;6ppTr;kpre#a+YDwI%f z>0}sF)1mf??t=DQiw|(jmG$)qe9{;?o>r6$KIwCx4&P!a2g~`UYcI+Rl(T0f&qPzE>BEOe6|rxjPI~8cu<;%#Y*`oF zuDZIOfr)DmDZW%K6xdey7z~LHwOGRW0*NNVGcmVs&=`Zh$R6(kXeU0@-yES740K4I z^(eVSLD*oZ#z~=@fnjOpklW(@!^iCXmTA6|g6F3P+LIk@b9Ehk15Fd2 zB~W4>sVQ3X9fTV1hI-tzDm>X$0gcw@%XsE+ii1td(l&j-KemP)v4*Swc9tN$YKC0=0LO0^>2IU#V^iGntbKg z3|Adc-=gI|&w7TGlzcnJ0rplcmcB82&)Kc0)H{L??6@e_A}IHK@#lM+4G1vYXuvzK z=O4eP)On@SAKaid$3kK_}N{ZTOHK(t88YL16(?C^hdP|U`ZM}pJP z=Yt|3zV*sk%#5>~&xwFn>c$Zf6F7bH#&nnFTUM`b(rwlwa&;*xPUbnI0Shjo^fBtm zB_n9?Tl802e*)(eMkBEAdm__?3bz_DXvd>$DPO!`&zGgFuY9Nrn!E&b`Lx;$gt1`C zFP9zO&uqLh+Cb>0S$g7R>@j9j#7LlmWFEWV1Qf{J^_`g{1oQZCv}`lhvqI^VlbZYt zraQ?^#aEmR;qfhX&s#i>j8-tWGsRnfH+!e4^gSDvvuTfRT1jX!Yba@8j1S9C7-n1A z99pq)$8NNM`J+C=HNJ+=FRx%qCoN*)IK}FNX|&a}RI|o${Mm&Q^-gW>8lk@wpmo80 zrMdk9O$A=(P4P2V!I>6)sGTt^F!E^YmRHbx5q{ieHl}Ixj-f_fwAHMHWyV2h*tp-C z(H4b6Ppz0YKMKxjC(!r6%#WWKoR_DVfZ0NR?b}hgzlqA-BjH}{gNB!p`y*R1X4hTN zRMF5APhW5ZXY^L==aX{M@`<+}>>Y)`DcOAGJIYs-1MWs|K_?|q9)Q=~IQ^;FGERTh zD$e}ibiC*hr)JCeizL)-|M<@{|K{m&k{qDeuf#zwO<%j!U+m#`n>q< z=s1aboEndq#}9|ZB@w>8@6VeI?^rg;Z_bdhXsuh4h-~>7I)YrE`9+Ge9_KrTR%r$k zU(c=XcHifPA`@@vd8qEJX>0X(c(hifa?Nd#cbXpHB}>q2V__=IEwI6MYJ>lK%}%~( zzjm=GJN~}iC^3uin7LAA`>PqElA2xX1c8`({q^@Ie2Yb$)N$}fLjkERYiXJ*scmk5 zJNCE-_3&6p!t_n`dHp0SQt-5m`aH%bRjkP$OFGkco%UeXhVCft;fKu=T|`~4$_`t5 z-bcGLXYW@8ZLiyjd5CmXB8)XBtt^VfM76K7pmgkQ7KAX%pM~rFXwag=@SDsF$Lap*Z)Gk=v4Hq{Lvw2e?Vqx~n4%lE*7r2r} zE=Kk{|F}|Q;S{q1S80yYjEm{G(C`Ghjni<(W}-sc3P{@Ej_oSV@#qezlF{7UB|~Uu zNBX2ZKXIWK_o*Ylc}KX;u;%37&c%7=fiaRf*9rHCuA_YyH}_6_^`edPtG!LV=<8;8 zJoC{W^bV*mVI}B=4lL=8?;BWO`-+J)xC`Fhu(&AVe<7n^(u?4EF}-DlFhF*9Qf$_N z2fErGF4Q0mW%WBhbhYyv2jc}-Q_%jRBys_IHE?x?dLjjG2?3&(U~K?h#~9UL=#!;hd*pS57CI<@-R+XTL9P`boV)2BQgelvh@ z)LeZF_wWnMUtL_gp~`$PW?FhcLPek8^ZoEB{d9j@*di)be8N+sXz9T~gOKK`0e<1I z|3JaggEsNu7T?ES*gT7gh3*4jM1wEdU?5s2PpLhhyiU%XykDY^PrsOfFgGd+sE`_I zjmkFaizyxKSgFg(HORI}G;v;_V9hhf5cqSIVv}W0#akcbHh3m6f3MLS(Cbu+c;MP_ zdmw!|QgeN7RP!I)If`APphsxwsc+#|YHWSqBqj-6byz;0qmXjk7SM{y40jFbl*ebxpAGX3u}D!g7I$P(;1CXu zz#T#wbin30S8J_PGzlV(tXx6IaaWNJdAx*uJRE(&r;&aq1bA3ggHn+KhYhhWROyEZ zE_nwC+rzhE{dqlED?@%p?ji6_@tqA-uZ^rtMoQ-?5@6x1w2;+W;}d-0Fj^>kZ7{Kc zC_((%`t@vzHDB`#gg&f5H;Nab36+#HePglM9` zn!`3Emy%#2+Xc?d;5g##6WsLd71Mh*LFiC9oC(qKgm5Bz8>f4%&k$CDgLD^hjSI`d zAL2v&hOFZvW6Ks3q=T~}x}DIZv+`1}P^~>&_Ld2mvFCDPx$wc9L|+&7vYj6|ZgG_c z4F;LwoN-P>PqL~DtZerO&Ka(yAip~4UqbH;nNH$XiLB(WZMo?B8g~YJ{kGA?z;w7B z`!|>vdt|7@OT`bNnn9W&nYEonL9(PbpBmhPC6Mg_?g4T;RnB{R70()*1uG^5UyHy^ z5_!poyt%7*2Ur8yXG5@|fk8{Pnm8JfoxGpA&7WtDQ-c*9(q21C97G@kf{1Mg5eL*EDqPB3ApuM&CQx1%AsDwT_#ow_vo98NPf9(sTe4fe}pKp+S zZ9MXLo`Nwfo4cD`{j_w4a5h|lMK~m)wjQ}V@5mVYN(c%m*CFo6SHi+x8;v}cCr{vV z2qrw%Bw(?t@ux3SrT_7Fu} zq+;3I*hfRliELzUm;JFYDK;zzf9N5X?PYt+xTIhd++Ly)@%+EpzqkP8eq)Xq7(57e zTJCW0p2s*X;u7&Xd7JkhH7+rjevkxCyjC1%K(r(Cc<)n#CBvkHG=pCg&*X8?1*IW_MFeATibPYgs4-U!MvCrq%Cv*82qs#)19_%n=WVtm-1!vlu^m^rvVv|Q zuKfxJi-zr9E8=@(R&Naj-`Sn(<3+#^I>%JWPJjIM>G73s zCFm3Ft%HuPI!F3ZN%O-wHv5dsO0MVcr(p7_Ym~zKXm#w_p(MZA<9g`BhS~MPi9*X8e*gqP8qi-sbyTfmbef<<36l3z!Vp z>r**l^s}WtVL3hG+Os!qI2q3NSDU6%YO6ym ze)UU@vE8UfDVSIh&tG`c;w7vm7$tM6|7zWER7N`G*sQtb=y7CWireZNO8>hQSY<%w z61T#^-py~h(e!^Nc36eqg~2?lEr(-nU* z0kI4fjbe+OX3#1H`!$ErmS_?CTY8$z68MV}2LQllkJ=led9w2%vT_I>FHVvpfL)N~ zYVCH#My&bczlL*B<6Gz>f%fcc<;M_)(HWH2xo}eRiN{cUSOcB_b}o3UiUO%A|1wpxlxrC^ugR5#`|q`Wsh3XW>F`KR`x5gHVI6y`qzLiOx^h0BKqxFkEzP@1*MXW zz2=VsuH{BdTo>s}9QEh2iw?L{nBAY=UF_$01lrYa{o~abHE>n8f4t0Y21YJPtVV}d z@ALaiw-RWV8bTqC4DhRo+tcOi8oR;3gW#(JsStohd%AM_W&3*Fr<)zSU*gkp#Z8*b z2Cn?Wp2f$(cemy+fl#{iNw|!rtXCHSs#GBkVfX^Vl=fmpi~%i1|FrDxAnd-hRjR5} zhbT7Xne%J+p_(+yLGKVabSy#PzR-!Eh-%KIrUGy>1Ai9XY|N4j@ zt57IvHbU!B{T=+FGO5n98%#UY*PuR)$*U|J*5pwU3^Y$3mR%RP1ePv}BR3ekMcdZ0 zgm1^o_98GfhU`OS%sC)_oFof1K|J|wNo>vEnBY%Yarh{c_>^F8`@b9bXlA)V72MnR z9SmdgG~2DC2rvte{E-1C+dKxCv&z2d{><|3qEIL)^B`L87z*&hA&qAi4r2Fv_d(e2 zOX4$Zuz`xwNP`2REvcmLc}oabUC=zV?ie|z&S4F5Uhk&LlExfKL88_ zCO|VAF&!)OHy<~F&gewJ=cSN8T%W-A(5z+-Fp-`P!dC1C>L%LS{mI)bTAwbMEnMMO z0VyKy_TM^Ga$WVK<&6}=x>j#m>oQ4@x)}`MqE;rOG+N4yqbD#M(If4#X!-mz+v)KV z(gIw~sXD&SzRlp_IcQzyBv(}3Pn7rd&P=@Xw+)Uz`j|`5=*XxFsfo#^!j|j60tFNG zMs8BNO^OY$TddqwgSf%Vx#gb*3~48s3lgD+C@x@Sqv;P8C^nnZrZ7Qm=_yPvSj&*b z=T9@o#}~Yn;!L{FdR6G1l%L^i?4d zng+F?r%DCHoChbm78~&8gvn$5yTQyAq1m|1bEDz&G}x8+?Z2e48oDQcv-tGp!<8b& z62()%c^`K}hq64JxmJT|>pKGm=I0fPOmR^(i8%xMlqlo6E!K2)PC%a2lv@L}#EeD^cmN{O(%!5gY(EP<%tOxVye9}l3 zhM>_Ku+;JHJ8$#bewMrD$>2`1Mz49m`N%Q-RgR<))w_Gr`bm{ZhCX@&s?|z!31a#5 zPp_mnyVKS}+uZUp`^(lI>}GCwr;GN9PIUT)0;n_EQYbCfv{GzA3SH<3>;Qy907g0C z{eFL(f1dxioxw4yh+`jZUP~p^d&|QI(Q2|$@aWLh%BYGx0Q7o^#=y@AfQ&##3v->#Qimg#VME8bybaL)ctfb%pLWq zKHlkw`a#k&6F`mVhU=D=W)MfQQ6_X=7|sNL6t;9nmKMoCe*1GTTGNzs3Aak37{uo@ zKRQfVjpug?o0hIY&lLu|pI&|GNdTy>Kx`%R)+~R0*7RFwdo|PNy{q9Y>iTMSjL%!& z+af)*I}jSe4FAv&H!Yn6pHfq1xAMcfXrx#;C*SmY*-n|^U&(AMV6dIZ0OwIoQO$X+ z{#*kl_DO<3&jS_I7`h_YAQ3K)%2lkm(7RwfJX!%59l$jKsp95v5A~AC2f1MkIuyW) zj}eqb-4;#WKk4S3^mAcNo&^MUDr)jjl{I;QB1$?W7gb|;y=AVO^cAnAmmciZhz?bs zE;4U4xA_*_uA2-`T)o>Zd^=i8-4nYL^kT93gyxu}{@P88?^H#NpxZE7)oL;W)ro^U zB*r8iP~TM&k}E3)41B=O)+7aIyi~G-HET6U9XlZvv7UY4kX0`ZjFF^zeN8}1XSuEF zbpq$Ll19EVI_wa5*L832YgG4~7Sxww1XQ=ydOZKcesM=lz}UcD8ph&%eePuCIM2cl z*ZWRCk)5>kt7|Xrmy}kgzfMhUS&0q9cw&|7)2GwHc!ANHG_ra%z>9!)!lbq==A*iG zAOJal0Ps#$`dZ{eRX}-3cNkqPuP=lxaW=cK{%JK{8+0(Df7l9e zz5W5jlxFLLhX){3H?OS;djIueAMF8y>w*VMu-@n~wdhpKol?GEZ$-*U%* zw6C}WeO})jpGNj+9*|ZE440d{G_d;8_c`7)5@6lpC2lYf<9rGzzvPDD(7DHYHOyna zi$i&szg--!c&w;jVBv%qKxs|FF%=i`lU=Q=yV}CzG3>{B?};RSjCzAK^vS%OWQr>u?W|5+}Z)f08F~ zPVG2*7xRhhKg4r*16x>8J08L$g2sv8yX=Z(y}*?jWPr0nT2ocr{vo&{`kzpwzmuhM zmVMR2&L4~jd2o7n8Hd>pDfXDq?yLN%d<#M2G`v;y| ztXS64kin3?+G(N}kMtXQ6;R&JA zwJ}6Z;xV#3mCK*IhP}CVsVf1?Ys=jmyqLmn%$mJTCFDYiobmSCtme2T>XFx zp^k?cvuAJbW^)MjsI5S%ckG^LtH!M$+4J^q8;r)v5I7untOV7hh!c3`VF{tfYgzIR z5_UN9maz7P1S3^XY;%Njh3-efPVC`f=lm(?8v#B~4%OhDNSi!v!rpL>wAx!puM?#KwX#JLhPiQX=3v3wUeB!UHrH(hop?M>ve4(Y^6lI1$KwXtZzY(ffa zdyooButaWB5T$k*DOJWoXEWPI3Vu*4jeDK4-Hc6&GbXqh=kLNv-CMTHj5C)DALM|; z;5?A5Wnv##ZgD6EC*Y0~6^U2LrY^!|qIA9wtTP;5!Ie05qJ@jN@wR=zG#oqebO+D* zuQ{#AUs2V)1=QHYI3t4_aJz};$S|@td5^b{8hbKld+;cZo2ZYJ@!l!P8ptIb9D_sD z9wS~M3sH}Gi&n7?a5x8-;#7&wM02ti8BP|XN_dO=vuJRL2S?(twGu=FvLIE2E-cB3 z2$IGH)*2BV$h_2j{;=ra*(~lr;f{Sp?9-=Vd2GfzZ*wG_vdP;u1}RW+M^l82_hxfl zh4r6?=W!YD&gSgqnhf&5`KCyWar%ZX;JA=lc{>RLj%*5?0jCfha7V6{&{o_&Bnx4e zBWDR$Uyu{-RGz2+M+qzx_f;Odz>d4YY(JdJ&-pF6gFsvRdev&Arrhty9RtqUoZTXE zpK0H}LHg)MYH|~*71Lu1R7B=zYt7cs@h!)8L0&o;p0o5d#*@ZuK3g+q7r_A-F(iWf|V<*psO3S&tFrpSr&5 z{FS@2OQq7V#U zJYWaUrOpbE$2cz?PF|i4`Xnr_afP7Sljhx7Gr30?zUn*n#7VWQCa|2Xek5D@fEz9A z(F@Oj%-dt8vsrKYu3+zr_`PhCcG}|^$lG335Gb3oYWHL^J4A$(?qRGnE9 zIG;>vC(lN6sS~DEs}Xre-fk78c9D`!}UaP#N5)&#Y=h(M&J`R^(M9yTapi zw}97&589ktvO0+CJkRFFtEtt1)qCiz@J~Fhw~w#d1JKqMwLr=*lj87e4P~lRFl}9s zXy6O=RXRf>vsw5#rjazAXZhWpm3J}3j}`NGNc%})v?Bep)7L=33WRlY07Eiy-=-X) z0^ahmiIi7?Zuv|g8Nxb9ErF5$40xXRN4mCjSkT@uOHa8`sA_d8VW>=;NwJ0yQw5NR zM_IaungbB;6G1q$uFC=-Xm{ae$$DV@Adr}pIQ|a8ZmI*TizDPMi%m$S)=<%EJ#EmE z&ct-2p+_xs2jeE=mGRftBwe?pbbQkRtW2m4a8?gh3jS;qP zoy)AI6Y7BoXbN|}U_D&nh2`LtN7Yv6x@R6ub-Dpf%P@hY2Gn(hQxD>S)}%PM-%`g1 zlDv^RKPqo%pgP0jm^JXo)auQtV!~(k^`)||_0jHDI3! zo7K}d)I;VDvYVA_>7vezTJ$}WFZ(sA{YadzeAr6MAOvoNLK}xpQ0Qt5a1VNwQ`zAS z;n{%i-OH6rg`-08Ks!{Jn0nLEP;pfHpwF!qCSV_-Kl^scIh!H#`>Y`HvC&1Xjf^7x z$`8$X3TA}nAV<-*TuNGSqE-am)t*%=DW(^eiXzv>l=xx!&2cUBh%CPNxaW+|iBTEryNOPNOIo<}6rDHCjSUQZ z_h~8g35MO4=igJ;34i}Tb>08eb^lY>{ZC!@KXu*z)OG(;*Zogj_g_%gWsK{^oIF|d zoKEU;qc3)hg&iGKQoq@Dfj6s|UUAji;UCL$hNRg}xfDafs~<7j;8?3aq&I|Zjywvi z2_YIO(qk|M_Ej7S$$f&waeE1yX{b|e3G!QU!BVS;ElmtT;0usY5H>6`}_U|-~<1?MA01mKG>SP(ld z5~p8#nut8XD#CUs=_|;L;~4J^p7#l-P3H=!RRZ2 z>T^6~cb7v-+=;x_q3*TiNIoTaBAj%0(BijJJAnw^n2_e$5hRMLBSb#ta6RDZepq~}<=27EIX3-DXP}>b<4d4oi%49WfR{C~HF6oei z#CGCAvMiO`o3m<%1$!VTCRDLDfq0y(;0>emR&lTKt+5WUuXD_X!fV5b2gov1BdY%C zV5Tl>mMp(Wj zPu7?dxQau%Y9~@4TQQ46NO5g1QZ!G>7(UNYeabmcgTUgrO@U1~RIfH81wO{T65M(! zFi)Al#=~M7l2hA_6wW&?u&snWw004BLV(SaMGJT0BFpx}Awx**yxjtHmLisj+s z3=6Kq?L_WS+9ksIkgFl+IZ{Yzp9t&^0qr3r(H1sM_6Xu1c6A7Mt#s`jq*&QD`Ytn` z_}YWS%PxXtA}uhZAVlyn+!dm|3vbyzx{w)cU(h6umw47?cbVV^HWd3UPAq3>a35~A zR+MAMR>iW$u`iS#w~K5;-sLSw=d5D0;KT%DL-06Zq8@plw{R8v8rJ}4AZ(pw zHaIY32?rx;k#|xB=-mEn8#oqsnaoMuMThxw)^H65Ipe&Lg4CJ{aw)6RY1Iy~+bqX| z_MX=6*nXSqSg=&eF=N4Owu;~@ID4c=hv+mJn0-m(nNix&1BsFo9-dF} z&CB%dxknL?)Sp^5K%SQm{dxZbUi6fik{gu^fIDpcuf^;Y{IfF)7H+jB zrh8@uT#=Ne0K^aV)rD@Z4W+Oc{qfby*ZWOj(SxhB?$^zwQ}Hdr{WTPn2C+9?ij<^kg;4ZdyOfOj@+qVPxf=VhqpZ+saq~3H-cf zua2p?U{YeJ7Gno{!zLlXGfF>pC`{C@?iuM^9Ar;R^F54MCsTj18>ct<;R!z61j%P+ znxFhjPXhbv;P`5Y-aQ4Tn=VtlN>f;GR=fQi3Gdsgmp*042cX$g>FfV-w@dy5bsSeL6O5!nR_%D2B4kVv>4t zb4|b3C)>K^=E@fFB554`f$182=LLZseJk@a($E}mn?WA5n%rz|zYDWQ5NS)!R*Zc8 zTT^Ts;A_qG3|}xX{Ae8QVyFx@9Ktrj^6Nf+*q?y4_e*ishs*$3dpsbiS+uNhz#%uN z&}pW|rYit>h^aaGj3EC&n*a#UG%6ejczU&guNATlGaa#oIY#crtv+%Mzid#R>u-D- zf2TNsND$Kx3;s@YS-NKqlW(=*Pu`-ioBT1yDH~|{WU=z@txuK|9MCtUjx#u&3_Rwh zm{g+~zNMUxVspo4~IAM;0*& zOU&yppe70Re(viz;B1bmafnW|Ma8xQw9L>{#xFCan+#KI6c{f|xAU8O{xwoPsFr`$ zvmT_n@7=SQvaWINR;`xE7IVm=I7)}qL`)g!rj+teSYjK2{~AJ%)x`+6(R|&OFZ*@E z|3c&YEamq$YYYt-u>FO_jgA6sBGV@J2C|7Qbjv;NtwyUed}8q~D=GYCwchWZ`tYl{C-k?{~;a~4jm`s z4IdG+7{!2!Zt6;U&uwKe^yD6a33X@R2x{j89sUy%fcmO zTkN?_!6CE&leW=-m>Qjf67k}>+)yS0Lywxv{Tz%&75(VR7E7?45eE$8O#8N(dH0r_ z51RBg+}|5B2k`W+tN3AC0)F}O(3t9Eh5#f0tbJex2UbGhs1W_uP^L1sPJX7bfW%y| z2tUU>vxQ9(wHq?aH!R}=@=+UM04lWfHk4%7mTFRkb^%fwKWEN+OrS7AxOTUqJ4$}hp(AfHz)QGK5Y+?7_X z?Qu2;wZ{t%)h$3Y2Q0KLYBGGJ>*QxzvdhTjcbyqmn}m$swBI*Zh?uh%tI_ni2@JC^ znwNyfkrRAdN-5964W$$*Y@yJ?NpbS6f1XHh>s^R*}r86Kmx%z!A%-7LS~QGy}x zlY;}HS88A@_rgUIO=2EemgI7ZDgmBj7oX3n>hiYD{x+da-O4YA*8R)|gbI^y}wfd&uw zXK8G?(sDQHV=8J|-3*Uw!1d~Qe)C9=fa|NC0oT=49vXM6+4&)a$v0jWlsK&FIO`Mz zNW%Pj?J6Htw*3HZhv5a4X}md#W0GOMQ>X)=7ws$G19fCFy|L;?PnBHgHs7KTO?_ea z?Pz&TaEE?}NjLpR&wO<+kLt`%U}p!mu6}#RU^=- zXN4jI{7vC8o4UsCHJ;(4MGsc9hk?r&^kv>Eup7%uzt_?LdOOGu-Qu68QBpcebeEmB zF0U&IJ!%C|-d(gqrHO?v^?{aF+=2`oF^zH$*Xh%Nly!?w0py@l5VI2TWuD4eOl$rIzwvNuy@tg;xPO57AEHwf)p_*q-LXT$Rk;6<>Z(rMJ?NjfCSPEmcym#em3Aj` zws95Wl!<01V2Q9cn0uY{@lZq%H%`rEC-q>Nz{y~p+H|6wx@92n&5>cb<;nJZR^9j{a(^nDJ#=z%bkwj;Iwgv}q3p%2Q6z=Jv-y(RI$%4R`gX zKdA1;g7Q^N`Q{Hp#pSDjf@O+3Js5w$tPbPd<#L&^J0LiuL1J33O@4z9DK0slb~_AV z5Fl#%vRXZ>F(|mt4LEQDe{IdbGDvz78EtN zVJn=VzGuKo5C*i2Ns==KGBx>Cf#2)102WL;uOV(Ko5!#?0=*kRFi4PigQ<^eg#2N^ zySnLj`h%)9E2#lkA~QB@a*V2VLpCTCfR}_q?7Htbh>0OAH)HH7@MSqOEEU% z52@bD{TZFm6+5{pRBOoG(pY`YIzyK(af8D9%jBQBfL;lp#{?$d3TzWq@pqO=i+-53 zo0Yq9zrOXStyro#yP1u0{SQ;@HuHb>+hi;B0aaF4qN2O*3uxEM#ERTajB%u21&MJ@ z;h!Vy9^Eqmc1p(I{k;B2bbTv}Z!3{&vql@}t&h*|m86);;mv%U3m1=&-tEF)1*DJS zat;#Qd9)@@o(POCD`C9fR##g8F32w4YSk&a{Gn6i@A*S~pqvpij7e`8P}N1`!Jfdg z!d>~ZM;iwUzIE_A$E;sn9ijU57n8*^tpMX)!_VZWt1&FvcArF;311YudR->^b6Z5o zE?%4Z5=?+VCk9;}ueAQdx^+}d=egkLK&P<8H=~n}rW-aSh=16#AjR(QEn*PZZEOFC zo4@q@i)=z*1Z*B{&v9z$$6IZ5>H;yW9X*o#j+yF&$(7&o{+ZGadWki^`scY$F0no+ zF-8s`=+%%!c_6KGR@ZvoPwR2#ZGCbt$An=O`(28i;Uk(N6OoJ?Y*^bDs{UgscExeb z6gvff=TmiQUSeiu?n4iU^a(YEHst(P zICNu5obmq{otj&J1=iZ~o)KqhBs%X@(Eb*-PrA(UTKaLL7ObY-G7ZYrIcEoWQ33azlp9aKSF5+(Fx! zItFqcr|S-XPPm&-bCLq>Bwa(@edw%~d)eDRluCtGm_q6d#AA!2ck}6I%_7`+bZj zlhOchy$l+{X+nW7*5oxGwL|E|H6tEU8o&6b%M}WKa8h`j;b$QySgzP;z<8x0P#Olsi4xTJ8wz7}v*Bb(hg?U#ELvkr?0d#~i*Fg<=hQRIjJe#V z8nqBHG1YK6VryNhaQ5Wg#QgYP#=MvLWj?_3h(Qp;Y8E&5)rNn*{RppY?&UU*mM2WL zG@KF!ranf(=$9w4&De($8Jp}3S!}wQ`$+ZkMLX}2-WMQ`QjaubNUC`3ZqD2Wf?sg7 zGviA)ngilX=i&-&bgWkO{#-PG$zuB3vkTG7vNOkK}Q!Z>5NdmJBq&ET(A@+)47YWjH9>&C6|Jm<||T9 z{&2D45k>IT!?^D$NapDvOBv9+&3yti>YD^^Gv6}*K zr4PWsq*Z_uB_r7b3dq}VJcH%N&3r;f5I+0;=d8(=^7@OT?&mWsT2nPBAfJ|z#u^(t zI$93&=0?_Pdq$wxP8G{6Fwhbish!Um^5PhGk!t+?M|DpV$H9ykLrE+6fX!zl6?_hU z7)_G-YrGDyzSWL)Z1j3$eeczETSrUBDOLGBF~@c94|x>Mh3#oRK)Yh`ywHDPZ=ZJM ztk|_{+b|DqIqQaqfjdP8v@5gp$9JwiN2e6}H~e&?sL#JN>L{ksFOy#a)p;NqkxWRh z^<_ewB4K=>>r3VarpU5VE8v4bh^stv>+=WuVvP7`9;$ey0!Ixk62jiVTM&--0a)P?s*xLa|m@ zb85c`LWK`XFKgM|rw`dItw^CDYR*z)0u4-=)i#lhMGvO8Yr-+}&D1^-`C}m94m}a* z2*TAeA^#-6RCLIjNv^qQz5aGHSb+n)Z~bJX%EO6%=}iOX0q`7Yz>Yjb#+3?J(<#A8 z5;iH0v;1wa_mizrkZE^-5ogv-GX|fjdUm=hWjG8VA&;g3=?#DETc1lH-p&B;ai5Y} z-@b4biZOGbn;G3!f4@M;m7Y&y8f4!tT!dOxAx&J7)vAxot&_WH&WDk6>*nvy@_D;p zjOXum2LIEB|7O)1RkBpfDt$=Op|Tr0U-Erl14J)Nj+IY_L$@I}?O+G+G`l@(m{I&L z;2Nex2V$~~=oTH@%;ED4K8r3Lgz^ndC~4ph4ghBLZTb7$jqrX)p`G2J)?p|hF^%*Y zBUo|Pr_Z51Ui|{B%D+hu^oH11p~$a0tft#f4s) z23W79pC6xwMCYw&RY%olm#qUFovGy|2_xxYSf!JXO~fhWt{KIKvZ88eydV)hHdp!@ zwsH_PHggbZ4*>2Zotb>=+8|4LRRxA4uq2bt4dWLhn#U2KgP(y8ZU*SlQGB?|7lae8 z-Pi~#d;lEJZ(VQBN;5JOA%uqlCfOv23ii>O7x`y*CrM1HJ)Lu3OSv_w<%jKnGyJh< zW-TMoYtJWr_h>lx(9VoTx>~C@fuuv}1>eb?U_=A=GZ=_+`ku+}7fw(+J7lgA@J_iLMmrGBD~9C9y=7jL_Yl%~8~sPFSZD9Q?ux~-Y`J1f@&5%^?7ye0tUJZ}ZOXR0p`y48NDhMF6q_gK0o>vL zket%kMdY7!HP6Wqx!QJI98uCmA(pp|bu>hosOF*(%Pqz7hpC0)W~?aATn?k)ESx;? z8ri``)R-j~rWhoQ(<55Da8r59_R{x#;JO9FvU>%o;>?M+T;OH=A7J}J&Ja;#7Voc4 zhVJ^|w4fRX3J%VALE1P=qAQuhdl#LRzEh1enX5g>2Io%nA+u3;8}nFj=+vTdIz%fn zKY54u-YTxO9TqU>AVr)h(TU7Q753(*i_+PvU~Ah4VB$e=oHo&tEKS}{-Q&-t0bA$N z!SNGy$l~O^R3U%%8qUFBDi;V=!(!6_h3PkiMF;|iX}Ay05j zvD^*0p2D*0qT}Qr1kY(YuMYzy-Byy_R#+WS(*3=M_TTZ(uJ3IRUwL>-wSspIf6JG8 zesZUSI%U4I@qmu=g?$Sus&6Mp{MR^)p3K&Zz0I#!7id#V79Lcp9uKE0mL~M?(rS(= z#TO*X6fr?v65e0E$A04F&1&Jr zI-8r#b_b96n0o3h36#6F(qg2MW^M6u-;`pHKDWnvoe!&7sJUNdzL51MUt)^>?(Fp? zJh|l6P>yr^vXL&I z>btv_TB8`3l{mX_v1V;(?p)`)&vqZ>y zp5$lqD^#Bz2;(_YrBE6OtfqguNZpX6fu(*%c>PA;`kPVJy6-Lk1gBr6T=G?}6JMca z#_J7b$?5&=XBTGrN9V3lE_Rd@F{MjT=K6U?w6!RD^D+hU>&5QYI^aqT8RCC;kpkeo zHEbNUWT*u}3QrUa_#6eNpu{C5IY?i!2}6KL@Q;Gmiz&vfv5oUz2Cr{K^QF(m1APg) zsluhTZWtg-M%a(J3IAf^00uHVs+l;8J(&D15C2(alGGTw83qR;(l?>o7eWNeB5FU7 z2r|(jGtsF`pLX+oE`|p15&rmu55P{ySACf^{_or-LxjjlMZ8)yRxsf5?bnKiORK%%pqf@?@`i$>V=|)d`CS)myRjZLO=%hEGEl z?hO|E_vWxfnW6;Xz3gvZuo}yZ`TOIZ{UeD9!#%wW(TEL|kQ)AY+tEtk!gZ)QkM+}AEY8m(_(TolHU|IQ>)T!IIV$b0CGKuE+V+nN=C|S8f%u7Ou z<%WLFKFjjM>uz3}WhzF6mT&l(@Jy*jfYpi&W~%3rFZIRt!=+5fvvG+Az=1-;vmU#P zZZ25V(XWGGPsRmHzKCeuOb1AEY61Ga7Uv_-a>a`PA&YWoW#V*{kJ6KJsSz`N<~C!| z>9;9v8&}-9Y!Ybj=cL8a0d_@~s;z3m{~vpA9?;aaxBu4KYOMlV#i&eGs}xZJBJ+?c zRq6zefHFjafXZaRFcVZNvtUI8WJpmFm04sa1cZPHf{GC4AqWvdh!98uneN(2Ky5we zJv}|=ci-Q={sAVqv$M0;+G}Th*7JNHzlKMdJ@y^xLm<1Pw)nYEljF5x%$Nt6Jws6D z&IzW;S!!p`KduaLWTyPuGspHvphiIB;~G&$3K$00LrOofkaAdE9AZ-8U?|e5zwuv` z_rWlD&ID8nm^&w*X0qxEn?Jt0ogUu;0bUR6;Xj4$m{VxCLC366dB2nvW$GcW&Db_q zrEsk!SZsUwKa78yhi81UMYSGw3#$S&wU2_1CMpWB#qmc_;d4mVm|k%EP4175y3jN3 z{jLAassL`WR}kx*EO~eGuIA;@Umvn6R8~Ov)LBzjS5vm^>`sOi#8+VjqmQ!1|!#zN8)&?1bvW zxmpQ`OGHn~Uq#U*r_bUcwBs?{0h6-lorE($rUa;lrfwP;Apa;ttxC-wL6d$|uw>Zm8n-?6v+j5|qFb2MP4d1qsZ6VsYkO&wY0&Yz^`B@E_5> zD?}fnZ~zcj01D1U&)g3Lc@wM@gUu}}g^FDL@>lRdH;>nufE{5U3i8xcAesWjX1tF? z;2F==SIg(gnmcO=WmF$hq#HBX=Mja~OjuttyTfo4giC;of+ymPf~AA4OenM93BN^u zRQn-_8NjwRl%D~@Ht;tH5RuS#2OF7jO^ZEKk}f^MBefL{<$!u+1_KlyVNtUaLc%j$ zQN*|mJh7>+BqbsIDKmj^&u&oU4pc=`Y63OuNHW4=W)mLPfYbqKQ}H4?I|NiUn+DQ5 zAtS1|NNrfQ7EvnMd<~;h-t_bY(jo`KJJv`G0{DrJ+IUmLw6mAAs60M5lXg<@s_Q!! zg@&pW&(7>O+3L{a0nDBzP()1vQm2*meFIvtQ)L>2hbRVaho>e-DMOEg0(!Afl|%dI z(clk4&0VUsbF;NI{Q-2c6&ODqwM_!zx>UP3X^{C5*tQ}D2pvDFAvl;e7yD6ws4o7N zhC)!;LX0&pZ@eZ2;wbWMpgM(_(KMP4JLy2#iSb<8@B0#_^Y8--F2nIyO8Ck z7gBN=p|#C3L8Zx3AX@_HB>m6OVL}Nd2oO2UE{9Z;h}wkdA3?oC{A(&GyM&ttrSCvT zY}D`YIYO%LHzi5C;Ij9p9w3xvpW$4vewL^RiZ{aDKq5$pxplflvvz*(C?ybpI)X5X z1ic#psVIT=PWUhHpXW}@F*{*GMROpXf@omvH@c(n(EsJsH61;T%(h9(lLPx7cKt(( z%Zx1maE9oYKA0;qORm7qtP{eVYIC2Up33o4 zJK7g=zZqM?{L(KdpmI$vyj?2t>%mIzT!i))vfp5rVZW68a&gd-{^hDG4tyC?mD;a* zUR(Ri0}G5V->H1rzeZL1z{0v<#PAo0uj_(VR$e&2L3>H#7m7hgDv0ShX4rtRMoz}~I@ z2mfTXIX>ub@|~P_xXWLP(^Q&7!-(M|!wqlu()qaz!F05d9`S2;N zT)j0C%r6jM`d!9X_MTEuT{O0oxzMgkTUCBcO7e@vmlc9=7lSIRPJOBU)jvrd>DPXD z^t`_If&+^cFUM3aJHK7~>!fc@7OGc`oLbFZd*-V@r;Y$j3$ywHyFleDM&<7QxPH`& zuT&O@Y|w_qU;cGjCA?T_>sJw#;v>Pw_7 z_`Br_`&YhGVlVTTkD!heIxwzvH$UV={v?B&;i=!um2d1^NY@_=7f zlv(j%VLK2#AbLMbJb&pw)> z!i{75N~u|IZsHN9T7pMj=w2fJ8I6w)c*IcgTh(IFrljfb+Z3dy-<$VIV=QtvFvcGL zz!>XZZ&3f3nSMMQndm1r#&X^@Bz`9@YG?cyzwzxRFhdPv4wOO?934S;y{&O4P`z=) z>4=9a9E<#}3-c(FJ7?C3oUz<;;EWyTO5}X|xAG`9!jQn+Da@m=Ld*sq0KyVlQLkOK zis0>|16(kD^?h?~{T0e0D{HWP{jY|_F1;?SbPu-wu>QKnlQw~%nkGExsWqR{P!2YMW*m<1^h>K;AsSqsr?^t6! ziu}ON=kQ%q&%1|^mW6Q|@3S>~X_d{qD8T~*{Hq`4#fs!rb2t-dPMU$|1LSIam?}D8 zNP~O0eHPk@c*19)9hgu5e@kd*{g;GxrZ4`w(2lg=v(WCd&`wm&^Rv+Iv(WB<^Jk&m zXQAC^q1|Vp-T#|HJJXh_d;8a|>pX^KoYTSjotZLcXUDCJWLpw^TN&{!+R1UH)i}V#)%nWYbrUm-4I45;hQha7d-u#xPr5NFxocwm$-7xTShS(z;xQrsxeAO+f_re~2V%Kxmu+$;T zZl@DXYdWJ{`@S<8ys_YDL_XHT!IzZf(ZZ}>Tz@&wASRaYeE=(?z7=Zku)rNEN|SX4NbGAdzQe0%Z~6qj3ol-MT9lTU%6$ryaUzo}Uz)yb zsb`}bklkfY%gY2)|4rSUsB{zZSSyr~$aJbl3-c%bV(F$&M^9LPW2XrvPsoo4jkZ63 z1A-@dAN0m#eNexySnA=wXNJvl71ve*W>`H8-3$c1nh&~}fw086)YzLl%z>@O;k__` zacXj<$kGNR0>b)WAYQJ?PQ&iryqgMm+Wu6*uEMBfK6)ZA!W}9r1`0WiHrhS^gBkWZ z$V~j}_3L`3hbDU!? zv}NdI4wom`9f54hSF7ISJ}SzcuoMOCiDq+VKwpwyn+c>?!sN-Pf%PCRaho_G^-oot z++6IQTe#euxymxz3#Gyhe{LHBKUkwzXLhVpOvn|jLi9=%>RsUl2g!X|y)~|yu>i>uv zR_5NkbrXsM{@))PTi_q);=n%x3G1@B9i8)TSf_>mlN% z>}NMDs4e%|4g1**3+kJFcEgI|NI$z_Kf7W7FLuK|(K-kS{jfg62^4RmRjfXuA-V3Y){)FvJV9j+p9O6|AwfpCk^lRjNO4h z8-zNF{#mWWn@Tp6$vkS3?rP>*I+*q5d6blt(^Ia~R>Lma(-ZFoVK-9}UzD-mIWC6f zu=`(5rC_^Yf)dAQay}{BCW$%r=f}+F*N>8PLCfU;!Q|JG4JSPIyz#W!beuM8 zGV-FR*Xw#w23Ao=Gb9<=dH=gA^_(%>0jI!|L%t*Tt(8fp?x}A~lx*OG|7w^QOg=CdW_ zAHP#~0RcQv89PW|Eb_-DV;^t32N$K)I5j;l4p+03X+{T)Mz^9aU7KG*O;$8_Nk<72YOYK zTFki*ld{DhKqVk)ZH4&y{TfvS%{5fSSDWUGj3D|ZYRT1HDEh){MFOmlt z#>F+!ljK zSQ{g%_aH^8dcRW9ig(&9j>*O__b2l#9=-?K%;ef@>Sjkp)c|hHuT-eM(J%OTrQ#+- zCv9i#@nF{g!|cZ@e_4o9{g8!o!xlwGdXYABIoW!buC`uBSM zq-R!f?%7L)&xNHI=0Z0XX9+Wo^RMa}dN7?IWD*`UVx0X%{bRF0?0^S_6%$J~h*Bcf zfn}uJ#mWs8JHVHO?Z{fZJ@EWL66Istbqn*cw{Yr@wT-(22WoNyW0D(9Sy}WunHwU zCxb0>rTOMoa$H_>NeNObKh99v?j;QsrO+YV2c;E1PT$C2cpe5F6B!4}M~E%Qb2%a) zFK5o(9GL}7+-K&U>QH*ppHp^Lf`DvuQ61LX!it@a&QM1a7VO*>t9}Di25^tM!#a7` z-QtFNGV~04?B;*5VyEDxMc1gpRhP)C0qMcyM_2Z$mAh*4uc3TS6P_t!Zst^=IA>Tr zM*EizU|!O$?tM`W{#ea9<29{ha2o+R9qDV1M0zIDJ&z1&_|T`wHEzB&lMO4&R#`+t zRq23{a&Q)IB6>0q>*?%u-xRV^-omLjS9t$OCHW>ufnskv9TR*|vW?aKheJ+JR!yd* zAI!R;k$nv(UFN)QuvOnbaPlc|$4WBkCZk7A$-Ykwa1f^kJeLAxzdBvp?#!Ihd7L=7 zzivJ6D6~sq5&+01c*g^I2gW>g+Nl!Om|6m;MK(Ygh{diAOuXY@-&r+zlox#!4a(L* zWzIgxKNv$m?#8Vl#k&${lwwp_hlw!k!a*K{9vBI3!A4UXSxz>O&9WPy-rdYa zU%bzCP+PCL3<-HLeHEOzQ#mYrGVB5u)INjkuUEun>~!IykUg~21a0`5UDEqj_bdq% zMWcp`+^<)#;yQMv^5?*(vSxj4g>-nzeu?o8&J-G2XP@|Lz)R(+vrlGe;PD07yiA{0 z%nHqZ;;;5Dq}jOMn`bz}!g8LQypn%|9)qHHOc|*3pyYHrZx3k2R?i0@DMWuJ#%Brb zf)(^wc8tOF8&!O})d>s!0BRd%&Mcv!{KiLpO^3kA7Q`>U_3hwYxvf%iWto-5JNWrg z@i#z4LC{CmA$W%zb;{v4*23p9J~DZ0Ymc9SaZW&G`aoNr;+&A;IoaF~@dw!xxalK~ z2{Wfay}ITq%IVJM?jbIVk3R#hSP2BlGrxT59V{exJJn9UqLVkgCay3}Xm3{LWS-+# zKzq|E58mEY2l78anuvjKAV_%-x|7E(98%(8!X^g!8{Ezcv<|O<a7r9rP8H_Vf@1mZFI>OU zx|CS8tzWC(?49L#nPNGvWcPycue>gmRT=d6zFRtSYQ4o*%gmySs=7uZ!+oIXHMCLM~TV4)lMc`nBwms4tC!>?)4|;pEb&g`f00 zi<=Knl2m8M}uT;FuN^&?zxiY6Rwtvfcjp9XCOR*B! z3zt?-_b)oXv3N=MGS>x!1x1(pE_YSZ`n~!k&np%$?Uoo{z`D$-oR0x^IP|5`k3k`o z^s4Iq74LrhL;7$3+P|A7^0natvAg+{(kG%stl3=`3xAh8Bc*#$x-$Itjc>ow{eZw% zK0Kw73q?L()Tq35YFn%i>d^j2c?i_4NzH* zDOg>f(<9I8Z`(L3`r?zCas3Y`d)hbZnz0nBXKFnLWzw~he-jLQ=3Ao9%uF&*-c%13 zQ@V&K)+SQV^vPENIlS<`Wq7Tgv6oWUg9u#@i$`>n<2W|D8XW=edWT%=FmO$Myj|*9 zt7eqZFwzcPx!xecGWf=e?3Osm$Z@sQ zXZ@aC4|kjSA@7qKx-CFMm$PABLsuQMe|+z~0z{TOud~}y1nSyDki9oRh>QYVsdc&&#c?74 zd7kHtah&~p1v`DM{!y^ih;bKcc7rHbaJV}%V+}*l(wi94aQv7HXZ=E>Z=yV=&{ z`{l#et|hsC|I5YijlSP`*df7w`IigA)|}pR=!WW+pkMc{`0A@A-lcjMFFyPJ+i$jP z`1VK1A8)VtZG?3Uhh>(qbbI>XeV70NyCRJgFrBTKVV-sJctM7*Ic{_3Aus1Vg0`WH zf_qbYxc6-d3`M`OL*>o#i5o8WyQCP+z9-Bdr@G>%6lPD(_}d?Ow)b%=zrCTQ>=mZb z)~9i*U47ATx+0EI+SzePL38Xoq`H)~^L{sddzn#Ywf?V$4>6Cxt7UPSahTmQqplkn zeXNv@7^7j$DZ3S>V8@8{f+XZ_{g=U|IU|Uv!|FIaLk}Ybnb9ovexFm=q~_BKF87Pb zTZbN22V~w@Tw4nrA;(sJ7gl(zNDu$u3Kdk8=4(02K8^)W?da&8<(ANn$CE;QUSE;j zkn%_ZGi9yi{)D@ObfwI*Bt2RJgKP}Iu z)yDeSLObQUqlMrAWpSy|=aQZ8cPWZb#<=Eo0gZ`O3TfPCX9qUdoxSYyWFnq=cUxb_;uhVyMZZ`ZjPkQ2Knb9T?(yVS)W`}RIRi^aF$3SN+a{U|Cq4L#( z(WN;F8LIN(@pmLJb!`J4`;;7px73)`taQ1rEGW?CueH>HK0U{le?oW+w3r!3e5P{e zA-anlF0b<{$FhiIRbG@B@X+tVV>4KkFhorZljvyzo_UDNj{O{P7)i zr0~;^vcd7XykR@%J5p`@OoaQIN zrA%v$v29KkWlrts1qG=owdh8l#OCjN+F=+wTK05yPm2=`oMccQancIfBj?u4GP`@B zI#367VPe<#88FA*x9!Z+iHqLPG+)2;>Up|^G3ajwx&_p%PLzDm0%ly7l(ea}>e9pO zMdu?ush@$_#fJfo%s&Ct&yGF6KruR>@2?b6}Yr&8nA z0nYb}#%I#zyK&<}s=YPnACpu{Gv39yEJu0hQl1)ULW7wJhGU~rldwJU1#LGb6=qAJ zW2e*N3R!ta8s8tfdaxw!=)URgA^yXlYj;W6eWR1KDg_7Ad>!!I>dM56EBV3KC8m4T zZ&dgZytAvUuC$k;gO5B)m2X0x^L8B*_2?vpR@J}=@uk^_5FJ&S`aOSsY%^hU#@tZw zjDnouD&qo&7&(9OkNmD?!f;cPU&o~pdSnZ(N+1w)oM2P)=*c`A2AzaGPg>9#OYbt| zZR<{2DJk1pZa7eBxh21E-N~z(uhX9gKA3FW-?a%pz}jaT7-6W`kxn%Mz1^ekj^;tk z^Ap(au_MAxh{SgL#p0tna(pMG>omIDr;oGDeL;=9nX)fDu4C!XS8nkbFRQps4R#S< zFUHeta#F$z4)puL+Gb(5zZdArj$W}AVdtQ5ML2v-d+InM?^363Jp=9Z>J=exTOQ&1 zP5#x1eE2Jj@4+F^eX;E~U%*5|P69(UThvuob0P*?nCmH#S)q!c%QaOeWo-x%Zp2d+ zhz2j_UO{S3hJi;Bk^ox#c~DB;k*4)^k+ii787(o*9S0HM9RUw>8cuyys104$*!Yka z#?#QJj~R>+UplAxgR&PS6S`hRi+27i*8?lO(kt!>UKlqy8gw3VBO=oShZoO`*%{%b z9g|B&&nsl8ONn|WG^TP_L4CAh81!Q68d;b7dWG7HyCM^GIX{XzyB@_>{`haz(9j|m zU~X&@uaL9!Dm2W1F}#zy3wnnhMH#tqJ51(M(6>dWtqosVrLU@43th`nb&hZE_$q=C zyq=qH(HT?>8+xozUu6}>KaU}@O=ZFqSW^h?RGV_$g*#4z5qXw)7lv_t-XHTVJR+hx{ zj#|5>F8@Uj><}mFI(yz)mlh zCnTzQu0M+J1YR-;0!dc&w;m1cYrsu2s^XM_`f;U312?zh7g>5 zEHFDUdnb>dg;-~4D#J;eJ&q{OIdU*ZhURyuK+V2-I0F-xgZE8*GS2&YY1u0(S8w>( z^%R?|vZP+eGAC0R8}=RC-7n_oT71e--Gr=UDIVYRF!n!cm)#O zx5JfG?O&YKNt+;|&j{MAJyW`|oO}hUikr{kBOmq=2&=Id8@GQJ^@bK84K&B(2&DsQ=_Na@F9`YgN>Vq+sI^5ZK3{L2qC<|yC^hfKF{n6#nR&_IE?+*FGjZH9?0^=$A_1TJIEQt!L z;YHoiMkc{Xb^=*rS;strHzJf(e_{*RoO5Z~gtfOwk>^)7g967ttna{OHv1^hm&ziT zb_Ur;(G~5{)5RvlnJ1hLbu3E-G#P#8TiWDgh^xi zO~rtQRRBb4Z)roL>AAQr57u_p4kV2(!N7AK_z`xFx*2Is1LId#n74v0w<6FYnpLfh z9INv!SE!wN;?u~c_34uCO!+7`!Fmx6zfRxp@kis^^sA=L_7-7|$8aGh#9S%E#J8iulOxfFvneqtNm*-)AEF3~O*#;y`eJzxs}3o~lWFbjqs;2^ z^9D`Lfs+gzm#c^n)c*?)`HlbtK88yiw0^-&9a#b!d*OB^#qDgG!h>=6qR>mIYZg{p z*WX;VKdcIMR_*7l8*fTCgr$y1E%@2J;P#r$!4a{q;*^P7P{;FbSy*STZ(Fso=9j#i z7M7VC+oY*s)5Jxn8*8>{u968;L)qnptx?vLP7HfaT#eeF7u$X^a%ETO1nP3T*@0Cf ztdbTwCf)61n%iX;#X^`~X#Ge7iM$b|sJ44-tLKWc@CPWZiNkqi?M7ot%3PcXTk@pK2NA7!D83ujwEsO|2Xz>!T}* zf3o_q;zyNz$djv)-yv7Wh316jgzqO>p%aR&tk!3*GEYPNz5=-t2_uC%5iw-n2?=yi zyrSzWt*}JmHnQ2o!HI3?M0R1Zb@%#;jpM7l!pewVL<4f~)RhZpTs(XnCKXDLn396NT&|Z)2`;xpMDHdW2L>NLD=|{_Q_*ew0t%WsAehR#p@D zq3rX*EskeyHiJ2ZY7oCsJCeDnZ8bG?%4z1Ww0_LbarZ+XwwpM^J(tH2qyB@X9baa?PxI&mhL>+M@dGEvxc9{1`?lJBG{f#No&0)9*3UMhJMvfz! zlXp)zOIju#wIQ>x)2yl86dKTn&zhtdSw-kln}!C%h(ZvP0PE7C-Vg z^{y7IriThby@-C~fQd`&+=6%T&9e~;P`|JFVKr=j=qBP3RIbIzWh)v&?TPNF`xZ8t z-%0^Tr1STC!d{|`@_t*hQFC=- zXddclUP8N-^ZLS7mSI^$?YygN;G08Th(D|Cy$}|DpV`d8l?&=W{x#EA1jebz$rL4e z!XJOB-}%$#;MLs5AOE~*iV{No*~fbpzR^tJMB|yM*ZiZf7Qy1i1BkN*KCy~8@7t5_ zo>(d~jFXNIOm8p@yq+=aW@l>is6f!(ID|@-|BW7Vp0|GOIXZjFp`-U{?CtJ7=^X>$! zjpaXj`yfP z8|wxjrN`-5Hu+itz>KuHXg9`fTI9Brq{NEC0CGRVidh zjj_+kgjrk3oqFS8W3DIskL-jHmBKVF5f+l)rO>JiKDB9->3a&|*F09g-a}?K51IxS z;jRTX<=h?@Paf|lG^YhiWC4zs#V#AcfyqZqz~j{3@cddUOmQ^kpU+V%Sb2ta_YwnE zp_HNNOEN<>y7Zb+4M1#B71TW@@wj20$&ewSAP2!bBGJx#2)rvEgUf{W>}UN8v(ZR2 zIdtpEaBBQF;9L*fG`eKVNEYw7(j3uJWK;?f%6P$zl5ibon&N)D>!N1FY8zH_w?M}k z@SGbWB8+~VG&;7?uJ!{Ar4CI~_1YR_3*CEZYI`B%xn`GSqF=SVrCm=aIvkqYu1*)puyl*?@7r&VF&3MzQZtqrjk<~st}yY- zUB|>-k-*T(&b3>%<$98GO`V*KdIsasg)_@uT@5&_)i!>T-f4huYl4pQ)AQ4iJk~xf zL&9rlewtwJ~AY?p%)ebAM7fxc0k*ZR=vFGWdV7Fk3dSt#64_^{a z_6=sH8I)el!X8rWWTkLYo5E$!s1L6a(aGMEPWxu0MBIeM+~37%|Dm?=1K9fZole>@ zxW`jRKDhLFssK9d9hmX4g&6JeHB#2vPCs@Wq8k{El(8a2b5$9~A2-R24sg?gDYZ@O z$x2;^>Shj0FpBOAdG|a(Dbln1_ley5NiLPZ<2t?McW&jrY-crLb7mZ!5XIDsqt93_ zIIQNWrX`D`&e`D3_iu~saFfW|BQx4EOgQEg6;m6+pwJLhmphb5)X+r`vV~^yj%Hh@Mj)tQ5=vDKX0L z*)vS2?#%m;gUx5b zt!ItVQ`=`g2DJ(zh$9+VvYu0JVAq%fl$Mks!cQsV3TUpn7%R>%CW(o4X;-^LW*a&% z!|yrKk2=;eHChm>`8p<;b6$>BT1v-=_#9XgSU!8dS2!j}%`GFVBKCj666wUP%uB|QDcio^G#l{u68ymXo3wi9@GF`C9q>YQ6BdsES8<_Uo z@tU^7mbu~1ZRqZk|ADh|Y_I%buzLuu75bhs*Qj_*U-ou$IVL9UgI&gp zm(vo!cLLFoQvpC$eNR`L)*6L0E2a4+7#T96p@k^5UuDgnFc^iR^`wWwLC(urj)_h4 zkbqW{>elY757uo$Ym-1XmUh$y1D-bG+$sL*kMZiXr{jTtbIsxX2nUudctyCtZLoj+kHFR;jKBdRROW*e0$FH1wFi~B%_JUD)XnccCGTH!!6olZ76+e>DQ;G&h-pU z8FpFwb-sI|!26+RNKf(D(z2^wI%zW}eRv+NL49&IH;~%izr;~f0b&D8N6W5JU7BPh z*bDGhXA{k2IuB5tMs57WfZ!vobv({D@;xIYS{U0ZJ*RXtBtdgiDK7l5t?&ocSp>Y>-J(eNlEKX9DXImEmhXqv;MvSE6&k$})){TnBCTSfK(&<0T`Hm$C- zVI<(e7(XAu$s2bMSB~w$Mk^O}(JmM;-N3O{(Yy?p&E*!R_5_@dtsW|Hh8NXG%a^A_ zbgiI73}V1m22HTK24=1tn8pj@jic_EFo*ygQm#y&mnF`S%K=ZMXp^icSt^5V4tcm@R1`b>s^I~~-gD7eQ@4UeqUs)0?I|TYf9?1S_`_;8lq=6F_aV^wLO=<+y z-Ui^C05lIBfHsgdR#7)~rbpP}vWgo$voJF>>B44+k`-pSvl5`kQ@RvsUSCT$+jt~{t zxzMA`;deO|(h3hICW;6EDJbhl=9K2;uqA>3TI+j`6c(e0Y9_^Uck)n_JI<0_p zIRU(8!h9ZJmJuN2s7W?QJ7^w4qW3_)2@Vi$GOVo{O#z(+Y7=nQ88z&VinT7YB)X$~ zM^rn(HLh!rsbkYt?O8K3n2kIOKO*!qz6vcAW8qPO^ zEMBz`!l4jtgFMwHSd(3?o1m>kKCRgjUGa{uwlUL~HXK1M=sj!T&db(nf?E9%V--Ed zN-74Ct?aBoN;_7F6g}`{?IkrZ3%cBL`BTp@&!!u{p>9J6>M+mvB__4g@q^E>@i0JQ zlOHJzUjej`AA=LTp&({vO#g+43r;Q(3XHr&U$(H_MHbz`kt&-SfLlDgmTP_Xg_{Hx-2X@yoV@{GR=#N z-^^Ui2^~QNx10MP@V6RMG+b|}#DpylT}a$XK7|g|KIXqqcC#eRICLZNF!>BRlKog) zljkT^Gcl5*TjN+0M*czVXwp`b75Oy}$U7!ZpiASE7lmmX`BzL&*QgU=H7Cf<=qPRT zB*i8eEOaRmM%+Ep)j-})ZXka*Aw6*%eRj6QJN0Ed-5q_UcstL^e|G=635^L1`f9P6 zm0~umZ@E^e4|xwdruann23P66Rej$Tg}Fs^5#`8wWD~Uctm%d@dbnQ=ek5^(Y^a9p zhRZMSgN}1)NZu{ z%V5f3vHt-o`$ zQabq07dFwtQYq%$oJXN*%aU|+6@zsISdaA4lso3Q$LxnrTKw{Q`K#bdAIiS!`Cx`R zQqHR9Y!PXT9(;oGEaIq^oB#A(`wSRXHFFv_2EjRc@p=F0IvK(5+ecu^nic%FWz)Y= zWsm!e_T70%a?`CS6-?!|e(hm=oTQBnszo;W-+jag>AO#4!B19}>JMJ7cj06tJX~-m zT6RZii`#XnGrY|u%wL0WNG%$F)?lU~`nV9z`3^McX^7xlwXqpk&z^m1YgD$#t-$z| z6EHG9tHzfKzr-tChc>7uD|CXBNmLkbSfIO74id?2BVq17KjddyWos-XWS&P;qe{DBJlv;x(Tp2-x^{% z#(>{#2Gl>zfkWhM4~67I=k|}~1hjddew7zwe3J|*Oq5n`zJT|B84r*ZEyQI>z+sVS zNk?k%p-8m!ze-gXtzUHfzon|9%6AsMSJmZm8YMXGkiQ$W}Pz>Qe8wO9x& zr`rs3Z-}O}^U^?|?P6p=>XTU21JS=7hu^|W|S?$_1snCUslZ=-3z z=rjl1sLe@6=RTRyQh@FhvIlHPG{59=oJLL-1bQJ7aT4M%Cz_1hgnB4186Cr}v5wL_ zzUQC~r4R;be3&aIn`e%T)o@a*aM3}jHN7C6l`^a5=jWzo2bfFvHLPnJU&lP_;&MM7 zE$h)ctp1A@b9WC=+Top8QzrrWWmfwK>kIb?(i5xWmBiv1yh5!_w87jgUo@Bf{=H!G z&vV&CN2P$;16O>$5s4F(6xsgP5j+dg3DGtLvr{2Yp3~j9oaXCwODTUOlMM1Rhx(x( z_D>Wxg*!Xr;U?2HKCdn@#d5i8J;Fup>6WFD_5q5e1!S(82NZgg$kb}ejU)KR^|%>; zIE5Hi@d||k0H;u6a$z7d^|TS5wJeR^3r}%U1!^JJq|_VE_d^(Q>r!O&ggjl2Vyw~F zcUwrq=Ea`9}Lj*9N187&au37CwW-5jUrGTVHr(wMu$v^&C6;$Zk7nt~ ziv~udx&W{`2nrt*bieESemnL(4W+p^Rs2F4(TKexF?}~J(9RMSxM>?tW^i6!SIW!X zmI@b(e<-$A(!w2WW5VjvX+}+J=VZ7#mb)b*M6zA+l|qZ*8zIg8hFu9*p=8d$E-lF@ zYHgQDN}QjF1!MaTsgtBeqgWR-Bn=jdP)CIt+&aZM+Rd%gc`LiUC{;nU^CdGBF|h{= z2+>K+zR=19!9>s~H7OKEHR*b^4j7s_&Odt0IC8;88VC%r)KuYBT(iP|(N=-F&yX!_RVEg7SiYz}&&{lrN^*0X^wn!vyglcOgrsj-6WE=*@kCIM{N&HN*BRTm7{92PE_Xx6Ts)umOfYmZy4`Py$hj@bL?QbF0JfRmPB z9V@afsc2@|kvi<7sDiSWWuDRJPSKCit;_ttr!zE45%ve9I?lFnj!O3CL^UtP_y@b3 zqKBxptQ5(^frR)FG_>BKkA6}@9IJ`jNsfP@I(Peujro=#T2T;b!bSb@15LQxc!=~z zH?Jyu2Dp-eg7`Qfy~pT6!SURkE1>L|6{dZDkpkMKtyl=^0Im+0@LI|OqW35BUr%~; z9GA%%rr#K~mz-%tKtjdZ|}n;HHmQVgb*pyION1SH< z7+`ox*~dIN2HD-NMJ+;R_+Hl)%fk*|p@s5>SbEr^X5S>Zt8o%MiUhWr)<(+0yCB*_ z=G)2ZMCm?<%~7INln4gdcMLZeI=Gbyqc5H?)Y(6zjr2$4_Kp~i0H)}sg&YINp9GX~ zd>#5=z@@=DF|p(1$&(&vBfXJt*g(x?mF|x?&9!kWhS1G~pY?$6Q=p@QGpnbsQPEb& zV)^ah6Ey`~BE_9pgRly5w6!+L$?wdBXUF*r_%==xr|ZDadD)Ui{X9H8Dxi?Oy{n*Y zsK~tC5YiI1XFuz4Gy{(z+6Tm2K>K*q08RL9)DG_qgPPN*ulz1b4~XVWA)y*To7n51 zVXHy9>0YmdnY}=LHf}SVGMvHF<)4*mq}{sYS*(wqU2Gi{F@X!HS)b-v{1GMfrZrT*U(POD5sm& zxUNKgGQfBP1I;D}0LRlW&!=YXwsOwlS$gSoXt7RYF-zvPm(61Br8dEeK*iWtNYy2S z&H}XwIO~lo2OT*Mk4KPz$)46j2}-S9qq`^|FBmY{I5Crj17YM7^eRf6u!9*i!HW>! z2kIQt+5}XfPs%&P74(gfeJm;J1a=JteC-m7>9nKQ5`gy6_lBVxq=c-C%#|-lglCg{ z)2E+Km-Pa*9n@=&mYcd89N-sc=LT5C9qGW92p@FzD1shNawSyFwkVU7Na`Cii@00` zk`jlayOv}ZrOdGfG!}q7A;A0BfZZJ2YZi=A5E0C@)CWs%A$AxHT$Y>7dcoM1Mf^EX z?Wbj{f@;X9C3$x(_AOhzKVm&+8*bvnk(#x_fYD2CGESMl?uPJj6kDg z)NBqc?+Opwz}xCIn=n3gFk|XqF0F=%TCAoMFmuxX=dmB|-9^4zT^4FhJeHTpzTa*= zww}4lFYGMJEiXD=#UV6+SXWa=#FAH%V^D$Ui`s`|e~?@q5^6+DudyYcNB^pQM0Qiu z>bg+V2s%r2h96@K>hI={RrGzTW+PotEFr zuaZhbmY5=IRz`;!6Ya<$WMgt289njy>=m!SFRvOQe6dLk>JAWW;Ga-fY0SV@e#B~( z)GW4c3l6ofd4$qYJGK>Wwp=;f7FBGqH*@RX$F%W3l>WT?7||2umlqVj)pONCO5M~AV)*$<14 zHA=^YH4vrA`_Qq)C;yGsCbDYg8@qJ(!PmQuwu)dZc`;zI_q`8r4{?9tTIj2!EH@2L zuQ9D@A*Yia$!jMvy-m1Q`f!`&gyn=)gzwE?Gq z#8;?Li#=PFZ>}^VPNKff6DeeCo}r}k9$FmCRB2o86uts=*J6L>wzd`2@adXa)Izm= zk*j6Guhy)|%WMBRQaUmGdCf3tXDp#c&1o~8>|C9&42ubTRZf}&7oI2?)mektOGL%hd!hj|naH)TtWN3>oV@4Dd|e|P#fw6zt3d{pKVQ5 zMjS_A60feyNy&AwqbO=+(lKqt=LoL=twX2<)Q`rSBCG{`LfONdtn;?K`x@qFGcnJV zV>`^F9xw&72dAUQDozewsxQ-+b@8FH+^ikX*10U?@ZOd9ZK;#c z$NTWeyNkrMF3^W_bKm1sL-+O<*&Ap35~gnTDdQbD7kZ$zdK=7i>RG2wdbcqhK8NcZ z^UZ0H&Mq3Sk5{&? zlf7AYAlkt06xMfACnqCWV%oM^WkQ4Tf>=0JEA)@fqn`Q?L?;3C1OuJZu11K}f8;!1 zjbDRNGN`TvSyRIaHx>g8)d41?oN79f4d$BftvAE#-=vF37w~Hd{i8j3)A;D=wd!f} ziYA|UW9z3vS(0%d21rj#J*FLPUmhI-O&Yr}Bf}oM*CV zycNg_3jdid%atDI{&lespdr7{WZ$Ln~tMt6u)F$u%5N#{5Ujj6J;kIF+8^C~ z4*Y8TZP7tRs!w4$gIaS1Tb=Ja^&PM7eH2_!7A%^|$_l5l|GM6eWcS`0>=Pd?D+LKb zE34WHNtEYZL%Ez%S{TVjYLp%6F!34#5D8LSo9%BU$w zASx0O%X;x#UXg(1cX!nma!Thc%`pmyz}h6T2*X%99aIuK!lrw%XZ;kEiU0Qs?chAu zp9l)wn4d61oI=yUCz_OB7aD8im}4v$tCqgo&3l<6 zrI3lRd*c4Vqt+fb6|~BZ342bQ0?Woe$Bvm{<4;6CFSH71<;A8D6rA}gl}UH|{20GJ zOX6hbye+kzTfI1p$JA8|Q4CFGNb>l*dzU&S1#`;&$Az zhzrR!M}Zl+NTG1k)NZj!TR$>Ggp+|ug#9?+sgB*8F4r+bBv$)?`@8~x>0EO+jRfyJ zWXky6G4qDUQ#Ct3k+|5i1ppPGzoXvrx<e6r>wg7@o6GzG0she) zaRNw@EGI}nR=iWxdrvq*!`&#*`a&jqhykugh&+-yhI4>qTT_*$-Myv{hr8o%a-l)d zxUUtGpHg>i6kij}(nHWYu3YI$xlU6GqElglAqer&2Jfh z=nQ9`RWsiNRwT9!t)Kdh1u8M!S47m$sfkmvH56qUk2l&nAx^A|5xs0gcV>QJ!i{FV ze@=p0hSVLa8;wkulM42DMWzCaYMpOFwOJGOL%YBHAMF0a9{UiH89*f01PqhV2R3Cz z1Ixd-w19X`1qj8&7FZW&bu(%)p%p(&#$%XgeW#(7iM_z>ALVD=!zvXi?`R=D-+aue zfH8hB|9MJQ)O1nQzhw1a$a(c7*tePYsNGUC*HcRx55gRCBJx7*Ex~SF9IT(2*Cy<|%7A9a~Tu z0(|ulX(&a?Ou{R$@FDx&^IznC#C9I`4yu1jLCMQ~^3mO^*y}jBlG_Xj`v5`~^F5J7 zytmZNJRw8!{aE*@@(lpJ6-OUm^Us$hl!2buQY0rno}>YYoL-?M0Kz4UfIK{7a9P@E zBRPt8k`w60;p)kd4OT(pziRbg&5N#`E#jYXHgFk>8#U+kSM1ZkyFRAxY^}W^k(kqF zUBMi6`L>J&+d8%{he3}&y4TmY0LY?nv}^+x?*RzVOeU{9ZDYkPxd$m+~qs8#k@102EL$kStC+DP%%V z;TQmOaX74Veq-?w?9DZgcCLUk-VHoBx^L#ZP6xsB={1$Q0U=lcs}nu0fPn#V+^8)U zkQe(hYbnNn`?;bzSY0aMG#x(f93cZ5^Jnvb*h|I^pgEs6B`v^F>Gz6|ZqC7QBjf`1 z0L2%=aG1S81(HDlSyrJP7qVYYRUGpabU;cy9b)V=S3+HE7tOO2f;KhG80T_MSqt~6ySJpveE-IA#{ z=s-+Im4JMY>ew7*E2$EBz~HbV#obbpkxb~*>0k);-PX*85sX|zaS4irEMi`>nUzCXTPc_MZ5IW?mR-k7IE~O_$SO;y0+`kPq2jpKk$}uN z+HIhs^naTecg(pC4XN&QI*PjJB88oH_FKgf6?Mb|YuDzUhVB;s0%RS6cB#Z#{e@(l zqKV-52DBPKKU(|lYsrr+#&S8lj%ZBW|I^-?Ks9wIZrhIk>er!46)`RdDk=&JrKo@c znW+m^6jX$$ERmohf?&X~CmF{oOB7U8K!MZ+QC1^6AuyE$sw`1K!cGlKgpgoZ637DI z@7@~_sqN^ropa_K+k=fQLgZdz?|a|(d2AgYRF3Xy*UhO=REnS(s-I>v2_&?Yl%cDt zVE+ouoHuJ_umw)O^1viNy4mpxeAOecf~_gZYSA|^lQ)Dv=#|EY`uj0*!YG0qGGAXu z)@70U(qba!6&EODUS13+?p3SsU*-FPlLa~^;Hw^plY6#Z*r?!3jH1Ax8Jkxjz{@lY z_%55dav$ZOxpUOj*(4JD96rIyLE{d*!6%0%57TSFR%_Ma62BM8MqSKXRog>APZG-Q zRN(3)-!bY1FCB!gW3hAU?)ZTM=Api9P4JW!gJ2IjYa=8(GD4qdany68`1^F543GzO z*j{4!2)(D{_;5>!dQfpv&Tp+A|Af)2TR{!`^q)C;Eh+rX(QEE-JY20ypQzSZ zOFMn=i6&UbrTX)*^H50!>nH}{1{5nMhKXS%;WF}fIqGM8g^5Ugq4g~6<42I*(r*jj zt{|n|BZF#v-A%#W4#6tSv)o&QdXjER0G#ElRotKtUmm6fIGQa>oM#hG4NGR(u-Hr^ z?m79 zB@rdz#;i7$16PYn#oZc?32=}*6lsrYx2pGrOIS0xrF9+HU4y9U%)c3aw^sl9S5^^8 z%)_+o1>D`}5@e^iUR@_MfNUiGGM0c$qyQT<*xqYX=7 z>EkelcD<`Cbp24lVdgAn+>LsNUhR97g*dFpP=u#Zhrb`bTaVRPI(y#C@UyIC+&!W- z_*fH1jb)MZEW_`x47oMj1Ne9o=lG=-+Sc=Q!sEC{@DZ^qt<{sm*%37?Pmw8DITyuh zSI@(KI^N+FRtz(SwUVpHP37)0#5k_fM`6PBSa#eKAJL+W$#)3BG;0gh z+teiCBV3Z`AH&73=Wp1HnnnAY*=o(XsuZm-OXeM};eWL{Ie+xG;&}B6iaIj_7rQ{; zey$$Hj+uhnKr1}3C2h_2u*%yF ztYcos>Cti)=;Y36XZ;^7J$Ch~aQ(e=UXBKiMN80Aei)O9(NNZA)M;)zcl?5&w8#|PI=#ey< z({gZx_qjnWCh#OHqfeL>OehoUVKOaSa*hcXEqND`txd4(XM;$j5h1POvos> zHMgzg_-Fk73o5Jf1ewQDW?ZoFdG^{Gz0Kl>A={2dwb>+6OSX(}>{-9k35n?dqwOXq ziXm^l3pLfiW>Y(KmS7QH^7SJ@`9!)V>`yBp!*jI6+t`%5z&n>YBsva6DS z1WzyeO2WX+!6W`xHScO^e{5jV2A$wZQ5;etK{dTtjO2I|!_P@A5(&di+;OR3 zmf`?Z20%Dwd;OJ{7PXbUeWmB?z~``+w+ynuyyuAp1ZX7`_^p z9D3@=(#RD@CmB9bc0AtGKVT?c_@lePb^hBz%t9~(Yg__@nfk;NeUd81wkcpHFfeMl z@R7y2t^&JKD2S`J1#!aq@eD4;nek_G=!E!66F_C^E-Svj*uy{FUy$ew?O1bWC&k;Mk)Ts@{?F``Z1C@Du9&@x>=r-z<>=TLyZSLUirCjHxj`-wpQX>f zpApg~?)HAw6@TjI`&~m{%xncaow{cyz@3ScqgL14oE!DKXO&=rqz;3oPb{2Zpw(v0 z9QPED>i+g)F)&n0+TY%q+DoV;y`*HCWHek)%`7Z=F%bHFRomyW(SyO>iF@LF+lJJg zj6IPc&I_HACk#N=mtZ1pL@I+oY#SB6rQ{1g4ZodE=2e)2n-9sH{QyxJ4KQ>k!xh&+ zHg|zayqp2xJ8)Mq_xoa)=tZoG55?Odi*fM(Gt^vFQ#UxcE%epw@>Dln3*{h!(1~mZ zl!}}Ladd?goEYw!gBPy(%r*t+$OW;~#E@z)DR}1^v*S)is&=_7_$C94PoPLMuYI z2qtEa<7pNsmI7MXK-i?F@Q)Vjx>D}uiSr;_W4~q^z(98BjEwU)YjV1doPBrXZ|`wd`XZkap_e=L35>W=b-^cXDeX4+HUG-@)=ZNCdjRFI})00Bc1uV(=l~_NB+^ zh+tX$*WZ50b6IJI_D)4^Bt<~{O~6wHq3IJXsfPhWgmz}{g^@t%KnsqVJP zKO9QZ8w|11Q+!8wr%9))xSR79S9M&vW!h!O^jOji2e5ohGQdRq{_FP_eZI+UeCxLC zoo)Zugn-FvVlr8lOlw?3{lVVle(JU?FYXIF%&)KwQe4_=HueI7v#A+;K*h-_B+%>X7gRi->D?NIWgFP=;3`PV2{}f zO>3l2SHu31^54Xj{oXR&WzstGM-(YPrE{Osx&MuHuHgi0;Y$y334Ju3=X{^)xTH2f zYDMuEjE+RvTWXZ-Z0~v!y1vyp_&dTiYO&Kur?zV?WS0@vQBX$MeO z%m=jJ%7)vJ=Ka?-plYQLC5ufiMZ)w`Ryr0jT;SA)%$1(>;0+p21GCfpnS;jspkb zcFAJCxA6yY)i?B_0av}z|K-EHG#5dJ3&L?KhdEBCU9q(nbLe$+KbuUD%l&Gz#HDwr z{_PMLx`3SyQQMGuWYjM%VDkd6zi@ktQ&05yC!3Q~AEo?)K(!LY^P!wJ!eFjn)=HrB zu#gwi4?tE3+)fFydXh;+1K^*aI72yL9_lD(Qw@_%U zLB|wN+F$zTa@9*Vs*u-<8Ch)}qYX`}!c9LE_7j_usk35zAD_8%gQKGnV0BdNuq74D z{=Far>35Cg&fVQM&v+}8U1DKC?Y##XrWH{)ol;<~UPV-J<|B@w%+8>gn}5t9 zqqt{uC%a8`BYE~=>lb~Dz;BOM6FHdEEe~fYgMs_4g;N0l!umn>tg%_k0k3o1&e5vR zn<&fK9FiL6t`uW{L&6Yx@dMj15b-ssHqDjZMq0PNnq_uK&H@G^A&nLt@k(b)U~HDM zg*M4S0MZLWj+;=-M}e7k;4>rpr=}g?`xCiaLGOkdTmwH4s?7E`%rA!}3=`BdpO8n& z@nj`A{%rLWIwyiD18Rq#M$$l>m_E!0_bHZAK_@Mbn`VT8G!GL=T5RYgD)Rn7YNLep z&*vS!X=v-F)Mr5UpF6rf@VyY_Fp_X<4rF5iqr_mR`_N~Q9k~N`iIId`SH~R#g&5=5 zaUdT2Du-+xC&^VFO7yCZ!4JW!T{N{LT6ufIhI1ntkjLPr z@WmbQFB>|YAfHHg^28YvTksZxC1O{?+u5>kbJ;UMl|zg0K## zZa^lhRc6;B*~C*?bCPA20L%#+V8o?L^TosvT(~iJVJe}ytp|Fk+ay~LgE)uozOryE zXJiP2osdc$@+vpknV{ry+Lucm${E!s4d4)WQqDS{B~(YAkZb&{%o@$KGemcnEesy% zVDGxc2k48HwwwUimP6>(;e&q%cK!Efh^>cObLbdwiWzkjAZO&j!m?e;D6&`4GhDQ% z+^v2FaOW7ALjiC#To{Hbl*%>pFABq0_ko2gBIEu#@yC}bn7@vHLEc-9d__9o^3ES# z&GZx~(M`lAfj_wr_lfeLdbi}Ig7lc^v}V$WMBszqUc8nwl15RUBoHazQi&_+0 zvJJlRHPMJvIz~imirHNvE($H=KsEg0CvLi%)>xxRE;Vwy43uyRZA?cT~^2PLS2YSRv=@1{{* zI-H=Y+gaLKVS~;+c+7DP+Um=OAHre+hu_q$NdiuYUhL*0-^H|pk*qhUxZQZ_l?SuS ztcSivgDvEvmW)u)v&(5}r?fwV{SE~RSp48;0UB;-!9_!`vNY>&c2DvvFz;@wfwq{e zNQ*n5%B7{{8!`t4&~Tt{HHn`_AL|{TVdk~9KNa|GhFE3bxt@}N2Sx^Uwheu5&h?S? zSYW2`m!$#w1GI5f63tkmN9LYj0xLv!U2eZ$6M_ebv_KaChr<_}7=YfOX{1Rtjw~X5 z)fb($KXQO2na)oTMtj*vC7@RYo_{6e8K4eWw4zK%Kt14?%!RsG5VP)_WE(UGttLE` zY$@v-I?q4i2-cJwbwn6}LnBqXbNpq{@2D9SK#V&lDkTKWuYvN9m(KL zwfr)ec{a(o^=^9lSxz2Kq~LP`U(<%BNI^<7@fh@&QGFt-?Jxrrtphw60ky-flMQk~Plzyb+st$rR)cvR?K45r0Y@xPig79J0+P)jC1>&#y; z)?lfz4-Al>wEA6WgNClHF;62oAk6eM*)yI%dxssK?(BX_Zao`yE{ff^!k zNvHxiU54ps(LTD9W_D-(v4)`nMt56-R}4yXw3oTxk62y0ncjMc$r3jLSwu9^m-_Mq z6>>i4gF9M&GSIgKP0*CkfeIx*u(<$;OQ2%{zUpR%q+IXqH%EtLahAE-;>@N_aXv9J zW(aTxNv6(F53gPBGiGQ2jw298KqnaZxeUC)Cr8wPo;Zg0MW33;;_)zX?4f2~DD3+P zKwlU*2X!J#2pN{$1NQL-MP$I6+e;te2BrHSiULpRB7D?)%~-M(oYIac{$HEQ32Uwi zV&Df4T8J-(9uWoN#c61UPlAuR(c_>B_85TEks(xJS3Ux%b4>0H1$*ZBk08~7;j5fY z8a0CMRZ52((nhLMw+)M*)qY-ujZZxF<7|cf>ybixfM@-U`RkmC`RmpI>GeNk{%YYb zWe5kar4C-3n7^txAoEvh{~tJi#VmRM{MGvGRNap^a+n;}2JRtzLcVjzmgc2h;fq4_K!663Z7T6;#kPTim z6NLrrH(NzoOw2ZWdw9;1S$F&pP7-kHg z4RDh0mG6)rly56qHmW;%ebjajw~xC{q$*+$8i~1B?n#liC~`O|08^yXiqZ;Gb%_4& z?<4Iz5p@X8QZ$Da9lPA`4?E@o_(*vyJTjaKD=Z($vrjRNA zOfr+sH5U~R|1G5c(C~kM;4I_4XU;0`YZ_bIh>eH(a38*a{fGwMNZj)=@b>zFB3b|5 zm}+bssVS&m*TzD$we!`vJo~OFljO(zZ6Tv8FjZPl5{r&_jL{O>0!X5=IA6BsSl4Gf zG4t7go{)Huxj;F6MOTrF`YRt9+gXpb=qxe{HFs#>ssbHmifdBTdHa|U4@wzVN1o=>;ctIGL-MoAwJPt{N75d3 z|5|MQ@qwQE+PPjuqdk|TwrckLfzfR$2V&eDzb-gQ-e9}o@ZjaM+x?}x`#(!>Tfs>@ z^=5#7uLWD8W#Mmwgm&1X*STFot_Eo1ZUgVOec)C%WRW0j9oroJ9*cx5LH1gKAes`N zN_u^4j=~-ZbH6tPWol>Dv%|&@>&Uxt#5QaELiKe{lL+B70Rg6)`~u>T4BU{NHv_$G z%7v;q@K+SxGlPAp13*;4FMD8ot#Pi*g9`iLrSogR^*hHb{6o6V!_|S9{(rsg+dYM9CvPtt-X}ps1QWGNmLFU@Q8C>L zSp{Cxy?`puJwwG`;A1mUxFfuh6EZ-Q2_zG}r?vFz&K+Tow-kU35#%vS(A#^tW=#Uc zzS(h;dkW>UY^>vH2dE{ZuzNzDGM3B}pGi_S?tTN+4Xh!9Em47$^@-zyFzEW{SehR) z2%}F%udkbeUWfQ+m5@CTx)j8WdI0SD_;z>drEF_cnT2c4Vo9D?W2YU7_COAGunO%vu9+u>1IAYZFEnqrwq z8*da~l?;7-_b=kKtsp#td!XsT?w^xXoI*xQF4knkXTO;7eKpyb3dq zFR6XgT#ua5Cg>EGKNY{;h`o6Zy-3#D(3eD2c`cb}+1b?p+m@Z$$F}SwXRD~wKelCe z_EXF5Q_D`t<^0sL`_!`gNbfb$vKxDTV60B-VtVpL#Xd!fRx{?oJ| z3oAIzjNLAh(SM$E>L!dX@zvd71NN}ieOWCz$+HaP(A2Tm_uHI@O#v~sH!`nr&=~XK zmRH+G&YcZCU7!gg%v$}7%FJx>u>G|rRnRO1zR=wO2*y6hbKLZND9f1#lo22V@=MeR z9r&8WI}0YRBmFf|C$BHWJY~<_vT-%{sHTM(VdokxFk$4(UihNZ1fD3GGVVQU62LKz zt$Ae>2=F}f_%6T-D-aKJO|qfD$HuD!UQ;Hh0w1eQun&D128{W z^r~8}4%{P0yL56Y1O>ry!8ah6N3!|Xc^)%>;)B_(b=p+Fj4YlL3XC7`Z~@zM#)TdY zRCh;^tx_S824J_HZR!JHzd%}I|0^-fCIsw3kkffYqBCo~VrHH+%{b^s^zxSa=Qeeo z3q@}>yyiq5cvGE`Hy=-9Knd@Wwr01SM&86=+Cn*)ekBKL2RKUg2PX@BO!z>!ae*qh z`TQOkPzWyr3Lzhayk6Ayqfi8kRZ<8`z?xa5d~ZA!KEVK zv0(L#a29JWZU^n^g0sa$OTq5%~M?{4sFz3;->bLdk-iv5qTC#*VRa$br)h@MQ3Xyr*a7(YA#VQMZ zu^RO&>O2``I`a&7DgL3!Rt;T?m9y1CDca1-+|~GOlkFjgH1vStm!kH5kH|U;o$9Zs zC|~A#v)Sc%+VGX(NBLVrwuRU^>UKp`Fgv+_#V6(ObX-@ws#rf`4kv=o9O5qK&M>rT zUhOs4nBu@p=icNxa5suh;_3NYLM)rt2Tumfwjv|Zg#ddq9g=<`${-9IhK<06qyCU-^W)#C^Chr?a75af1uM5xX0f*8 zF4Q0JTwWHjn{^E5OFLJ;&vRw&{C4I>Tv6=OTy-_pisvhJZ#Safkv$;p2Oha?R0qf- zej$hez6cm8xQ-S-G3#PaS(KiY^r=~vhv464JV&a?|n`1%o` z$wduyMmFD;4m6#(woWU~eo5hT%C%7c;Jo67SPwl;OUdawLG@RAbR;23O07gWssH{7 zPGXA_dV%Z+v7fcrJ~KF%&VLgfbw3f`>3eOV3dY@Yeb1ApY8-P6RaRr3YV0Z7`eB^9 zS&RL4aPZS@yq(eF{y`+HbAS5%7fwp6#CA}0`=wPP{Q>s|25FW&kiXGG`Bymw*bOi+ zPKu{@kMQKH%o`GDW!llFo&KYk4uLOG+g?EKL)ls2C%&X4U z)vxP(Xm-rmz3NFTG63VBaD)HDe}tl|U4m3TAQWBg!&^D*X7$MK>_-fSjG>3`_$*R*{0j;vKWr zVXy>p60dLO%@UQbb-E?3X1BiK74!89z3R^^mzev#Tc+Ne@b{GR&X3KwEHau9OnS%o zPW0t;fH9(Ju1pV@Ul%|a&X0T5Qx?$xevu8UF1)Cq=T_Kg}{P11h`uI@ieJXC47!&~4DMbsVU2f)pT zCpP`#fW~BxswaRd=3C!rbG>W@{L6w0t=|$i9fCRg##IBtnAtMgrnf6e$r9J&fEY(} z?s|A_pm4Gwr~F+*PR>aD3!pp_1}BW`t!ot8z>-ro#rQE{`DwcFa+A=T0Nr^=Al`t- z{Xs8)AgA3KBY9VEJ=t{HaVv=yW!;xmNz{Rx*mk$&=etX?KLBj^pil+v#puQzRh=IREP3upoRrnOtX0|^6Ht_`d{^TCTlgVHOjb3U3uPf53|B7B%tRi z5K35!GnagGw*n8VIemh5AI)o0Dga`&K$_9u#VG| zpw79PXV}7ZtbA)oI5m?=lZGXS&9HNzkS4c&O~E+FE2v(#f6GII1w7>~sTkkY}=PH2$@ zJ$81>Y4CmsM7{An;j)V$PCZJ85;!es*Gly!yd%`8vla`)!Ek#{3z#H!fVkBL=sA2P z-$>SD1so8C9`@pVnO)i$UW0z^<9K1LS9|svkb(lSY?(J23V5xKmU}>!kIs1*-7IuX zK7a+HedYDwfV|aobe0PIL3{=Rx=;X~<7kkbAGGMJHu^Yb^-7yt2m3tpDpXX-emGsa z1@BdYllM9Im$=WWfcS8OTq)R zT;#=M#25iW)FQ32k;hOi5{P>53xcm8GHk`AMW@WG z+2nD^+>;M@IO>KTJD6=SYq!%g`K|6RP`%H*&n0@huUjRhe1`jB!HW4PtFV= zli2b#T7L=aIKdj^bAUcz4Y6Qf+H!mK)bN+AMY!FD-(HyK5uU@^gma=L9>97o-5&lN zw}p0jf%eAvH^XQ+TUuPbqvx{Rc@E*}EEC$r1qQ3aPO`RMCSBjG*yS(1XD3}%k5~la zi$m5~Eccsh%dEzsXsNNferj6kbc#7M5ochyyLX*^8GE7K%3P5*Sd4>ev`6fg$U;riUA023sV+g<7$G$G(k;!Ns&P8@jjqSTh4F?bvx^wa%|mT!xFWz81GF!q0F(Cn_Yu#%V@{Fb!YVNQw6Af92QId(Y9Pa zt(c^k-<1CF)hZQ>q0u%JVQEv<4h~~u_uWi|*~vcX5<1hEo+v-P@}xwjz!tvxNsf+s zG1j)Q@b#PoWB=ZN zy}9?~obx>Q1cFZe9VE*Rz34?>|2C!KVjSzJ0{**ehTDA?85;bAMPv`0YzSz51N^x6c`NeSp~$H2D692V|;i zr`}GyeCqKxEc@>oPPc=;G_HH-(2&(aGg$Apw=3$$bhK1ha4@Z@%p?v~N&SW2;oxA8 z``xRqA&OjdQ=eVQanojD#%&K&mHepCZ*+rDiE5fY(?QSko9GjI{!f}e+$pJ{jEF=&Jj+-Gw!wu{jAx0(Me;id-KpKRX)edx`qDA zQz-_2BOZFYq?BeHydo?_Q1Xf7V*Etrez!DuBsjCbOqo*kJHN@DHM~~$*81SV z^e=!m)yP`~mFaC=DZ-3Bx&je?j5~Ay9@!W9GUh%nRq|N1FUUoIfp<50lJfDXt;&*` z$?u(Do$Px$L|^Y%6WrxRit9lVsdq+9FXtu-8WA)K$rm)G0{R{!y?C_v*|7$ysnAH) zl0;&m@) z>1#6>A4vquVkmbGp-3)(#F24n_&)2<&B^|18td_yIA#FxelkoYwX!G%H88|8r-Q$` z14=VM%5aJPt^J_-m3nq64RZxVSGcMD20*%~e(NFKXS!zJ?vMeC4|34G+)=bf5kip? zIs1CI$G^C<1mGGo^Q(^b#+wvlLITyC)e)%sEc_$e88N-@#0OY$I{D6MRg1L&W*8<9 z=9!Np|7Hi&nT;@Iq4SZDpZ?J)vwtYR{>#gj$^dw5xAzb?lOCU)qZ{{%EtD7Z?bf^g z(0{7Uc@*>1ODIx#x_@AR_ff42miSe3{CXXiiuyiO^z1*k>{l#z=7M)P z7}v}hUi;T9x#Sf0+!wv$Z(QV`*9mIIynkuiGNEcN?fq*e7&*7-X#aw@E!FXf z#PVpdJX$P|7R#f>!XU9cL0z7pEiN#{bBm`QxGg`oSbm(fIRAd?0TzDQEtcJ4 z*)428LHvVf*2`Za%U>eD@}tzUTP(Z9vRnM~-*}g`#lNR5wo%M`;<)4U5fdS*s$Ge+ajaPBQD;#SSDcWf{ z+?1$zBz5g_Ru}M`C^PNG4si-CKd&H;JUd1@-0kk8Wg8nOQ*I0^)#OJMK&qJCkWL0a9FM5=*W$6%*MJ9%x=Bb0ChK)^c|e>WGOg8_L8Pg5n%AoMgcD?K6tuRqd+OvKOGUEpvH=f|l#Y#KwH1;U zne`#m-cjEklbp_!rPvxjsvue?0}2ldTeg}KnNVZsTo9!%9MBVniqTMWLZlcZv&Y#- z)UNnmy;VBj0;tZ6!}P_%6c%H-M6uM|*s{6)lGEq>(YU7 z>!}jx8pVhZqj%NIeQgcDub|VtkfFAoMMa~)zF1!D8eRVwp?%&OrG__{{3KFa@vJr> z?Yj8b$lMIk^Ok)N;UQfy-{Z_cvcwXh_vx9wF1>=UnDi_zaSKA^8ztJszB1~u9}Des8$=8SQ) z7w^w^xzY88XK~@yNi%&kgZE`*CQK)f$`C?+mC;!51JJtK4aMRcP4zPC{p4U|Oe7Pq zKn}y9KY%_cnFjePUYF8iE2t1})#=9MH-y)7LwI*L(0fpA_1Y3cw^TW5*ij+92NjsI zWi$y674c=H4r3fDg!-=?r}h0zy^m69E3M-J?vQ!@2GUMzH|pdm=PyNB zkIMB>O&59(8fd~&%Fn5Np1wUw)KR&V(o_Ce6rU}hzal_MdOM+FpGPN`!RtA@q) zLp1Vlv}U)nFwhiWA~SLVIdB{i4ZqYM6NxZq|Pj41Jk$M@-myta0{s$Mh~E>o^;K&b?4q z#9UBU{!?KDPhssBI_vq<{tGgdsIvkO$;~MxU5!(?(h$!1vP#Pzb(@P!O%I&GPCtiu zW`n9iIhXm#)8bpMEHFRO;%K>EXrQYbZPMAW9>j0AUOaS0jg+R4^KLAvWW+$>(+V4# z>j`e9Ay)aG9>b)^mqRtflI=yrgL}l=DK>@Ysx_5| z31jw)=fCFZCnlIW-&Aef<^{>G7{Oh8q7jU`i%$zu6Q-&>3NgNaA)eBuqJCKB3a4 z?>Ifk6V=36lTC>xEcAp;K3&TfB`xe_s!c~f)|$>r?M~PP9+AdGbHt;D-8`G(7qs4n z1bzt5agT^$lc2Va>JAII_q5_i(x1}Ilm*Rf%)p<n>DT=9*G@>lxvf+8Z@UM zFKH-Y4%afAD3727;l)r`EL$E+0$y81~Ye-=?T!Hj{#{eG^#T|GFH zZZ-xa09$zBx<%Cu33o`_|GMgC+dt9po?}A!7LRbG6;PIzyE91(CTa|>1&s2zCw@#h zw{FS0qh@G|75lu?kBzAr&<*<*oLW8lnpNfcqfdfOb&q+Vn8q#O^An}`TXb@|g9uPb zV!n?p1^T5!aklA_zFdTVQZ7b+al_*Mk#N3`Ne4KcAz+Tnk9}}btY~m)ii*Dvb=SNn zOmGyXo~?@gqUy*e6Sk&u`6=jwC(Kx(7W=))qg$Z(WN1J`o}a>b9+l(nJ(+j6(uI=& zlR}(r4sh}3xvp(unlcjQ$4pqVdnR1#yISlP6Kb2U(*>4bF8QfR-uH|56oNKDV;@Yp zkK6PM?-Ys`K3stkZ<2_V59ND8!}BHZ7Svt`#13tB7*L`EG)I`Z{vs zk94VQaaI7$y`M}igrawsrNy$e_+Ke4mZyHp6XE6O0Ka?;Vhxx<7pMa`-wa(D8in{x zmxCW~i;x7Q98Lc)J^93R>uEQyA_%(lYy3PWo%n_1r{+x#D9jhMb3?5O2VLH)v-(D3 z;=M)2cAPR;Cr=DPFg(18;_JxvX|1)0*-y!k(6Q{Y9t++|2(2j^?tq>I8o5ToDBFw9 z>d|BxeLJ8N&Gu8Pg#^uIVWFwoIy_2S3lhp?2C4V){Db%sqIC?q#D`a$eqU8z)2Ez5ytv)nyUnhqVh^o=I&*lRTt^7$fNlO8|SZ_uORe{o2}AXM)6QPNIHK! zsZCcP?_je0Lr{6pfu?r3JH@_TZ+kbaRBLM2x8d4suk*Bqb}QU?N41H0naCS|bMGEh zC_bQrf(5pSw;GV@KF3q+T&-1}PLXK1T5Ex#V0hF#W=TeI!DL!8n<%w##>`saoZ_6~ ztnD3cp}JiGrapK<;5$3&Y>4*fl=Ij`sGG4%GRK*iXZ51lw9cet!<% zHbnqEyqE8y)M}47BAlv}( zo5h=36w>|BMS?pP<7P3uMW??01Jww$tBWjzTL1~V%7SX;Iq1~GCd1ywV1RUi0u%Gl zS~Z^t+A3ON*QPG0?v}fOt_VLXDiQD}oR`8!@w%oFX^YXY7Y{(!;Y%z&??~miHl{%o z!0Xy3m5XtpyQEvdWD=#tXhA!(%nCTx>J;u$FG@*3JNpd1-Ur3SmXwxw6+-S(gbwAB zQ`kg`)`fYjK&O?zgXY=Z7tj+V?c$?Q|LRO5!UNo3=p4GZ6dyUNXH&Wzf~{ed;VMdN z5UV1(OYyc8mO6tzJPMhRH}^1q4=QNhLkpmODTUn8i^4Q=2eFN)#$a2~sxl*jNAHl| zfC?iC==5YOWNEaB!3Jyw!bi4E#|E@1;6ijsJ#=w_{%v6>N^QTQ+!;SKYSz0pNgoQPC#y z0eMf#0q6&uH$g-t^gc9mZRO=D!Cvqt7!9@nc~+PiIZK&ZKCs*$%kHpzH~hPUflw+3 zgrkJh75G#rYeRaATlyi-#w5Jq{*@spI$_OSH?NYx_%Eh@P>pj{8~G=#NrFQ zTYmUQHv;eOEb`JD3PDcr1sIg_Ae)e_=RY6CuiDq^9slyqHAUG3>T^-LBzn~cz3;`Z ztFs@-GUp`l(`4i*b5-D>SJ!8osVlNv_|Ni#WxJx}ReN7`VC4)ot7NyM@~SXLPzFg9 zUGu%BvkIGWs24s3vtPL)he|!s{J8T4W#_6D)S#@uW`?Xvb5Ce_#K%KFVsz z*~H%`J1(1+t%~|Ey?ve6mN@6G0beY6IoB>^O^98HLko3n^S81elajsP^xx#~ zQM~ajesyy-KczW^e^fT3VJGd$AZP5#h&r@Uuo{zH&G(TFCZW8MTkP8T&&isr-eA8S z@ya%PxP3uZLDo6hm8b(5VHx;_9TDqMFQGP|)}dTbYf%Hl=lP?u?=(#sQ&LRBJM7ol z@3J?ueb{fZJ=t3=(G9-r0QM%fN5opmI*ALg*^)9>pefNvG)x&u)~^9HUu$k_7#dkp zWCMcud*Y747Y|ZV$y%`M=p%~Vxj|bHSz_dJ#rOt4bl$j zz+1^%4X#tI!Mgmyed-f>swZzY_{^BZO)Z_ z9aXCNOI1qmo2RIe{KP1O=22BpugBE7S$lI9HH(@<;`5uyQDD;h-rN2ConKi`J$Qr~ z?CtIUhV#qzYpB^lIYHTc{!6k4vUj82@ZagYA@_wmyF7=Gm3y1l^7ES?$S!C`lH%v6 znzKm}8OJgXXJmRKieC&~?UtR--z?i8yC{3OeM8B*5|@&-B~HP1rBr#gN3)J^*KEh{ zN{aP%FLn{ElIQ4}jZtBTT(_-s%vmS;vtsVUU-cdjX*x;HDD75sCu8=4%=UCmX^nC84jq^Z;7$tE?knp}+_>En#hjP$uu zjZlVeCNvZHyc|Pz9N)}e(Y%6BZKm>ZGLzm9szxn4DI3#V(TpbTo%_?=vAM%@nR9#Q@-!Wqiln$Xp+>2pC56u= z&Y|YWbGsS>EZZ$_Tf8m5x9qTNvG`lwu)JppZP;Xa$KpG}lM;FIcX=+&S+m8=iR1*qP-)=D&W7)unrY5!@sBo`IjFA5F)n6O*Y2^6d)J zg%SPG;4=6*F?qPPQTc7)zT?T!D{55mws{q4A#@I%Q!a? z&Gf3vtv7%*Gi7T8>cP69O(qIgwYkEPJ34!ZVo~t>rT*4pfu1ncMV2{n>r6aDMmm!R zrdiRd>hfv;oi=sqG?YAzoT+<`hL*B%3Ny^~iMmEF9WwRMq+ujcpJOW5@HMt3F*DKp z7g%ItlYY#JJ2;^48l`o^II>ebN-s{HuCm_84r!{oT_qGn@W5>Fq2s2>3RnCtIjY`L z?nNo`G9dIt&ki{FOgdH!gTzGnv=BJJ0?a{bT(G=ho#7~Qv~;bq?0meaNxH{ufCH*N zS~Ozz``K13t4Tkho|2$J zK(p5t#h|fj?P0IeB=hBZwdU%HO7q|}XIPRNCEprC`0FV%+F+EY0!Bk#q7=McREwNe z!;ZS7-$7WJG)gc=%VFuUcm?Y)jA&34Bo^|-r+ULG>%9%I20pPa>oCx^ULcUN%sJN6-e?f`6UnO zz+r9nrc4+~{d&c-!Mt1PkFY`oJNCqhex^wqbl3}Fx+J2P)S)TbpLBi%THTSmR-`4t zF^!D3sw#@Oby4b4dP76{!~OMB39d{^DRygop3C%gsuVwzTR}laQ2`F3u3Xo#l^if^ zgyCE9w2$e;57glx{t!=p$6%Ah9(au&vsP zG%Tnq6(}vo+p}$*WN!B>KmfTX5{n{uBupGtQX69Ko2*80%2OU?`_=Lqg_WG>5L7*- z)5Xt#w4BoUIy!TWr+A>219f9E@A81!GI~j}Vka?E(I4V6gamqZzGew~VA2hUt(H9R znDD^&3hB;c4`h0jyb{1-%B3lJK1^NFb9oc=BxdML@S(Xui4(yAtG+cgC0G@zmLe_Z zT53`H33#n5D54FpTrTsJHEt`s!M1?mV+i${1kC4&iEClG(cOq>8WBNmTvlU{T5)jNI0p7bN-_1 z%SU#aA6W9XA!=)*z%DVb0@r6ak6RzPLnlu!lf%{VDISHHDpc15lc21Fj8Y0-PsW-L zscEnXNh25qizDfdQ)?2zafC7&HGNf*)>0yEoN2(t`h%HZX0W_cF^fglbzcI!q@+XY z{?Zz=o4?v3aCf!vdrBo_^JDk){;SsOEBvdWM~Pb%>@kM{;hfnKG&Kuo$X?vEJ5T@J z9rqAYOdK;$iqXdqB76JNgp?^4_RIs$WmuANGAWpSD4W(z?n<>*+S25|C;XI8+P?pPyyQ=^-$K|-( zAzhB-yJzj87O|w^oZEI%ag&whm?}V3eW2J}3JyG`&)QGGJb|1Zinqlgxatr=xvQg$ zDBFDrE{%cUEX@C^;cN3yr;{aqpIK{HfJ2ze-0CYc%`31OT4s78Nn^K{ur{Qrs(QN4 zsbkJe2{?+!#JA+Rrgm;Es%M&7qaCG&F*gf%}tY-M554w_%7)%r=_M+BDdqvIb#+92~l?qY&l8qK77lR zSCKcHhUGo>k>l!+lVpz)@T-%83ATclrbYrM3AJi(-5t)&XJ_N-X}z}u9Yvr<#`0(M z5W21qX7YSYtdq#;?_+4g>+j!}ngMbcMr%MH=9p%cCU=01XM*OQlQ50d&d<%N*Oy`Y zACp!L&1IQunUwL1Ct>vht)P4M!bKQmKs`05{pJ95o3j~-h(#Sc=DkRjI0k4R;ocf! zUi{tNv?65P12L?3$|ZMbrXzZwlt};{iql=DuU%&%QwBmEv9;H_ZDCT535|cLeZb9& z%Qdt~PY-jF$vXW)Tz4o6mt0NvN&9Ise^1E*O&}qHvw6tIA&#Q$I*+A~BK+hN8bY#u zTjV$c{dR{BiD-&T31$}OKY=3FQYyC@p6qf33c4+%%@$xcOKNN?Kw;sLm<-|15n6Nc zaV64|V~Xqo+JMOtYM>ntMv$!rcpPV>dqplA3UQSo~APcnIlZgmj3K!~8hXTa&x-w6pVdp{&e>q<7G4Q9w7;z{xiuqJ< z6BK5HYRM#1j#s~B!(1?leoZ~ZNTzHY6Wg&A_*NEP4FNdV#?i7iG2Z2{v9|mPnqXro ziWXQ33St3GfTC>T?kkWU9^(Hg^6k3b@^GQ$2oeZI^oUE#F)3uI&k!W~rZ~@_+|%WO zOYuRC!6@o6(i6x=UY*IL8X{~VNra+{(H;noO$+SKLX=M^GN_^^XM(&3vU@k7-glll zk=kL6bNLusWX~dZVrbsCH{I6r2wHy1Xs^bwJB5s*E`^qbVD$HVM}FooY^cj+$s!7{ z$YCrtc3^70u|%wKJ_u?CHG+G!GUQ^n%KPy?4=8fKR9<#Q5vtyYU0^VNjt1*ad35h^ zPn>r^Nbb{Z1&9Zc=%r9UycJ?(=5dL+enZTd|3QSusKw%X;Aiu4sbPq+GQ|RPQr3=n z&(<#sAVHW8MFREHP+U~mq;5M~Zj1jL9B>AMH#ExG2bVaVJiWSy4id4Ke+^4|;x2Rj zV(;PVCg0J4Nv)rkb*RS}aC(dKoU{UpxbshLD1iRmO2oA`-4mkzhCEqp_1}{x2{@CL zsVb~Pq*QF(e1s^(oA^2`-75NWgvZjY((#`8icnIh(-A&fieI}m z)+3At{C`{omWC#4-hzJf&aAV=nrd;jyQM&1Z1WfrSUq*+{@nck4AO>SnzLf#pI0q)s=oxgLI4qQ8+ zgKVIP@VaHc#3*&(+JF?(#kQqxE2W4}l~eR6hxrIww3&OqFnZJ9L`*^h!}b>5erTv; zMoDx`TUZ=U8}M0NLqk6jW+SW?!X*#>J#;cK`H*3Yj0?5yg)5y3JQpao64Par$>@T3 z!m<7AF!3(=%b_;Qp*Ft^wONjgT#k%fj*R@@h@Xk2fgsK(iH>vqx1!)|D-Y}}v)?d~ z0Q6GPmeMp`0TQq5_8W*W{5hD|`1pV3PNc-4Vkt2|Z2Q_AixT#i;Vah_sq5oM7SY!t zwV`x&ZHexzAcpBjnyih$jRrj)T(=+XRX7aGlz**UC+5_;pdH$LmcAHAeW%&hWaN|G zr!f^DoRo^#?kf4x=O4 z=#qb5fPWb0y~x-8od8c4S*CJN+X&wd;q5|U+xW&Y;YU75X&{R3sc4L!^tHh)x|Ho! z9EuuD-a0xjG9fVdDU$hb)i)YkectnaTsSB1SA~_)1W_5H>IK4^f;ux5V~x_B%K$2gnSxkg?MjV*GBt!wFU9sNSpH3 z5~4o&86L2p#^>2uAD;`GT1*p$W=QNT3Mis)*$Vs82aGy_;0HImT zTVE`Z*+dBX8M^=Zi5&&@kVX`bj3(7N|0o?gv81@~E+}s6j*!`Iwdyv6o3K`QO+7@L zSp=mGTmL7jTi?xdG)+jWPV9gL`2E6pe7e0b6q#p)<|sDW$)A&M{m05%83wHRloH#I zcX>jLao$owMt+*x|0le+{O@`1$$9MQqO|o_dGFE`R=)UH?$fmQymy+6Pf)I>C?!tE znP?^PsmVS>vlAUsEPgT2Xtr@I0JKfQ76>KfF_0*N(l9_%F*zf@DaQt{Cu39YlQ(hp zKr=Rp>3ZM>A2(P?Krj%d_|bFf$6(9|)I0j16TixNYp45`Iqzl6eVOuKPI9oEf%sTx{heX7eRam8{vr7_Md3e$sv|vAH-zvpkm>Q z&s3Mo`M(J&-hRt<)a))mCQO=%=nhjkuT{V|cPKRC6nhwiijYzOrCUPRJBIly=PS5V zS|E-r;XyI*He|t#8G;9(yid3Sd8=X{7mG?a5=VtmP~80dr=Tc{gaKT!4EB@F1w{Sfm|#1j+I z5?(TAfkGJ!)sgVl(Ub$IbUe*QhRlO5R0wo|qWk5z%`HXhcBxSD13Czb^5>ydJT@2N z0`Qxem!XXz+}07!ObV*aIwsDC&e44;;jwd#DLcdvngTT_CQ_>#5$=PIi>~;GG?D@5 zQDK;sJnDZfMLRXeh04Hxk}EuyKdRpWlIj*A06HaC2LUs3y4lxJD=&pW3nnF`2MK=y zptLq#fN&IuBLM;qXtzr#Mp8Z{M8wO3u9CVN4N(9o0wix1KOs;c{siJRc+SFnk`K#${j+#QGX8N zhuZN*T&)zmltv~&)Q+$kbJ0ty+c}S2Jk;&6xr)>Fo={;zqwZJeq4FTe%EorxdB~R`Uxgr(A=y{w zHZfU5o+QvvWa9dx@th8d5IR5+d+?G!f!u<7%bt>PpOs@chsMr(B81KqR$oP<7w8ZO zVd^8IQD~c;Z73DQ>*x;HJUXnv;%`g|wcQ>IR)?Hes5q-e^0?8lc#6%REj{ zNpOY;zKLv`_I0pPBoJFA;G5)cg3fX?d>&1@JSGjdu}XQ-H>5!}wh5wSduK>bsa60uli%F}#(wBS-N)AD)?r<_1i;xiOr1fw4jEB|n zK!}1GH5mUZvwVMRAgtRpj2+;$PkQf)n-Q=BrWFN$RTC;$2m1>sE&6R5XjC zijt&N3BB6};qkDvm3I6~NucJds_2Z*dpG;PJ>{IXYB)+<72X@*?>+U>?CK7M#BXY? z>&vbiXV*B`nX|3@-^l)|(IkCx$oVkC#aH%7^XH^Ldhc*v@4CTtoxMX)RuHu%>sw84 z(wE-v_;2$^_`kkoRmkd)7lNqK&H4PBGFbEe=Hs&KvZJyI*<8|Y?{|w|3AQiI!en7` z_Vd5YIGB-^5uB0ey|vi2#8I$P=m1{{U%i)qQC6Om=IvJOB(RfHW1{d4uSL9iFv~If zWb+-_9oY|>AFAGC?v)bySH#8gg4$Vp#RTd;`k-5oAWf+-U<|r$W z@tR#^p)#>dpeafUolBUDoFmNz&PC6uHKiK4ra&W}fC2a@+ID2QlgQ5Sf`YoHIVb9J z)$!in$8RhgZ%9A%;!X#{s_?9>{DV=7s*fA6y*uL_RgQHIep%U3`x+=2@AtmzzkTWz z*H_Z)Sy|InalKooTwR@KR}ykK&HvGukyDdQoAg0OWCp3>y@+*^7bUMsUXeIUUY2Z> ztOwqWaFMu5UXpC6T%~?NZ4bTz_*nL4gk)gn#2Ri=LI!2-{khNQKAB5u*cJgpyo}hG zxUwt|KEr=)@KwZmge&4D#0JDVgbQLV!U@<)+(twYUmtu02q3;Y=*)eE>zud>y$Zbs zZHIP{t^{A=Zs4vW{^3^@?<HR0f$pS`D z6toN{{rAC1zi5$A+Of`2I=oBClrtNns(q zqll}?o^?}YhMG;;1vMlq&!CrTXUp;F2Y0GF81j#VX<(kam#Is3dYh7aWVVO~{vd04 z`Z^L=hn%madxccy6=>(`E zgyXgW-6bc~7K^STdLU03q5Dbc>mF}TQTVJ)zi5-sd%@`Whk@x8sfgK}pwwJWTHfq= zx@7`Fb3|M;mZyD|GKCyJK($|wOxtIs%P9<=Id_;wl>mReU4iwf;&jyJshuf0lJ5%2 zm`ep~TQB*pl`_LeCc^Z$Zr*Omi}XKa5x9h-jr|{6<%zY|QuG1`SAQiE(bKQCd@H!a z8>mZZ&QTvpkOIb2a!^Y}bSKgp6^~w$W=C2XO5?yY;n#&8mrwemz6rD@`GYbm!8M^BI6KBsjGtz(5c z`@G7^q|a`mHwMXD(D)saBD$eyGAxT~rSJKO!(0L$vPD69D*%mSD3a{fp%j!Dwv_!cE$x=QV z*43%v(>w2;5Mr&5L?}O%G)#WlVWLq%?0ZCUZW0{?p*9~bK;SbfUcgK z4Hvr*=lt9U+C_NFBLy=f$xK5^YjwlJ_EY|VkOoxp!Z3GrNb{d536$Bj;cL#Tco|7u z90x3N!p*@s-m1se@yOa|S0Ee4Ju86c3F#=qHyl5;`Chfjk-3pRZAdrIy+xz=$Kk+h z==A3)VLL?#LtMB+2|)uNs4%O18^j&FBP8>=fmBx#{j%>09%-Lq)||L~Sgp1UQ?PV% z|A3i3G<)-ehpORHty*Q+=s&i1AT5!T!Ma_V+Yn;DOHcDH3WyU6zC(igzl0-a^EWWT z-qMI6iE>qu{dpyL@AgTwC8R<;tDZu$WwW#v=bOWi>fPeIN-TqGeUTZ4G1|IxHrt?Q zmirAGy~2UQ*=)xSCt^nH@QQku`Wv+!=LXi0&4bU;rfKU}*YT!AA_r8bikLDS>5E~1 zaT2IAPk9fkd8(^;S{yhx0MMlU*rd^5LaiKoTzxwjZP{g>f+N4w6s`!Vk!V^XeGTT> z+q9c5235e9@(yTph@muct+(8^Is2=X4jcSGFTmF1W{kp3a&$FW{w9pn`{+9t^-xKD z->@$xvMEMhCnjDp0D-`cVLjMgpQd2TbzXXh$TaYFt{Pdd?+NGjNDb#%{T1Q%F8(Sn zMLiSv$?7X@(rN@=C7h30<~}?`RJ^H3>}`A`P(H8DGbkU`cc4jK6EFoha|2fEz%yRE zy_RWBQ(VFXT+&S%rnf`G{e*msCy7ij3Gs@#oSq$`Al_ znC&Zhk6kJv(}%bzW|usbv*nI^xEtkY88m*ToCH{$;qT0AxEj_6C)Nfxmg)bheVfr^8sR(eethPl7ewfQ9*Ic5=qsL{`Vd zfQ+#|r_}Yz2J54{96`W%&%l<+LCqlFFPcL+7O7`lG^Y2~iEClCimNBA6q4ri7Pzuk zqE7L1;A*di4}@4vmu}PFt-8&EONCu6ZBV2^# zug`|pqngaNGynZ8%LfnhQy14>_+^%5Uy;syvdGg`D+theP%);+ToZ)!h!Y46--|IF znW&*cCA6?Z`&RS;h@AYW${z>S!wZRElWE7Sx}|8Y z`M`bGk1OP*+=Z4;m|5Dbz(2`WZClgD=h&5{Xm|;$$ChmQBP%MSy@l>?3qe~*xV)d|n0vTMF`qiO zxY)fVhvx{P2TupT1&RJ{nUk@9XPJ|n^3IE~os@OPoLCL;b%)p%n?>r?9Y*sHb6<_?9?d#1)c=0ar2%QO3td@A0a3=Nxgffs$uJqyuq2TSc^pl~<#Bs%k`w4pqzX?{hVhcJDY zP_6m9HF2OpPH|y*Ws~LXzRSqbGIF$x9Q{X(gn+Aw*=}$R*Zr$Y8@l2(zVlg|>cepN zdJ}am2+Lz+Z4uAX8b_D(jWq?OK5k!!aptu@6l0tz9~1hRj;XeF3EmV_SfuuPhs0b5 z8~tq^Q;b>06bmcnIS@psj+BaCA}qiM{l?zUzS&1 zSYGC2fq(vm`d(hif&Bj!n8U;Xv_ zD|HvX`l|E(%9X$CVqN%M&2ye_d!O{KeYWkvGxk57c;TmOZ%^L+4d%r=zqzyOlWxK$ z?~uRzCF7Nso?G|Qx)shZPB*$w)HV`u1~5fZEI*AF%6&PCSY_?Op#!YRHa(xGfm`gr z=H)-jU9jv0%QwRRva!Gv@(ArX7~gm~Zi_lM;pSrvBSHMCFK&o22DtluY)-FV8*Z8#IwbP9 z(u>yQQf6lKL2bTxpbnKgFdZ|viY4vbtQav4Wcnn6=Z}xF2lU#e1dg*DV2>5oyM8JX zD$bM7;@T|Gu>!<7(piLa>irM%P5W}Cv)}DM=prZ8TEj3_+6r?AY0j`WujFA~@?b7F z+N#JqOR`;e1^OJG^0<%s7CEf`Vf%4buqq^sS6(D!mFJuHaL&qiT^G(2Z?RQ-f()OpqcOeMfF2X_ z%%;Zg)%5p5u8Y;zp>oGI>rDX%K$JG+T`3YeQD_VJzi);^i<^{`OIz^ZDCY+$)kuoq zq1*h8qzh2UzxhpQu@eC5w-VAWJg@kSif(a5Y~mCQ@Gu|g$eR^o+!4cZGViQjVb~ky zCJh?(Hf*M=I|xI)n_Gk2Y`wvpbJqeYu_a_-s8fxKiO#(FCeI<+hB7RpEZ+={LARLN*Yh1LELE<)rk z9{>zQ$>>}OMw;fpF$rC)j&XMGYjSE%{#=9y(iwXFJG^{L@iqD;$f(ABlX_hh zES(+@R+hNxcFa ze`L`Tt_N(>jT!!gfR@22RQZI7I^nXJGv2@ zhgeG%Ufi?<0h_=x;62I+QVTQYE$=S(#w4JA`xkXbE-Jv6fKQKq z@ExBz)^S+RnFX0I6}lYus5y8<+3CyhVc-H4*<(2(hlCdk9s9TMK-TO#GTiwV@`XBr~jm4r^Rq^xdBN$&r2Ik2+sx++m#g#O-00{Ti3`zW%9y7{Axa z-yM{lYwxl)Gh_TE#)HmHfh{??b^-6qjCV%w-n+Yn>gtpkF`m`=Eu)26@?z!};~z5I zs3k9F9veT($lo2}Ky%LAGk%eQ$u4op3_h}DX0TJYTj;=fE%i_1*$l_*;5F9J@jHxy ztn)d+b~IPtO*37c?YrR)oHtTqkA#h%V}SOwb(sgodl=#@fjw=luUC!#On>Khor>K| z2hQuMsPWSbuWW$>$7kkzXZQeNd+M(71V%ko@=_*dyt$LLTW!yAn~`*ycdH#puchuA zf84nuupkTZiYaNlt21afYz=A4%yefABP0v4p~-*ddS_`MeHCfvOk?LdhI?RWHUeha zJ${mLjxqfS{Gmgni|Nzx!p>R-oXUOKbZk71!K89uGR2OEGq6UJx&gR0J zUQdl4U%`OoB)(vMdj`pXQOPei`OI{8Ze`H36W3V%XHIq|Fa`p7tF7(12B@pOii3cT3lH`CHd-%YW%zL^>_&Sdbip$E%Bl85( zb`RvN&sv{TXO{rW{NxC=GnYZz?O?BT?tgQKbH8Uz_z^>=Dl_*UX@C8pAG7{Q}2trlgY)sJ7!EXTI+IfN?SqwmR}<(rI4l7=sW9t{!qUg^mNAJU5i3dUsq_ELF6+V%=fCn&2bBM}|5x8I25mpi|aZ z_RiHE4hgUHd)It;q`Nbg(GeJ)Rku1}UB7S5$43S_zv*o5Ok*em9bQnr*#BC_<1>21!O%+y(}s`t`5 zDl_x9?>%t`n^&yLTOJ+$b4Q0oUzQ{9LW9JRsqAChiXvg@8S%HjyP4hIRJv8UZ6dd0 zgKLrBPHgIUv8vZHIwX4##=mSAgt-HDp!!_~zPKY(Vs(B}AWhy+*rtrA)fDFn!E`$4 zRe;g=gh^0En^3+oB|r)78;@Jz6xmPenJJ=?RF8Ikx8ispXfUBAd*X5{8z*=ATFU3V zIuWH2eLl%2!+i%($;9f`LHE1TWIm0yM%Y){s5<8~i!_~bJ5|mRUpBjOiQTc|s!c&9 zgz*!Fj1NJ-GCYVX+DQ+~2$HYDsQV(wKdlV~2y+D5+QcHU^X$nJCisBr>X`&Y9R)1l zW&Tl@%kdg7MmCL|RfE%TyFz?;L{H-vJZm8)!nfY=^>a>MJkzwEwicD~J~+IBa&-D4 z{VYX*I};Gx*lzGAY3{_Jr8v-x9XQr9b~8>j6t?|V&sjZ|+&ef%nPu~;lMDex z`xE}I1>3~aWQPsDhB`u`0HR@CU4}0ex*3_Br`)5w~j9QgU$Ej51G96o) z)*|%=a!E2JyE+a^+ES&;CAAiTBBUB1=DPNbN~WX8w5TW$(i9byOH_~ukJsZVqQeBS+Wd;JS;jpx?r z%6WqEBi;j{mC7L38ZRv`Q@>)p?uj%uW*vgFU)Isyo9ae(f2n#>a`n+MSH~B>uf^d znik?=?+#R24XHgdavT!T=p?<{zFeboB9)Wyu|oBw0j8QCyO7zOK(eK!&?mm(Npdjm z{_fc7Ey4Ye_AA+m$^(i)m?N54Ad!y*(jWC~A!0KS_bpe1Q}La4b(?7@Rgg)h-(J|C zo$mg#AtqEH>=o~u7?=JDoAN;HE>~5AJBC`xUbN)`9nupjHsU=`sSj0m~ZK^z=qeA?}mqe2EM^Ct&ZW~MgTPQse0r|uY9W%E9g zvfsUNi`U+8%E8Qq_gUWM7p`GLzlSV}x0X4d-lkb~*!%+)FHy9|6I+qD$>t#4EtCW( z7+lM-F!VX+O`7O!Zw=sbgUF6b?J4m94>CnsV#tDhZf{B?_L~;x2+DM??MYOUT|6Oo z$6&F!rTdipkbaL|TR`jZG72p+v4j|TBdf-1QPUK6zblKohFD{xh|*REad-=RJ-N#j z7V0ucmxU3OXRljt>>K?eff`GvCXyx<{A{M}io_k^@zAzZ9W?(8u5DKFTioHqkZ9}B zB^jr?#JN+~z~Fy|xz;P%?r?B|G~T|PC|-~|Nf+%fX^m4WvxQ(NdQmXV3-^0q_HNT-l23MHI;tYz|sE$~pp42^A9{L0QXaU! z2ZG(&63=kmN;fe`EsEBDLlLH_V>XPqb40loEgU*djzqh^)!r}JEhanO&xsNfbQ-p$ zK~)1u#Ukv%j+CWV@-F%|(mkwr-Q;+8v9nBGa(x;=Hi2kfk$5Xiu->%42=Zw8;>Sbe?~kEKiGSa(ORX;xz9{Lr0(7^J0llhInhr^VisH z)bVD*E_xSLxqT24Ej8+@1H+fMI2X6!mz5eas_1L_uD>!~6H;mzA(VD6bIsG{@r?;-b_OZR-4obv&VtQf#ryx1zWCkFNK6%311H?#40UQ&CQ0t%TQ0$T z{sC_+Byp;9E8{M$+kez&!!(>ebI!~Owjh1iL&jUcUQnrJmzZ@!>oe%0=gv^wzQHlp zMO6*8e^JGI%ke3AV@8#Lb(xhb>oC9y8@NK3TSI?AYXIXhpj?C`o^8?%)y_;dvHxhu0T(+SciwVB)UKj*&7IjT-UW1ih> zLaV@q*1BIcW=@_-#VhhM#p2$X{PCn}Y& zEwVX#OYN1s_pkG(hY{T6Z?}t+R@GO~&a6lL_uFL1<&ZTEla`yl8;-=g?B4U*)uOF? z55hkyTG!|9OI_)>JgDrt_1-FQ(f%qvUNdyV`u6m&c(USP=DcU$4zphrnBcv0e9s); zGspM*KgRdW-IPDkn{o!sYH~sU;9m>l{53B43$mZgQtd;f9k4Ul_HKgrK~hC{LtNb# zC7EnjQ`$}j#+Dsws&v*kw}aoovebeET~SIvt^uA@U33h}j8k^YHs9|iK#Zl0+@AvL6){`q~3b%dF|Xq@y~sw^8X+o$W$8Uw{w#M+58U!B_A@8#1$jr&=S z5$bl!6yLx@-rN#jHK6U(eUNr+wo0fnPLpv+p-u5_yT`|E2WR zikz7;^xjz=uc4sP#gEcUfK?bBW`i@e z%U_ddI~f-n-Y5Fk0{0x5&;7RBh*zv{JDOWh%q1Z!LT3X?K^bBb})2ux>N+G#<=KO@2|?cb#!Q_f_lZbJSY?PiEn0`&Za+K z?b%{K`v5i1UnaH|3LlgXaV9~?2Do^1DJa`p1%XshW^fKA+B3l}dT4t-c&?5uz&c)? zxh`Ox0QPe+suxrPB&}ChVA<`fpmk^5*u#zrYi!p{DcZWU1$Ef_8fT~Dvkw^GX%fe9 z1;u}zR{^RQC1o(06goLwE%G_TQ3EpJW!k1kVPUG$nQDiK<@@R)7h}4RUSN*-p=cRt zz4v_wl_GO4-$7e3y<8a_L2o1R>|+HaR2Wsg`bU*4&E_>wot3m5D%4@L?{vp=%b}V@ z=iQf+s|wC1&(`~H&qXsL;0vag82?@|TSltXrW?=eQh2^vmgX#PhPvA}`exqgs&=8Z z&TuL-&CM62So}pD=EG-Wgn9C8jLfp^xReMXL>t|eEM-jBnLk?3kdWHU+cVR%+g;aC zV7Te(EDHdY#+v6q?5@aWbbDf1`u85AX+}KVoBj z*VVVey1oYXt$D)inzLHXZ*1sZFzaUS!7yr1=Z{w(8^5Scyxq}Suk2~u z_!G5udKHFFJ2u362kL381q&_Lt@YwgF#l#J)a+n!4&_C)B(I-p{riiX4C4B@j+o~s z#CaVBwH_pC^t86HqAg%42-X#5?%1B1E?G@|nXDzb50<~y|F&FNCXvf&|)qSDfj z>vB<40OSTVHh8V4=2}S)>+p%_><>U=eP#NAuFJy$UNhEEC$@YnZb2c;(fdB_Ksh0f zqrhKY^KUmxai=|_xZyVpWd#;WS;0YU*8!S(#M|ZA$dKuHwF-V*KYH$^NTUk!``0x$}EoIG3xIPD<5Ca0v^ zw%U((2l3sXm2^AH+?OlK(Xtt?d=wXr_=Et*`3K73NqDG?9s;OuO@uo_m+Z3%Dz^ZH z+Z}E==}Zs;_8Zn(d=z;)$iZOTbKKu1E_=qIR|7S`kD=?_zXMslwl3LWS2SgFkV0HG z8TUKe*{V|CVPjP0^S)!;gq-jR0ZuXj8#oCk%{m1>pF&&%d<$MI}+ zM22dp%!84d!GXfSGyuSRkPb6q5mI<7Vn;S{6m}wTJ7hdy33T6}r?c_m%l>DO?93hF zlLDV2^qU zy({R2eVTPdaPcDYw=xQ@jPvg$gONBm zWU56=6^%mPpyf?Pv|qA_=s1rW@%X^}p|TH%?)RnQIs(A#0sH?xPzl<;|9YfU+7`sS zuK8K*iEio$U5Ql5+JZEe@fE1-l5XmG6msa+`NND6OSn%H$m`%u?pgwT0zaLfhctJL zc)DB}Dxa7V<871m! z&LJ=(fTW@Jq<8AdGzq=vu#z+Z-z7g z-N66g^c+e4QzXgfj(JXm`LPEigZ6nzH`TaWMDY*S9nYbca>1@zk;02?jBmu{6!pJY zAG|wkLs(N@mQX$I8A3fEaR+{Sp`i^3RdxqN_2iq{pLW~Hy#{Sq}zZ{tQC zkuDUjwO)`rxlt_juc$dRRMDFG!f>ly9vg37l=ccWS+C+cj;L2q33`}|I?~V z`mf}y_|Sl~<*q;ItGOjdF(IvQS-+5n;mTN;mtFb#^;}sICb0Fl`5F4lavhE^pRu-S z-G-2{qb_bOj__jIlGaG;h&(x-u$U2UwaKI6HH#RltwuRHUb8qYuoYt!${o0q={doXyn~nO=m~ zSl#l_crwfTiYrMk=d^re7p08Tr%G8VG;&=zki=G(fyu5qmhS;WkS@rTk zZb3XL-cZ!PP_no_bnrJ1Ffspki1DB5T)DWS^*7dLxe{ldFI=C$N1w(ev(!tdL_LG+ zU}1jQ`n&uC`or9W_|O*_oAR^uE4d{_n5C{m`fKvAc+R4<5Z8zLUass2W`!$N|5y1b zZf!gV#dy_vTizu3EMRQ3UY9Fyn)!_N)_OS~r&+*=%1_g~ zrcyN@%$oYi0fT_v)srf6fs|R73fdM ziSgtj=8D$U);75r2hCRnyY}kwTrA7`tJW~9LavF2Q1HrpqMpNb$0IMQLS2qq$Do~A z#qx1^Irn;e4y%8GWI_Fk-4Pq2nwVLK*c^+4VqaCp?GQK<*+l1G@%4Oi( zN9`!|hWy{_zm>nkeV2O}$2zitb6|TU~5i^T)RzewU%@JzZD^roo~T zNpM9}OnxGHzjxB~R9|=hXmxUlG+iB>()9T04*c(bm=3;K@#v+wZ-<}mx5La__H;~f z7B;6q*BAI=BMbp0;UBGS{(P6C>AdQ7)z#?sy`&CpaR9qFxI!$5jf=@Sl>bOyHzgiO zXxp1ow@Co!GIdIgDU~KW&#qls)jQ!tOcq2tStp9=R_2gyC4>o7i22PG4CBYu32nS` zW@Do)cxdy)Z}Rj=I?}DolHA{ztlG@tyo~&UY87iE@gH4FgPvrsh-9TtotY@^FISn> zPcKL*QRvG9!vhvtg%DDB+qFSGh)dEg%!CVWEGxH)8_O-?rgTgHN=WX8O9Ju|(Z-dl z5!W$^i~WdGAQcgL`%16#Ld4dodUuIA?i$v5k0w5(cN!9%-ma2d1^vl*guM_M#7LHe zh|9UqZ)N>By z=)}&M-i>OqrOlCznYiebxCtWnt2Bn?mN!f-&O^ptF01!QuH%wK z$gN_vg3{})u%g^OdnDD6*EPYx#o+`fm)c9wtx~d*BR3+(oFT#y%cA zb$=Jw^p0#uNg41)*+S^vmn5r+oA%O6f{WatY}hFEiifT)bD35yMZ>OQcLg+dAKn5P zznaitvd6q7G}U{JlYfn+5~!ZX+Kig<`?Q<;lr}NaZqCQtu!KndjJJ0XEc|-cJi#YT zP)}>waq(~zRM8`rX17IZkRb(ii|5mh3vDAm6XF695!ty`JeoRrrM)7y-`hXcmfl2x z$|^zzuzHT;c-$ujNll&D8H|j1M#^*D^Tfn=NTYqM<5;+cGp4e&ZY)1vgBC=aw_0dn zB_(#v)R*;cmWW>SWr^Uxm>1h(Y}l^*ZWnyaH1u`Lqw-^S#BS4wusGsuZ@#8*zXgk> z_e=#xnH1Qu0h&Dm>FJG8m@ow}!;I|IsZ$~i;^8s4;$7n4*N#csV-FABtou?n)EKRO zVxz?to9{zoTTw=h>(MZcMra#3X(lxGFJ>=p5PieHeKc%CCd0OIm{4_{0vUP4zWoJv zhqRX=4F+RSx$O@1qi?<3h>zZ$bcwx`KO&|!GD|y*P2nRw_ZOG1H1Fb(`(&pbd18u0 z`*KcPYfGLo`8YDR^_Z({=z0$5h4EPo(bdBS7t)p!m6QJYy)aq=Ds?DTq)JN8qx6>0 z5}1v9E$x}(4^yw3Wp_08#|)P++HJ*pow`EIl{*h$F=9Q%R`GU;k-ON9lC}7$2U81C z;LMMQ=BbbGZ^>z#Aj?9r`-X>ISVr$sxDj(hHMNAG7|+f)-OwSrImlr#$E z_O5`{J7JhiXmfVvH0P;rOxaIas@ihIP-ZVEfyLrO+T0|ESh+Vct7_`jF&Z!K=^jY| zHjPiVw<_&2Gu*PCdM`i{z&Ee;)|Tvl0GXmfd$Bf?qD0{8Mydeg&9v3+8_pn(GV(gB zD>w~T6U@AbGa8MDof6e%+8}OWBEua>rfChVOdlHMSnq6)A@K%2>GEP3WVzT#T;+Zc6Xx|eHk*OGz3!~ErfuyQN}dH zvP+Op`QDc@^E7p~zi3KA54EF(X3f)5>D|R!HbgM*cHDs`rPc00LfdwvUzdA7mOS7} z>L=KB^%WyI4N7ENYNG?DYt(6h-d6JDLh*xWmutA3fE}-zh$1v6AUkB}I&2eCnk;hv z`I3S-b&f|~W6Zs;VT$1jR)}|W0*>$|nMceiZ7r}1t|$hv;8R92R)7^4nqeG|rklzM zQ10t;kYjQxv^h$+(P0$lDm)W6Wx>$kDYCs#?do6;N5r-c(C>SNiriKb@~$~n0O=)$ z;yA}d5{D3xs$0qKknOHTy1nfc#}L{v^?KQvUP%6(q{0Y-AShFPG^=Rp&(&@n@oE~B z<*6PnNhOY1bQja2vQDVSy43GLF--%O6E(+?@6l0?5SN5qQKN23BPk{v^`1xJ^Q=hz z0~4GSMd+c(>MOSGLcZyMuURgxu#X~-fc=1Hw`LO`EfFaUxd$1#I&^4nC3EyJtxUjd zg4}@Oumatr-X7aL@JLQ(dZwfc74+MeWsV@ygUd2Pn{D2)C}qQ_kLhpQt}ChIBh`bh zL(aSh9o@6Ul~cNqb66u|^Izk27TZxV;a)ZmH-C|m{cTKv@IcU&12C1>8 zsq<~kd+&6gc?P*WHg{d)%s!I=UIU@{V0h^J*3`HvDLGC8;OA!mkt-kVW2yvC1!cu_ z!#uZ78Aw(Ys_(WC(l(5yju~tZbH&%DPTHC)nh<*UplnM72#K()DCGEP`8a3l<(#!M zoYIhGOtjpOrDqVqe*rl%ml14n=G|P)Gp0K`Q_}>G{owYelmpKIIP46_y9cFvgQDbckCIIN?3r6Bt-y+%7=kXK= zcT6_xCBOhcpMlXMZPAeybc}& z_Qa3To4?}Q+-IsU0H$mKv@zTJ!Q)RNgXMSt!L1*~j9{@Vb6vjk1pW$N@h9u(9Ll(MQs;pX`ilN|-v zMKg-YGSK=z+sEb1*cVnvAVMtaRkMivGtBG<;tcvesAi~kb+08igQB}&Zn&q|P<4LB zB8X#k^pXz6p$$7vXQstrJEl1;07U@?ZAQT8WE`4>dZM*%(sEhzIw6#NAu}!7F^x@W zFS$4~&^#dIU|!YBGx%7zxv{mse%=bam-O~3MLW4Tj2K;|Rtx$SoBM{zxdHjb zuCw&}T~<>F^t}qcF)`}n^urvDW0sluLBiJ3_RX+*+lS#EA)opqHN!@`=fL>|o2mv? zmk-wKktO;}1iyEZFwFzP_|o&+S~6pRIx{nF%2KP@FyV?x{Cmnt=An?nga67n86DUN zW*CmnQv}s)+a2Y2JJai<%PddOAEHxJM|6k%21oi1NBbW>SzHQmw-**=Tz1s9Y&Pw< z&Oh}9cN?%~UWS(8TMmp)dq-7M_Yb+?fc|Mu=-$~`R+ABP=g1t3GRLC)mspfJq>nH+ z1am{+ry6aXZ^dgKW@g!tfAH_eho);>wZH9z&!KI7JnoJ$mWe%=Y9dooDIYBfmBrYy z+v;5Ov0ltjv;}Xe0L*$s%h+k@Vt&wIzk8(Zee-pWa>A(d`)!uHI+LoQpnvnc%IF!3 zHY(;VQ8mH?}yu#uuVvQR64S?b2*26?6sB!!H2a(eb4)0xe3#<(B!Q}T! zW#?O$PDA{3DVk@hohkZu5_ojz^u!LCW#*?D2&EOdvj9G*1+X(^=-Kr$LYiTpjY!VA z2zF&Oe{2)3YhgL7-Waxo9E5f{EcbR;eC{e7yp1kpWKFi5e5SIdRjd8>#aXlLIi1v? z`FM1XDYqba8umY1GX{BQK>EpOu(RqOSvn{s0M6ZQ(Rt4?6G17z#63QeE_oVc82S6I zY%O%ScWAxv`D$;B{NP+jfM0seH&QQ!wO~htavyi?ml+Bu5mr#M14njT26l1Y@!h%S z>sY1S=w)|yO zB&c5amL`&RmLVr%JBZgfv>lGmzr!cChRyi3YG$6U(~)U&FZas>OVkBF6fNBaXK$9+ zCZCf`y03TB&H(Pe&!7!Sv9coZ*kW*?QmZeBWov`aMoi z^87o+Y=mx;L^zlr(^PHlS&wOXY5Mp zS`DZ7P;)vJvG@Grf~>7HA{eX!`efaI;*Zym$yjGoHn?iE5(SzAP=Dc*Fxsd}1!`Df z0+v#ve!*;-HwFkO*=5=6l!xm~XJ^+r#SX)HD8kw!-sA3cpQr7!`6U=?Iviqw-}W#5 zdaoPKS3U8bWEM^&THjJ|9bS#=Yfn7I0X?eyek^xu-JTv$-O?&tE(jm*qyD@nhc;Bge{b-av;iU6;k|*wY*Fh;HBE;ZQUKQ2&Y3 zggAW}`8!io{%nqcI3qRMTqqRe0nwlZ%Zt|4vrgJ9t>;5C*z0`q!FhZ}=y}wMAt`2rPyD=TfArox9BO5I5S z;?`Q$7@mMnu?xE4Q|9({kgpqk%?y(5YfL{0UqfCOD;YhMlfy=7+PYCem$L~0C#ns1 zb8Ld-M9=}>Q?DXyc$n-*?ER`lH)#8Q5yorG$NM7LnBf#>o{cnld{YBRKcKufwjN;K;IJ93ngqW- z;SJWoWloD(#9Z?PSOiI90Qa70T#pA42%gk?4m@}xXf_|`izRTc5{DTa2hBy(l^SrMF8G7~ zPRIMc2K+jjB|WOJ>i~vl)D-Rt;}8IDUv|JLh~Y4I_@5zFYl}g*`(^<8Hzc`>eRBcB z^NT048*WeHvNCh+w1zS@>PfExRwHmw=!6l) znCI@&rD!9-Ih`PkG_I2Jz}50S1EO}f^OP@v0t6@&Oqy$bGnCN@M2rSsMu%u7c7wzL z0=Sd?w!{{$kumP+-X~}d0?(z(c?#AROi~_zzypn6SutpmTbN+jF_%G7fv20EQRUAv z@POB6q!#dqA?z~0h7B9j{Q+E5=1D3XR|jMP@JV@&rcGql+;RCENB3tTYbV0{ z2kVaKI6E!{{56XQBmZYRjK4=yyN`tYYA~vakcG%`xOqnc76t!$_pf6%<^Lh;K-Oyy z7Hs6=aIB*NsNk1k)>*H}-{)52I*x>dQ2#k=gY~ZbDt9=(iR@IW_3(-+lf!uDyK$lO(|+B$$$C@%Iu~^m6V$rF`lWmeE_5LyE`OJP2^U|q2E};OIx5eM=PYKd z$$wYhCEvl7u|&UgeUg=t758AF4VTOk{knBkey+Zo%R3@k?%J(4aEV7m%UnD4O0M~c zXhrLqe1@LJHM8>Ot6p)vqi1m|;;}Enczp+#b0jYS-jJW5&*Y}yuqgPq`6+t4JOU?v z5st~3$pFT2w84leeHDF9w;t&wwZ&;r#GSB{>@6|=lAwMJQGat0SVYC^%Z z#~s+!O`Ed{vn2A}-0L{Q5$%!w#r1*ZD|Ta>-nq!g-~UQ^=-}^~4rGV4I{4g$-Rk&a=t8lS#=^(mkOO{NY&K2VR zyUvyOqW*u%xCs1TWn56Wi*Xj*ildPMxxX6R(6lSdB45f4$E6%?4#-_H7}@ljrnj5Y zvfT2+@v%n=7a4y|{X9!2U&DPQZ{Vup+gSGb#$|)6o5)!lIWFE^WM60urDo~>MgAx5 z;rL_m^H>23!UKwr#Blc&zd9FN@RLt&aF*Dyf#br#7Zw+k6_ypSI+oVGl5(StHrQkO z%Ji`E(bbl{mdz|FXirX12ZyjXwtj9MyFUQd)yH-QiN+S(eJ-|OiRHDp$Ix7#`k(j^ zaFACJ;8WJ3<^NchdeEev7s3__Vw$N{CGkkEj*a+{-<>%X@N4liV zWkDLq{zgTimeYtQc^8ulMeeS}6;ZJJ!o-qvQFfy}y~Ko4F`i7lP8hm0Ef>hv2&ye! zVPtLyWUqn~u!Xl*rY)z~_WtfkQy((!ARf-|6dWAKFVimlh;fQbV zV7DRHzHT`vhFv9e&)JoHWXdIdcVT2+BK^_NB*U@utcmXX3nbU@^$M}AlRZI%j1g=X zXeaimW}jg-aXii#Tw%j*EOgxUMHUP>m5Kt)!TZ(a>t&oM=C?X;AI19!Zwwvud_F>E z9_%=2j_M!|rHV&4OHP3tG%@97&Qjum2W)m3fxf*4>Dvo%Uxs1|%ej_p@fcB~1o&(8 zoxq{y_LDGK(~@kFm%u{Hpjc*%7}#GT4F!n>(PJRGU~x0`L(;cPGD`@SMs#^@x)_4? zm)7J5FuGJp{m{MbZ5zdLkJ{M6&^88g6Q5pfc?1#*ocEuiVo}ldi`t9`Z~Ih@Np*wL zyFNsPxEeUYEl;mq(ui~R@3izh+Fl;b(OeQcOQQ5~mWYbT6asat?%Ku*?THS-p^@0; z!K=;zQh*Ull?Z7pq3&jq9cf3K%^vF>QiZWl(a<+gZ|`w%l9Nv zEQZ#>F84(jAZKg}vkEd}YwEB>_pul}tk+8FBgU=x7$p1{5mHbk{jvA5BRxY>NjuT? zaOwaa+Zt^wQ(AQhy;gmbPmIcdRSm5?%fou{1Lr91#3<+U+mUQXiMM6F`Dt*)yQqN+ z&b}n`a_C+HDo<^5tK8&Ne!A}U6g=!L27v_@Crho`bwxkrEfpP9-2D^T(`4x*niQ#1 zTBk&`dQH0XOI4_*vh7hT>`N_}sK%$gJ>oWx5%0eRk_()-PrL0%Uu^CZUtoc!;BZ7? zA>98biMwRz26=?!ofpb4s79oN9#*2UWjPpqs6}+i`)`=lKOok7=#DAv9ohbeSSy28oT>%{sDsaO7^QA92rkfOwpsZY^WUVWh(dwo zCI~J#5j}QgpE7x|jX8NIY(tIadLg9m4Gf}dUsd7aArwIu2U-8Cfxc?E4z(O*OQcX>e^M!UIY95SqgdUlyFrd0%EuYQfy z)$?AVOrrKhZK?V^S7+v7t~tG5M1(@J(`@Zz82Sd}0MdrhaVi?%_Q(RUxzd+h&{);F z2RXMM5#Df5L=@zuT3BofQb4zKs&#h|#_-CDQ=tRbps`MHzu5Jc95Pc){eVaR$Wg!) z*_d#fVyLlkIw|VE8re;=*7bo#8)VN{f6yMuiY0|(2Hx% z$z$~3Tr&6`YD-7~2rR%WHL7S!lzU^fqB4@{)o$4!rb_H^3c>rmx>Xkdwydwo?b(m9 zill!Ap#=oWy?0kSfURZui&&PYO>0VOoKkdmaF{x+#2!KU6nQ{HA`cGtN8jHAsB6-M(Q3|QXedB7lq9EBz@pPR><%K|<-22L1>@rPb=K}-s zmk4sYFSp>-M767p@i9|(xr6v%zfj#O=?Ui(waCZs5bHr$0fhaS=ar9!zJfKKynca> zoE58Eq8&wTF(`3ft+z29dT`g-mZ~^Q8f3HEBPjn&G8b)eh^{IVB?I_P0?4t3Uozxm zU6Pq!XTf!ri(k&$sATh#po)mxA&sFa~X#W#DkMaG`mu(6sS%=zEuM$iQGRj+!)M!!nT6#^Js_{Z`>;Eb?eb0$gpI?)tCNGSey&HH;!8>O>oOZz9sSFz@_PK8u`(~@|0|Kuj&wc84O&wo#4)!G*V5s-M`LaTD#SJ>{|0mnJo;VkB0p14A z9pEdoz4tSuz9#O3jhZ3mFUhXH;kbA18L(Cc#EI%zaRQyzcFG4~)v6bo+qAcULonUs zl?o(-8|T~nz?ZD^Y}MYwj`&!NgFC47W9ghjJMm!G$|o}eYHa$ARseHhi)lMjr@Ge+1 zdxCC2!ueq$@3#eK$2z*uNq2rg^2(}ozAmo~bqWc(Qt#|Kh@YmRXcwsu?R0?}N}K6U zKZDgvo|;c?rAdXe?SiY3VH?bGL{hM&{)V8>5qPr?T`%Z1jijfYts<`Oj*ViFGej3A z-U^N|_XB%jI2wc%fXsqj{A+QB$Ij?G0 z%=CQY&I|tmhXs|tmC^Nli&tgUIfziAP#l=_Z|c`Tp4l|dSTrO3EL-mCPikW22uK3sY|_bv-{V?9q|g&-MQO z4_duGdi*9f6p(Gx5f+<&m2(EDF`*y-E=fp`Y&;0ay!vg6G618+o9aH)Rh)q8^`G0V zk~_U-@=WH$b0HnQYgf&+YLA#3g1I4>;}Yk%#Q!?p#zX$eTW}NepXj1DPnX28Ka!q>$~8WDwsQwRKkvqfyjdNg zDW>w^p%X;xv?RF&LpB%IH%7*Me{bvbZ?@?E&m09I7Ef5DP6!{Zm5zKzW@7GCT^}RW zHhWu8;bWBrBktIV7Q%DYuEE~LobAOA0*G|LA@H}D5BHLw$5oqFo3Z8}{I{8SE?*~2 z;Jy)kIR(BN*B4wM@GG-3w!wcciEDo;bs#+8Aiqce5mmK@*?sLm zEYEVLz|NMRV%mL(|5YTt3S>3dNT8JGJj{%6fc&kGBAyEm7-C(<&a&gc*{z>9oi3FE zuc*{kjgI`do?T`8+;0WfL${vl(@ z9nCC-hiY3{OPB|H&}V0=9Xl!ZIX#y1wBd)MCB3105wot$@m$DKLz*vo>oZ<@l;iVn zbiyB1vVvr@7%0HL9`$d}hEF-kUGKH>s3W_QsXJ!ic)qN}<(17q??2K>W{X*BbO%sy z)2Sj`7pP_VWxVW}bOK_|X!KBmtTmm|e|;vxU|Nlx9iv(0^WX`mxAIP-t;fwTt11{7 z`uBVihnTu^*5Va_^$9zzW;kGh!@F;1^>@U&jt4m`z?#!=|}9 zV{IXa$k_LQ6~Wxe2$>C^%Pf%1erC3S5CPM;W!}X~ygEX)zSi#|1B2Re#?11Yx>XHn z5kQ~@Vzv+Y_Zu#T9+|J|`~+cVdl!G8c`kj6@l6^l3l`b+hp4 zw>NJXom%D!xa(f7L1(RITi+TVDSPN`~6KM_HAu;`9);%RR(Y#h2;M{!}J zRs?4I{pO^-`E3i6)}Y)o~Qz@b5t5 ze(>5F#+&ZnF*fPXl&%8ifY|tQOWW3#E8I`bWg+NGp&&T`=`K}-zTx(5 zmjLzyDTfcLeHcEV$P3v+9MJ7QibLKN<^X)TF3rTFho1Jm4}jd3i0P1lGUiEh864qM znS2mBk!<>EE9iz_1Dipp0_cTb2TCHJY(O-6zqtu$^==~907sn(-;mi&!+=w8=&H4&rtD%0V0M25)I7f zK|lHqdZ@%o^F=Ns5GG@M0C~ju2%n!IDw$*er`HpAkyDy^>i|#@WQRZCj7IPB9k#FK zyzdw-=1m}7fde^gq*p;le!!AzlECNA$^}n!pZHn&DoxS|<&1-%j&h=RiA#W8Kzn^A zrr;!HHn{+}1;9^8K)e4u$&6m9!aC9P&D01!$Q;lDjX=WaG!J_~Rsrx3t|f{}x(=9< zFq#TS>(!(a7D_JC+#Lfl%6*v?-MT)fwbuX47k1#;><_FcWDgRXz+A9UMts+vc}Jxq zcijL2f$v#f2VUoxi79yEesdri1ou%B{3 zc2Z&$c;R0sOECTXJ($0u!C^XD>AsBS9b{WToC6^LQ(fI zF2aMNg#o)Uqk`^<-R|HkV>f_hz=zFF@Ek1@P5jLcUT{4v*^(ObFd3uqt0aYbLAM9Y z6yU!2H?ZJG82zMaqla}ElV~hxCs<;Ov3zGCeRKV#6%%E9?ky z@}zUb@c}JWD>0vD$=#O3K=m*{H213XdNH!(t%H@Kwtte2t3m>im?ME!+{y% zLQg-~`_reX=Q`K_xO08U+|!l+ADS|rOf5hz(f@eVXONJZZwG%k`{rd9N>XusjAld# z|4|eA4G~RyvY1)b@;lt^@fAgp3v)vTf7_IiRV(l3=Ef_FmL6GJ)Vwe^l=@y4OMZgO zjzkat=?h;%H3jM0}mb^3`AAo+{`as?rA9l1VBqrK=T)rMBdl4OLt(1SHZ{g~WR)tdE z(Vvtr#6`TwUhd+_kKrN~gfC_*vO}pPeK?oj?D$y!XSp(-h!U=}O65nn#OTU*}!dkPC@mp)Be1wZXiV1dY(>HQu zM;qp+Eo*(#>XZk@6BgAkDrQl+O#PU=EM9S>VKF1xdP<&*BP?W~t;O;voaRMFtd%V% z;}l05P>cvGEMLrBfy2&+-_B3cZ{Z>f;7!(2xrECp%3A=xm7k;6aS2EA7OFyA8Tw={ zfn}NxugcHV^SN#D;)O8A%95+N*dp(*T>x_A(&C}TaJbbW$HqH~yvtf6tUNi)h46x@ z%%lDDB?0xzcCT(K%wo$c<*i&>JPBuDMb4Kj%>BQjpX4l>>rVfZ&ufrDO>^DpfAS6G z|F>o&_`B23MOe(`usq8inU1h%vn5x2ZR#;yGd+Fy=a=^`EGxjL45sx2SFDeVB=1-2 z<c*lifhVMYMJj|%vCX&t+OE`S$s*9R{<727-_h6Ul9=5+T^}G8Wrmvx|#cH)Z-$r*(Meg_Xnks`f~H9c1;k!4jZ5N zT6?1+QV(fQw!p?Ubdv51{Gl{+RW+;pSR$bZdip2MXx#7!AXhaczP4MK7RwUtZ{L^I zgGfIULBof1%PNL6Z4P3G74nX9$e2kz)?uGg7D}ANw1*tnfGRvef4H(jy^^0-l~P zxgBOzLl%!uN1z|uH-OhD+n&D5_WnG=9iJ)Vs~#wgUR60h@9Nj$sUlQCx~D@}!8eIA zNu!e^Y)+i%S|SYRV1^8y&-kZ=Jd^#Fcb{9_U^zZdW`yF$ z*`+XDabY&$YLi84*2rnMR>f}6 z2|g+Fo^T7eq+6&h6}@WVj`dKn220H z(w{yc7gQl59d@jfk1uSX+=`T#lqVutBHN7#;&_YUq?w3(7cFs;%_{NZg$`TfApAPJ z4Vsjx-@#M%W$IrKvqN-ANen(FpQCjAI(Dl7-V3<~7^k}Wv?a|2VY)3(VtoxXdK_GpA z)LqSqtzeMcUUFtcmgt%TEqtA@+%)8szEHl!+(aPRwku0QohRKjbg`qUxO{Q?)f7mY z>FtlTe;Xcy6@6bKG%dp$hP<6&b}>WS$L5xd-xoY$%-aBooln)dj(Zi?5He?yE8Q9f zy8I8*PzG^cM`*!PF;wPdZ}EQ92y2>+=7QTTZ1);vEyY6)=}Glo(AoRx{Wp_D;?X6F z?9)trl8_ie{t#JA7%gBr1i7w+DmPzRoQb(8&f=lSURP;&xllyiPcn3%f+(7K3g8xy zJp~!MsL&f+hr2#Rfz_+KV+oMT1*g~tM9%v_QwimR-Zg>V#_U8UJ(V9J(*8}GMBWyS zgo;Lj*;h67y->%hYPi?z%F^{zn+^5QkX=IM47DOH(vDd23CsnG_URpotxQ1jq?Pf2 zU9zI2<)X!U11U36wjnebl#p3r&KpE@^ZIl0n9fLeZ&|%Q$WeNM^5`}yuc%@03K>bK z0a$-)gd3-`hr-`DE)zrUBa;K1aS&C18m^D>rs%#(C>(>_FO-aVEGN^-BQWMGT6uCF z8sqJ{B;)Wn1zUCNRomTVRS&I!p5b=(F=WE5K0i^}im=+O9ELJ8CMTV_ztJj{lsllQ zu;JL;MmtEYxGoEjDDWu*%F%Cyu2{kF|77pY!(d#Uy@`B>PZm;KLW-uNqZL@ z;q#_Z^izVFMOZzW^@CbInyvyS=FGBj%3y(ta9%}a zN(fzsM5~gz_qq*WM~rK@7HT}u%i?M0)xZTmCx*@#1N>`h0X7!JYNd`bhK2Uj_sXf% zr&w+_b0ALaRugCfm;F3l65(+#o>x=I!pg(3b2&1xvD_A38=)?t`3}V?7KV|`QMPF= zbv7~RT;y^I$L@XjtsnXYxNRDAj zr*CFx&8RVzxe#7m@Wlw3R!V4>KJfN8on&g%D;ZM)!*LtKMn#G?tBHA4rps)UZ+MFd zP>_dlWF1^*I?G%5t5&p1E*X8yWbsa+8}F?^SG>mgZ6SHbBIi6|u;ogw!;>^?Wo7Oh zly_bPN@R%!A}4L4Jf16BYFZ$IG#327BpAHcF?kFk_r!W`;{&Z^-$TDQEzqT+6dwr& z95Isjk1iNkE!BB8mytig4z{>V3 z0QpQr@GIzH_ICH>h64mvAMg9!qHnkX3lGn`Fhc!j?B{4jlBNPZ)u0BFithrMe~n;J z4QUL}&|k66UBVqYjoL%}5B2HgFDsF*Aei-`u0MFdQ%Y%04_;I&`I_!%DqZ-29Je$fatp9~l&2Nm&eH1vvwVWugnqT`o?G z{H8QjcFk4{0ZnPor*)p5{fo6;V_gHhJL|Axz2ECUK<`L0Y+x0Ly(G#1QoV8f%mA^yJg*MI?miTz(Me&t8mb?QH z@&BEiH;|0~_yNk}(E_c8UkO#9ukd%nS;1B^+KY$hwU}mkjZ?A6m=pt?ITs`Zd6V_5 z9;MD@|AmuAJq;-3k1>+T)&iYHUFJWm_Ub4Zo8im;CPEG}wE`FYiu!0HH<#;J2N*R^>KCM z#)QTG6r&@}IkKgq&yw#%*TOS!T3}lTSq{p5RBP8AGfj&Et#8r`$B5I5C0{-E-aZ5t zdG%EKTTQhR5|*ZA!gi7^ft<4-x`V9Rq2F;t_SQ|S73;NEr(ksoR?)pxbnm~9?lG|c zM_qpbm*1!6{9O$#Bge}V$_`qh;$ngm!7oa==BW2~D|wtMgleH6Zzwo?>Y`J;6{oj= zTQMUjbWvPmb-+5qK;WJ7uE@vIV&wYwK0GkhB{MLgZ7VMAE?h%QL%)}f?zL(B(q}gF z{SOMulsMIKA}q>4>S{2X_5V6v`jAQaT<7 z);(7V@;IRUD&C^5gN3`~WJquyam3benxOi7n1;v#No7*Cm^u|vNc8jxQu~ZyDOF{m!TjHrNV4l%YqP5z-Yy;ZTIPHSad z1n)F?EP1h6#o>?e`w{iFMkbzl=5H}_SyjDf+RP?Zv@2+ip0yePvzM{1o~}oV`;4V# zM!*hgJ7&&(wD}X*F?L?ZGUKnP%QXP=qGjpAC6@7f(qhW+UkmGhjp|}NTQ+%38791o zk5Sd0#B9{jZh`OLWU^z$5(L;tYNU7uv!b8&C?Wp)Mm@<3B}N~7*&g%Sz2-tvT*t{bbs zjg12DdYJ~8%OnRt^p{`KI&5xg69ji>i}DDH7|P6E{w7)WU9hVKB)Q4QcDXqOOth$ ztGmEnUxmC!&T}@^q{VHNb$1$g=Ra>~ahr9~;9n`!X?bubO=FHcnG+a)|9Y~(Kx}9y zKYsC%D+~73`*$}tVCRAHWBu(P2LyOUKOAQof>uAY;}KS4(O7A^u|eE1)*;~Dja)RB zu($j+C=ifmL%?1ecY@x5+ZxWE#T3>cISm`+G;=ZSVHf^jU7qv|48W{g7M-U%R?Ki~ zUBCan8MbOzt{RrBhUMyY^gpbD!2d57rQGE$LVNab_lv*lE6!8edZ_o?p9CG?C5Mzx z%;$$NwpvDw=1;4PAWn)H(k7Rya$8>gOc;npOMGMBUKpX4Q3CJ)s?VTnrb;B%JT6B0 zkDw4WgBK`PU@M!pEYquNC&jT+cCjjY?p(hWfXm*Pj1>w+wkvH4`2HiGG%(*7a^5Xs z| z0^eh7t5#0d>9%@q01+^Ig;t&#A`O%b<7i$F=bF zvUn`YJHa!erwlWQJw^pf&6`eHfhFSy!@LSQ5Hp|vFl&Kx0m*WSYZ+qBkCJjlPEOYp zyEl!?jn@Zz4bVY^{}GI7E3+;pdSdi?Y3+w6YQbIvtmB0=Va@@~X|usb2byzSLda=U zg}m-Q4+j?RXYdL&$he^hM7=?xJmMhe9jIJsyut;m4uvT%TbWnFY>QJ*kyDuonj-Uv zS5`9gLuQ*up-~PrYDFIUcpNq7!5mJhXT6J@HB1dSEda%WY05iEIEI=o;z%65!0rlD zM2QDcd~E}%15fgJGV3qUcz8aMW~vRIWF2QFTGKIHldu(w3~ZZa5{6i|)(wtm~4g3q{An7h-^>+I(zw{~ne{_BCJ-cuW%?OOZij!QM2 z&u%=NzwnoryJosubCS1yqgVN?r|%(jI%myvPRow5hTC49Goy#qOxZ#sqOnWRfMa{l zjiDGtMP=EFt3od)Rsd%}=|P&D4T}y2Zx7dlw}W-n@q_Z3GEob0k_Um*SbHIYb+ef> z(X)zo5*{rq{xE|<;VItJc2OwNzR}jvm>L8%8I~fIRRBYl=G4AP*6T7HFOixu!Ft`r zoUV*3(SsqPReXUzIP>7)QgB2maxh{7YcRq5f{IlD&yh;hR+y=IO5Y^2SsKTul{2ZG zVNq)9S1R&mn5jzQ7caEFIu}wROj|fVUZo0b^Kq zxKTx(pwXyVj}^I9QGm>hWucT0BGh>mQLVN-A2m-ch58E=lFLeKu_U8&QZd17wRtuX)^*pL7mvF&bh8E z%DrcyF0@Br7F9Y^#F_4^ew0Vsn0>IQD5$uV=dY$#_nx|HRGqdZZA;dWqb@$#)u6Dl z_vFle)16DzHlu~w>!SwFl1j+ zs(Gq;ruhT&j8l%;_Sud;c2V?9`qn_Bz@R{#K%c;ps@y8_5$Elu+^@L%xVyOrxL|i#6 zevDo>2;JA2+LhLoWpdvn-6Yi{)8v6khDn-9R@r@(G2E~-vn;hNv+O}x#o@UgI7cw6{Xm?^9gCLci{ zJ&`U*g#0h^@8ov!bMkNHXXU5m=jA8lu5t&txBQgcL+&j1l^^&13fKqi1`YsU1N(u! zfC;b@Fa-L9^Q2U7#}IstDV0a<9_xQ}ds8mZwIDWRk?~}spDtuLXRaGy1W~cEk zQF>-ZW*XA|wIi8e?ONUI>3!V$jJLhFm-k6;cWy7{T z%kJv?cu%fDt||O0_&)e<_yPFW@cr<;a1;1Wz^C5zu`zrP+!($KZq)gpEUhdn^nPf1 zXliJtYS-x3qx(nqj+%__95o!>GiuzKG526DV=iqw{-<<@lYx1WRN)W8zY6aQp9_B$ zJ`+9^BG}>VPuTWsLv|p08{3P0j2*}RBl{$K58I!;neEOtV@I(!vYpt*>=3qYAVpXv z%o9?CKMM1O4}|5yJD{!kNLVMlB_s$lgw?_%VW}`zNEY6x`iuOG++OY_KPh*YJIUd4 z8@ZdjnBBs3xN2PeV90@sZ^5?^#IEZKpY+<#9Ef^Q@8D=Rir+&Z;0c5|SFONpg#}e5 zRXM$m-o7)&q6b9;3jI_1UAhDP6@5Rwh5D;SEk21mC^2<7o{ul za^*mL8QQ7Ltt<%MR~272ER1)87uV0!$3=xaWkob!xp!6h$Cn>sWifQVwrW~^U-_qh zTUD8>0z8;q>z%%%Si|-p(m02Ac|!NV3aw%0xFGOuI>+17%oxWpha%h%Q93% zC$u#qqV?JAVEjzGLE1X)4*O{j7e}cQE0JOAtkx5|gtJ4zk&$U6g&JPd-Q%EBO2Z8( zuMFpS*Cz>Wc4)dKv<*KCfe*J1f8To+(Aj4pz5mUWJgsLq^ysdT?D)L%$I4vjZ2pbr zWGeHq9$8&6{z{@sv!ckn+Fz&#EM{du2fk<)>(&ghIQ^&5F_fS5OD@))SsN2OFq7TI z&PVBQj<%FMk1Y4`PVix2GtP@d?Dj_n+|Ro>$m1=NAj+U&@ZeSFi)qI~5?z#i2ZLq3 zZzCr?6gRC?D5-hK!M65xckEu4N_D*AMb+*Ow5t0qj~c@aA5YC^^lqW~7(=g|7@l{c zvOaxMW@$abDBn*e58JliG$*IVSo0p|SO^BLMT^$#n&9zU9dD^)fM#U8~8 zn`&|&nMA)d$)EX&)D!Jie^aE`-L8zlIM&4b7kqPpYq*(76%(3Ih<4!^)(yg_z;pGM z9ny>HDq8=#Tee%yXZS8~vfBODxY6FjI?~3Zu$<1?B<&Y@%&vLwwx{fMw~uL_m7sE zLFcoz)5dhETf?Py*4=ynfoA#Es$Iul#6WH(hp=|XBoFN#Y2X@w;}T3apWytMP=L@k zGcJXSQZ5t16pXQm@&0kdI`O`m(K75qOnXTCMl|GR1as`D63J}a%)|o4@Fo$j!--bM z%@Yn8``Tx2zkW4@YH%t+FS%nQ3+T(fo^d<7|S?bW=(g6PNsG^jN%b;NiO)bu@$Au0hCw>Sn{u98>I@ zoAd8dB7*d5k;m&pf&DwN{0D5nRBuZ(^W8gP13RRdlTFF%D6ThrBoRvIah1`SG zd81KqgaVBVdw#h*sv)DflPGl5-r=cF@xJaL?{o`kTFAA$rK08FjvOGxor@U~ZDy*` zb?~N7gCe)7hA<(Cwjq5x4+ywd*bzaxU)Bd>+>tD2-rjX2@B1TiVw^pX5;ZPwp_WNi z9o6E`NF$l%IoDfQ?COHZ4eL%18tYal%JMsXa0PA%>=Cb+*~<3p;+oHJ!PfgnB)z(u znK9KtW)M6<6Ew|y;#f{MBe#DlvXoC_Rqn3>^p6=tNJOhFwR1jbBQm5xrHy(7~2GnGJ>MB1r5b{57`e<5a$0(Xy*>zg=fP;wkN~OsW!H?y>6rCK30@ZDoT@n9tIKUVF`rJd zSVk1}e1vHH7B0+ahjR82JI{t{=ZuKIG+uiQb(2=6wASVG`p3DiyX?h-+~Lc2woNuU z_mmY$F)1yQ^#P60eZcsy0d#c3fo;iieOF%+h(SQ-diIl)m`^>&zJbYvKObxVYv`nn zg4acqsy?&T`+~9~#+&P!f_Q+Tv!A@X*ZqV(F5&8T)lL2w*^0I>Na?WLBb%PgN~4f? zdZB})1xY|%nB-4BW7i#AtV{5~R7O6M%)YjWz zL6gr#KYMJJsI_oH~1B~~1nt@ExWD;~7>Bgomo zeA?&<%c$>LU6)=Unes-sP#oc)@6^i>;^n+|@L-yK~?7T|>R^hWd_Jupd4sO`$oDyakd@gg zt?!H*KZ&xSltL(n8iL_e5AcbVlk1o`uWZD%i<2>?jqiak25Vg#&eB?Q-zdhvke`2k z5EF)h@wXyKaYE7y@uXy0vMGfhrI=OlBs|HE6u#&r!$0~(ZrF=Xx1mbvsqQfc#ho@O zG6E`Le67dENC0HVKo+Vccy4N5p2OltqO^ z)2=`@hZXNQGTPt|JwOc@3q$VJ=Dv^wT!aNM%q_ryvi{1hyrtY+)xO6e*0$c?cf|0@NOU& z#LeKGqq5|_EzsbFG=B-6na$L!7}((ndy3|{cyJPpqFxzWVwodBzY^8jstEcsx3mfe zp0N)N(0g3xsWTCU*2+eXU69}xSAc2L+k7hQ3m&X&IDlFF92D>osLjiN zUaY<)RY_1rOc$9Z664iCJ*HF>KRFcj^1^B7`n8aU-eQ=ugZpcF=F*_6XfQsI9ud^$ zKT3QU7pPH#2oy=aydsQeZ|LAF9eb^K{{70r1HjUrTc#rrxW^)edKy6IS3t80_a85_tFt|>o4*rGgm$i~P zlLMeZU~~kG!4N$r?HO%pxvdqLQW2aBO~m}s6}%IX%p z=dB~*77kM8%HTIRcX%nHO7q1~qUpnfIIN$SV(`%c!bIM-amDA$^A!LMS5r)0GRor@ zgK3&C3HT7|A1q+-n>oTxH?^v@q6$XdD-LA2>= zWnA|$a}MW~{O!`afhbt-HE#6q6nKqjLViqGX2`)tmO7%6+obYeIP_WGz@MlA*rZ=p zuH`zW&3d)7p$|jyCP1Ybf1Fb97ro@47X+F`p;0`$f$uoyYzDU?5HG*mPeyc@#S?!a zn;pIFYTnA&rE;8Ip$ zS~tgQuUW_4aU~D?{bb$w$Z_a)NHss=X$gj9)|^@xgLv>hqPtZ}mMx`+-SvgsHeM%h zDaVsAc}(77ie~|^F7-eU_JQ{YKcso&x;#bD19X?_N*nmeN4q+R6|lZU5KA%v3$Rb5 zaRqIVIeQfR1T#;8Wxe-KUrc@#vnah$PXZK7F-OH)5T;n{kn|Y~^LyS{`9k!3G0z&d z`j8NXIK#jX;pV0QQqg|!^#&uJ&-)tk@~HvAjm#T7S`$W(c~q0dE$7AatV@gOAc&TH z4u4X$lr}1QmeXuauAtoDt%VMI!a4WvWJX{%vWpY29o!B)1qv`2rapsq7t92zXuXo6 zTg>ua$&2Or(PZ7dxWzmn6bAs?;2e4$xI+jF~<18iKOw^Q5W) zRQ^^l65s_Qt4$}iDc*?jbbMRk94L@Mk_Pu~qVPF#q=I=8~UCz<7dwx&u z{0>X5&S03XC{{i1sp|0+EO&HfIUg=mlsa-rCW;`Q2W&AX1=D@P(mqJ?;=_VdVcGoP zLX`UH9I2~|yc}I;)WGr}!R;DW=yPUqj3HqdaAVVDb*@`R2QxT?6y>#$c@j)~Eph#` zV35!8Ygu4AygLeWD2g{X@k>k$^%sFN`=6uZ^@NJNIIGpFE%hLQdHnRL36Fm4H3!=D6oFoZt#wHa4{neXaP4uLs)tT_4ov8 zFCYE0*kfr;kg1xz(ef3$m=Dtzcf@Q3k0A?Sn!XLzWm-Nod<(td0fd z?fN^VlPuaybPiA*9z8aQ#EgQ!qQDGXAThP2JVrHUz|tJ+LGt6Iahv7ORaa2yqUK~8 zWbS^eAC*`FON{X?twJypsL6Wpel18>E4J2?L>64Pl!5Lo=YkX<7iGUDjpk* z_F>>irIU(_ZG&;cz!c9$19JzXD+DB2uwL@)#(@A9-$-pk0NEpe2IP zI7GZikMnQ}Y0L{IOv98*&Vo<0@n6aXK^uEJU-Ku63c-qBSEyvG}}zrak|=hI5rV7NHu`e95nF}sg@)W<-q z1Fkz9>$?~hC~QRZrAZ}vYB4I3!n$~}xQcpRg-WV;r3tG6iGgNOhU-VVrHtunAcbs! zfQ;oP=xXUp0!q2BdFpgPvL9ThuGZnBRBUz|7+YvlNDVL)56%Gh=&)z>EHNM&;m`7C z41%*l1z%ErffWYg_!LtQn*1VoQv?6|aS}=cSi=DT8(hUUZofB)KhR$(=u|bp!XUoS znTgGnem*q${Oc_d4vGa<&|VuX;C>R=GNc(Zzln^#EVK9DKAE@xlW>7tqGMU6R zf|^!S2BS_ZZzEs|4%4lsOZpzoQ3Lwaf?$j~?;f*FMX893M<^9XF)o8v6jK%Yw901n zXLSZvXW)PR3xcDD(7s3CKiwfzJl4sSV%ptoFLNxsneZm!0c%%2s%@HE(TDvub`nC?v>G9576Tr>V=CNU}VeXz>W5ChU*3zP39l?tpN0-t6 zNWVotNpGR=p%du-^eOsgI+b8V7^J(?`{-u$YWfbgB^$|J%Ra^4%f7;f(T(Y)^bk6i zu1n9Q!|9{+^K3n~FWb8C_v+2|J5=SjNEe9=oEyonQ6o(Xhb%tJ6-u=AvywtqR zya#z1d1-lBzW068eN%lieINK{_@?=0jolwhA4?s}9D6YB$aCP?^Bnc<^d0o=^&N3` zI0u|P&JkgUa6s529GP}Z2c|vK5oWg@uTQ|?aRdY&K>%xa5MX#1f%;fa<>4AsnAGcW z)xdbKu}O87rJXLp*)^arF&&6J@ZSj{bYM1mvUk?uuiI@9K|^zQ4W>!s>t>OIiQ&`VPpK~11Lp@z^sP-EyWs1eivYD)f!ypOz_ ze1QBlc|UnC*@V24Y)IZiHYV@t%o@uWOB>6wzHgmwoobzF{lGfII?XzZa-Wh;Nu^{` z9#AqUX_PF?{VjGd2bev~k!nYE0DHnZV(qXFSTnNJ$G((c?&?#RTlKK2w%2#&tEhBK z`&@#Ft7qWRtB%Zn!5GqT;r;h-={?(z_Lfib!mc5EHpCqTa4z_mWWXej|L(K#lIl#mo*>T+vgc>)y5PHl z8!izqA%FYfk}5MMO~%#Vdd6K$EgWY@KWWC;j=*EDoM6uX8q<)FXpP-RexgQ>_mm=U z77))_?vM*{fft*$w-|JVXrKDbs?l7fD$TyP9viGaJAQK`pjNcFM3Z35Ka61m8Z!2S6d zO=V_AH%r$6AQ2pF$0t7u+5(e>592u}4W9BPFoj>n^G*}3EoF#+v$Yp!Q%ywu_zd6J z(VyohDb&3M9Z2)Rp!TW6O49ao@li~Rw15DpXSv#@Rr0CPjx4;VZ+)Ot;{4=9Uv2?D zyCeH^vE{t%-2>d(TWM>9&z;5r-=HfQKV2SkByLEd;tR7_=i!(gVP2(~! z&5^Qf9!0`FTcA z)@@;Dfu|O9y5!G8rXs(EK{@xXWJkGg`E<8wT9+dC{LfTs*CA9w{u2vFmZ|#s^(F|s zDJDqh91x?gcPb3_RpS8i`+*CCF?7afGoB-U*)!J!XA(TmQ?VO|c3WY_-Plh_z=OzK z^5w1Git(Q-JqHiXmb(+@*(Kbj2DFC>CgO!Hdx7lvr#AD`B~q1*DQ}Cf_e7UaSUP7{ zqr}4pVzZ7_Tf__>I#N<~zIL{1P1Bq%#>HS7>!d#%3CzE&Z7`xx`NEwShYNW2*vI9U zeoG40ImNtYkL*2Ge6vO0+o(j*H!(h<2P8U<9C}tuUmua$=8S$ep=79{fag<#%~y%X zu;2d{rK3yN@pTwhS(zW8C~evy23?Ld=bt+Ax!7nUXWdbQ_{vu|@2VtvX4a?AP2LeU z@mPB{dFEzWWi~X56TNg2O7*-A&u$#pe~(hGv#sP2^_p z+1xKYj&etDGd@!s+@>1bGM;hc%G1N1V_)AJ#YYUx(El(lF4cF@QdY!-_g=MWvJn_j zsguTQngea&S<_UuzOyCvxNF3WFFV;Z^v@<}vgbK=rs^W_fb(aX)EB(Hs2sEE z2~|+#Y|TabcqM%(Kz-UIyW!k9wDOIxVQ0{-kiBJJlnovIWQ}N8*6dTf83v)ipUIG? zPpX~IT><8Q9b({7{i08@Lc0&1g;kxCx48H<73ZU3Yubv&Z)dO7#?`%}3Ev*qZR*C2 zoaY1Cm;v(*MUPTLDZ#_nXuda^U8#M91WE;FYfPZ}8pEsj7Me8oUHtrC-4kB}-3OIQ z<(bV)=A|p@?@+(ptWu|onyzX^|A}-SbETfGKYY{HIdt<}{7my@hb*(ib(lAl*cb5i zsiA!appVj9X`#jY45)qXM4f65T9P$&pFI~ZwGAXjA-BZ9 zqAr{(9e%Ny^IuU2iUqwK$lin+Xc=GdQ>S@|Qu$t)bxQbNClw@XB;2vmn#m!b0MSoq zPti|FG;5Luq3K$0I@Hh|65RjXa~YQkX!C?Vw@F=kZd$_67ymJ}Jn1QmsZ$XTD7MR4j#!p35>;NFj$ zHGl@6^(8!U88brDo+yf{Sm2n3$0j&aJQp2((Fkt^On(wk)4xm;{~LO$k0!CbtPu<3 zPXuj!S8$VhnasoM+Z0f#>cA{(c4)d)iUchTkkv@=#+oFG-KR3R@x7kI7Y@ zsw*b5`Mi)wUB@NVYLkcwL#)u?vL`{k1GEIh;y6LCphQc9wgks(*b}i=zeX$gKwN7K zh@CSo%jCTZ1pT9yL2`Fhvirzkf5yILREU&kOFL7_$sg8`^D&IF7owMv0zUbgYz^e> z?_L4wVzFEW&KsE>Q$b*s(@{HhcpqSL3&0~w^nzFVpRwc-U^wUVN|)t;(dv2hK2clz zkVSTg_d#$BO7}r!x7q%xB@Szar*g7>Ug8<9m94z&^^r26IEnZ8t{fLM@Zf4#ng21T#umfpe^q;tVX zmM6zky`B&7BpwtEU=NxaA1sMu=oQ1j)Uxd^s`>>$F&ls|(4-d#Hz0G%tOlj3$0V$u zi>ZB!;cTBJIa6_e{Qt@>Kt}7gLBpn*Uv2EtHUpi5i`@;V^Z9j)Bfy+Oo2U(ZZ%#g2q|YTzt+r>TamTNZCtfHR~yaf+7M zRqloG>|aNWd>bf(M= z<5FP?i{#AZvqp8+qH`?6Pf%NFn+(op?@+@@pM$`Ac?8qhHx_l-!*nRd-KCN?D4mU1 zWQ0egKkOXzcFQ_1kyb3ZbBODs%(M`#p<^9Z>Nl+sb?=){N20GzvQB#eT|u|^xbY}o zce8#;;;(?jPomt0Q#Dgy8SN6KwpUxj+nx%DzugkjVw z@5?xluL{@vwI0@_yLySB20s+7>WRx)f-%KcrK;kep#e({Z@tJ0dF2M@;!>9Rv2+PK zu-*x`Xrb=z;OkYgE{v01$84~)fmkZ&65j%>DB*dUk8j%X9xyR}OKVw_7&IyC2Ygym zu#BujY?*55wkkdO*eOQ&zXO?Tkn4Wh5s-~`Qn&z03akvL$x{Pyg*^B5RKG&rjpg|P z)yjltu$5$5argsuy#;oobL4R02Q^tUtmL1ThZVyxiGuQha>xqP6U=R??7r6Ddon@F0}EY*x-jn5TgH0;_m&YAoUy6)Ias{nB10D)2=L z0Yg)5KZ!D0%=x0ZN{wXO6ANcHKg}zMc3HArq2EF~ka?~DCK-N&EjU}J*k8@GLjh`2yd5k1aK<`k@{A@{av#KNA$a6qR}sw6mWz6=ZX2KGeGR%HnXM5I+EVVW zCc>-9&Z;`Ry6s!t_WdJo`<4=6^QdqBuCK%w1>62*!B!LsH!HVsmkQ=~U&gMReQvui zd5$RsIqCSplbnoJ;9%JC5OK_0r(0S{9#CDc8^Ej#&(iIAas-DnC?4h;H=hU``D!yfuuj*JXHCs_Ln|wG#m(gD}p9x8!JW1dhl0htaD!U z+%+F2SPAHPRIo-v=r;DPDCAclGp_Yo6G@4uS7lm?|h-W)}^F!Y_dc zx#UBf5I83h(Z;j|pNlS!3dUgdWr_n}-Y)q{m77yC z3Hk=iHN_+$B#@N`NT&0&7t7z{{6Mmv36t=7KT)xuInl-|1fLjJqDL*zu;nIlg=U(t zE`z*$CY2nfL>LKBl?weHEH|&vJfI4}6r&Iqcvie8N=7fnhMODB1|AY$|aOELXcdlN*=s&0z2*rQU;9@Pn(RL1Qga2zkf^3-U$)9Xi$D_!TzqKgSG6B<#HCu#<7lA<&#rsYHHBGKeMue*kmhfAwN zC(A+8ARCCmjzu?KplT-swAJK^#YzP(WP_TX0^m`zLAylLmjJx& z(*`O6l+qwHpCeUtC6~JNYcGW;=wJW}fey1nRLxUZ%R$|Qn46WGl8fUqh?BHk7)mg; zCb=_A1okT6uoM%#C*X}G4W^Bjm5T*6gSh1u21XQLagLOEM;h~z$Hc%&^wfH-pkU%B zl1je8_$ni2!8ipil9UQ;SPLyt4Q$mXr#wda2E)Oyv91t>jDn%D7^wx)7&=v}Ql><; z;BZ(||2A_XmM9enc!X(;K?od+>6C4xV#Ey4CdlP?EBe(67^o2_m$VeLzd#ClCgIAw zCBxvMGGX;+bp}>v;D7xLuy~Ku9ot>2=du6QGQUMLLq&YlpgQdTO4V}L=rJ3D8QzRg zjjtv|;iCv5yohj-zLtJ1@a9>gJ+AjpIp{b%!6)I92u^q>LO1;ieVV?-HMfdz1wAqF9)NolQ)5Icwi#2(_~?s1^e8InZ3q_#(Rv7qq2~8ukD}&f5H^rm%iW{R+w_uSd`&$-y_{4 z)g#m6fk%c%n%v&Z6&v`&k?*$eE;ZShdE4$Ef7#2@v#!{)>{*W5cG?cw_S%kFc3BQt z_F0aecIwQs^ec{@_MVOmJBC99o`dHQbnrTaYrIgsp0jxlyI4l(vIju8YD z9z{^#*A{+7a6R+bVAK@)6?7kTH}nAXYv_LHUceK19(e-kigZAFBTpeckj_Y7OgbhNlZknN$-tyxvQ+m~>8h(p{(sn*qKy8ft3{xFpkZKOVMDKt z_o*2}V^Wn;D4ll6&Mxx{K&W4 z_m}Q0H7wmzYFxUj)Tq>;)O2H-`Tg6DyYSwI4^D9QueK-nr|-_d3gd!`o7ejvixxyO zL)qi=xNUVahBAARtXy_p@%Vhfc}9IvpzJ(}Eqh$=5!DwZm>D`vW8LoD+9zHA^GEFo zKwSLW*qf`ZEdJ>`c_>4aw340cq2IaxUEe0whx)>^*|L%6+n#L4{^8(aD~k_m!Jkt8 zNn!fczAnGFwZT!7)Q|c|f)kt4|DJ*#g~eq-1a*)bjwZ@ExF2PXLoT$xd-2DJ)VjW& zFRj%~vT!C+`00?Cop0-3C@s6U6qPM~jeE2T=E zpj(9vQTAPRbJ@F9s?E@>Pd3Hmro|s6o5x$o%Ym&W&!)LCkK0USiX3cR!ZuR|xx#l1J|tsbewU)% zTBpp-s$TuqZzM(V9>+37&4dE1s&V>yci)<3I@2RzBjczEaG_c|jl)##-4_4u{lp{O zNb2dldM7@gl=&%hBr<6HaUBogEX7Z?gTf zme+U_zBy=DBI$gs!!ua1<3eAusk@MuvQcq6ghYXQ`9G;{(rd;!(|l_7&mx^U^B0KK zcM~XR%%Q5Mp8Sy+u#Ug2=_4Vxr~DEiBX!z`j!fUcBUs_%^Ii-B8K$>&I;`8$O5~g{ zk*Ls5YZL^kC%$w$b@Q0MX{M^`EDQ8%?dtM}GJ&>KrYy%lEZsJa#U~@g~8!TVVZt-#6jyoF3LX zYEBh&F6G@D-l@$jwe-q@65DUAvP5-$bA5}i2bjg$#B*V>4@u&>Th*sb;CvtqS-Ry^XZ(S>sv^Y0 z=j>a7ZJ+S=HsJd9cc2PtV%rfBX|G|=OepcJ*75RT#ok00bj>9-l|eh*I!Zs-;YBss zfdq2dvpC4@a+qX;U%jbMBSl?u`Q=1Z%f>cSoi5)7xHmRLa#vK$lbdZ*&b)3njxhhU zO;;8&R1;AAoINvYL*3(Hm zZ%$o@6Wc5-=9BCa%ITwl*Ju-KP;=RTM(Ca&d7E@GAL#PzvnCD1Y-Ukse<-@|3ppZ0 z6~SYHB)1-2ZItCBn55=iZBWbsJ?vGq>$!l*32KeEVafb&f?M#XRfXrMHl&Q^NqutG z`H0!O;}03pE?V$k+y-i=XYiu&>@6`!@e6#AqN}AreoxDNQ^V%ySZe)*f~(bNa?FcL zG=Oe8fpZ*lT$jxK+*grrb1u54?=Fpa?QDer$lj8Gp&M+-9W#eVKY7KZk(1wVjvnQ9 zUrsz8blcdP0z;3__m9_%VjRRHmq#ZOVIw&kW`+u+s>aSHJ?(<(%9v6mpkj^<%43`E zB^4E`aO&cmqVAmP!uIOoz)~nM*8ZNSM)j2fMd+=CA}a!VC2nG zhXUzDK~UPVUl6VCCtd!xegWk5;bLd>JFxfdV$ph9GH*bcyqo&Asg#GSH} zD|iX38|}V~m+(akl7gfZ5W!fgJPD%KhkZE6_bYP1?@Wz`E16*_c2(qoT?T+|+2qwq zTDaD0-XoCI){Dn(5l2LaM1ZgpN>T~p9{-Kb1&fUcZ~%zkU!7&Onx0>RyAl2CnEIRm zbL&O8>HtT>Q+-^;Qt`Q{V7V@*26S4cj>h(5+xE{rs`#k1{}a};6@#K7i18LVq3DtN z)pb7#;t#UpX0NsG`(E37C%-|I8r3RtUxFR6C4oTz>bq%f z$=v~fxx66?ybhOMu(d0hdN+;IwB$&U#C#yf_gac02t_gOdSv;WUnqcGUsCo*S*jP+ zd>s{k+*@)k_eTsv>JRE*kU_?*O<8pwPQ-0b{sR)b{%uX9szheR0)&GRAb|uLG-o z65{}1l`Y5jJp0}YHtRiZ&rNdxnT9fxw*p9~Jx6*o7f0=8L&5+El9Sm!EA3H8Cbdzm_v3nbDQcH-3s2q>!cH(tX&52{sV}2 z`Ec38V%frDnceW;^2W&I{D|q=81R?B&s=De0nc^d6P}85Ulvr5ycM7hike-Bi6wcixR=bo`2Hg> z$ZLl$z`_ZwLoXTBdNIi02o7cW6U1bA(T|GrR>}VF9KA-v z=m;LtG^l-%5iI^%BY)#tr8o5Qhx7pdn>z|1Qn5j#m0MPQbF9SH50W5UCg7+_tj~AL$#0Fuf^*fD8!YE|6^p z07>{Mwjp4&^jfDLU3v9SBQA75OJ)pj*r|FCuluSTkjLLQ!rY!$!8v^eY8b#*h&*F) z7obariO<9--#9mc`T0Bl$;1mFthfOL;^HOmrm>(-rR8f=24>lNkX$9`t&P`kgBTe; z5qxBH!N#k22sF9S)yhAV?&2qJ{HJF4thVUbxUjMp&A$t@L5s$4GSwQ6?rlV0^KAc) zK~d1+cU-Q@9Re7)j#B2HC8L76n<{f}7Ch+OKUjIaW?;N#QV7fzoV>hPD#rwBLZWb~ zx)x2`79WLBx)RPWOU4so6~LaJxc^Ic9^eO{e8B0%Hti%PB|Tqo@c7&;-!!AN|?Hg8#2QytW=iv!hsGRq%ef`}u`Nc$ggkuwhQN z=}Fg9pBw@_y^_q@A(8i`yPR$v2kg9nMwqzV2>&}niM9<+oAZ-uhm{Ag9IA z=g2*TOfS%mqCE~uJeW4 z9}hRrz6(7VWex$n2}Y+0>2>1nelh>79!~C@Jk-(#Ls6@xZ2>bJ@i0~gSL(A!W;tsNj<(5_F6q&6#|ITYg=LF z96;EqK z>>ErTm|gp>u{9u(&soHj!wNan1-k-Dp82q}?k$YHnnztTp^Y1w5vrPfN(r|Nre8iO3Kxv%Y&zt3b0ajyNvrt4|=yln`0~i zaJcjl5E2b)OjHN-?VOAANwN1b0FipM@m)4UR9fUC;}ihEaXoS#lcDDY=tkP9hFa|# zX2ArYW?%dZ&O}Gzi7Wvof)*4?1-?U-QD#LQ^C$p_E?T$bSa4cAFUMmn+eIe88w<+! zs2|DDtt3@DHO-17kJTFxh)e*#&_~ov4L~)>Z&bC6{$w zPb$#6JaBVH5>LOP(EwA4?1EX-1)qvRms-vfAn!`?oW3mBs)D+;rd*NrLRCbi#Gx0A zFb3r99KVIKqCoUX8R;L-cuaL52r$!R#n7zKZTiA|Gb9oQMeE9^seVA7VjS3y0fPyE zI5Iko9IO>EC+Zd~CMfqpFrYfs4p%!;HxQM7pnm}V7nns#N+2W(utCs|3QvpzUh91j zMx)GiDSEPj^buiHJ?`V6{FJBELs-mJq|c+ z^U1RdHWu&%9r&r1ybfZgm4~EAz~_X?wFy;hhUy7WV0xnifR+=8d&Tk8fuBu?Me8tV zE&00naO;wb!A(_BeS$`@m+Poj4nlRn7eS9bM1eYyZXOgWR>{w(!7_Oc@HzM^fR%!L zIx(wfp3GS;|13vfIRgLVBY>lZY3(Bq_C3&z^To&K_O^iSWBZQJWlxDi#4WPDE#58F z3BEmZEzv#Vy$MxQ75s_52H*p^)dgQvR76%j7KQ4l4ZR1>SvO>_&t9K59X9`48uo+9#yS1h`mgt6`*ZuV`g8iT`|~b)njUN`TXU@H`oXU@Zij74$2wpg zavic9at?SEcoujC9!y~5GCqm0j);sfjR=gmTVGU9-s8Q|PVlK6&PoML=Ed4BSP$%A)-Z*Df(U>u%2Jh^i+Y7)wD zX0!}~wsm2&ow(V zTQ#Jzmr1NC*U^tOD@)bsXVImRQfCTD}9wmltD^&<#A=@DSJceZ8`(avA<>`Tuv41&-RF%intjZ&2}L?^-~Pzg)9ycJV*J>|%D=>|)vM;{S!`gx>7p1gG2Z$Txq5lX9z? zSe>IK96^DZagT3L52j3YSY;cSJoB75;o~*0#>f?z7B@pew`6)KDn35DNu>*FeBN`& zw2FosP@f$x3?QaU+|B4ZvbKHrHY{lP@o;$Ww;EI4E5pZWInjgFHSIrToElT3A7xm= z3R6(hQhMUhfMn>n<&n9(23vOW=(DjuN*Y&)jc?|Co1=lBn;%X0JjE4yS%n=T_(TVe zHb`)mg(v1OG;M10MA>;av6BYrvtvCyuwwA04K8CZc7?H4@1W>TWOke)ReX_Jm+9i% z!f)t)RBibF7LI^~bc%Z5HyK!|yQbb7*C}Z^g7Ha2MYq>H?6t9`K%1UBylg8c91QN# z-Bdeis^V*}p70_?A;X(NNic-$NsfEF9cjG}Vy|anP>rus%Ag*-pHz9}<&R&4U`@JJ ze)mGSa0aRZw^NP0bLAA3$nPoO+N{6Dx`0wX+Mzr2*qCKEt$A*~j$`STc=7_GeHd$w zLu~S778F#^NCk4~9rjK$R08K!y}RQo(qaBFH_5femp4nHH#GAG>o~aB2hpc0vy%?X z(30%2UJJ;rO)9g1<};$UHXkwJ2-G?HPR(;m%q<)n6+0(^4yRsJ1p(6MN$U@LH+dvO2MAkDjK9{9B7-3M(D zo1(&8%%>5U@mG)P@+u;mh23p!f8@t83LMqWic@#OdeY~(R?H|$;00mn)Gs2aoICrn z*V_A>ldg3J z=zIN6r09OEO}p$9Ox@WM)o=uxNyJ_2kFs|l3m<~o3eAIk;4Kw9+!BQQBaVNqp9tgSAJn^U;#~_M0ZAsmKe-WC#OwR{DWym&J8&H+d7l zD(QYX-ey@nuN7H~Hl95U-k$F~OCFmG9s1ZKT0A`C@h$1aQ2!&ZT(b|)Nt5Jl94}k+ zxvvLabTM~OFr#5>+rIEg^ek3PGR6u7F6(e8SJ(~R!&9fSo1(o4MfVUF%3eN>f6?c+ zMO9SeHy|p|uH*OwJ2}K_n_jCZovgfSZGhKOCVO%OK+pXF>$`9~7|UA7rw%K|~mJY9^Vssi(7? znwLV6N8?3?OurOe|K%&EPLj&9D>;-sE%!tXlfCI&wu*WE$M(*h!H$~*iv2!>xOPkaXpn5 zPpwU|QjzlvE!F(Yp?OugS&RF?g?yE%@MLWWZ`Ly*=8iObtbm-cN^DSgtvLRvL5H1P z{SGaqBebIc^6LsmSk8&N5JAbw+e!r!nF!WADy2PbcvK7Zv+^iPJlH&oEnwL48&{G~ zTyQeiCEPR%Acx#oy8}n9(|rMMoj(v1yI<6gm(KY z&98D~G zGy!@T1bX@C#3@Zkdup~1rXJuUooou3*Z%4rY2u2Sof*0@^eT~Wm!BdP<6kybPL!JL zoZWw~?-cA8Q)8n2hI?~o@v4ER99>V6+g7)J!GkqT5)Nc?-cx?MDW|&{R3pECs=(gz zAi?!^*nsv%P){f!#pzBl8SbS0Y3&XSMpvHH2rk|gQZiMcz^<=IWW&Mv_oLNcBt;P?W>Ei zlAteK0-Hqr`bHB|1FO-@FX@_tRa%us1!@mT0g&VQLFc)F=f&6mBuUObNz*hu0R#F| zqX71Y2CGoST_G@KhRBnPjs=69X(s>};*bf+;|07179lNIer5V=vSff|GzK#?f18Vj zFM?PQ7NEL@>a=Lc_yP$1@9|PW1*|%MpCwq7JlbE+B#;*rckD;bX#59ggAHBjz(SqF zmD8$k__WaHuOaDgz~qv@Q)PY1d6o*2kkhL}7D*1tixQoQuO%@Lq-&<>q;6KPoH`VCUtR%7Qohf5t>mtk0zHa(t>`Zgd#hYImvr}4am38u!N&jDuR!kI zu6K0+s{mLS>w;*id0|v)b}9f%**HzUmM;G@fXYx0pc+?g2N3iN>Zy_6hX1ohYk?T; zZ;XO#*lIl_NLZaJa=!za=!q_X*b069$sN!fDIiyy3SfeWU6ZOMQBD@S>>njs|1(ui z)QCU9$M-D@>P#ZX?SPm_tZUXl#%5&=59`3BZbI)6arbuQ>Vu8kp} z&U$`9oi$os%L%jBy&rL&a{E4+WO$X$8eIw0;D4pe3P}=?IWM%DW9Qj4mVKR}{&aq_ z#7(KXFDOs>m9WGw3DMk5f9qA?{P|)HZuPp;9CqOU6P4iQ2b{nq1`#RprRtlhOqHMl z7y{7)N9*BquPf*^#;h*c1``~g(WNBb125eTKn3k{Ve-R@*IK>g$fs|RbQB(OkIL2I z!C%m5-SoS<6|l$v2tU5@kY1PmZvc#x$rslTiA>^F>uDvcGw^^gtMq5;$0wtsK=BJz zo?*FeL>#!3!h#<8lJG1d?$;ICWpvImI%gT3^FJP)WAfh;Jik1vzrR28MSoDwu5Nh0 zpIbf_8H+kQR#84?)yHk0h!aG{7R3hfD{hU)v2JlArvmzXGdBzp^~R;u++UCSzB}|S zE8(fgQ#PI!%>qeUk!`<02`|t6GJ;C-}-WEuU*Sbivjn=tAaWZ8uYT*Pfwx@v4Rn&RV&IKWShs-x}^#wgN z0M1@BGgwve3-I5$%xq_?-%}({w82y)f1#?Y!L>Vfuhl;QEcC@6s#^vY5aME>i+C;V zuhmNU$_5~MdfK5SUi-`k#<%Uxn4=BezQ5#A(64>-DmPLuGmvr7UNyZhxwa zIPL(icH^RRSFg|$qPM2dtJM3`dM@ZE0_AUEh%eEt3U?0aEpg#OgXu=Db$^ z+AHH(SFOU;Svi#(754)9=OZ+s%2qBl7+(FW(XhCU~g zu6CF2bp?DfkZ!_cpdI!s!X0LjNXG0<;zKl7&_=;N|sZ$Ns0}l9%#b zvt;Stl|(bAT5oo8i=%$Q&8$8;!~kaD{FY;_i8kO=u1|T(tOZu|FHPK<$PIw02Jlq# zC;qs`B49r0&>)eL1{P<&A(UL16Q5mj?fNNXb*TP%xc?`)_2>en_Sk}H7je`{>u*?{ zcK?f_d%!c?{L$wiYS|JU`mKeVWBZNe257kfT5f=r9|e~m1^+GXgjyqYWIz3u(ZpZE z@3sFL6~sg!7QWG}U{_UcO}xDl!!EisE&9P+_Q9j)8dF^WE5nWyPG&SiCTwnZX}ihE z=dgWs)CI)0gYaGm;`+pbygViJyKU*F^@^kEm~5Y0mp~Z)Q5SRq4X_f(z(D|!dI0I2 zZaKg2G=E-C6$rz?3X75&i}m{TwIW(aPq%qAUzJd02}rdAs`UfPGpZW!UFx7ksHSFA zoZ(tJJ0P;=(_R(%Sqqr*mxL0JJSkv7Fl#C7Yw4Gr`V*KN_8QF1u>>&UQZ9@?4{#0= z*@+m1ioS@Q22lPBCIeBL8qu(}9qn8*?Ka&kx)ajb-PvuW0$QyCri?T7$J(deZ5g;v zE_@dOAl4TI>GcZyT9^f}92t3fpc=QxNT@>r#s%6$=wR~jqLe%^|L?hQU{Vyj6SNG0 ziejgAd&LjsG--wv{vd!e;0Bln$RhQ#&cvAv6Cjot3&56uNs~Y4c-VmK-ss$KLkHir z141GaZNlJyzFlYWkeZ#HUGwLvrTOucM9 z2mq1m!R7LA)vfcK{TLlEpj@nZr8S+a1z~mh{KtO zjlv2cgt5l8Ycc3Vz0JX|oEbfBq`Uu$<^Yg)CN5837PD+lVbOpS2C7)7iqQfwa*05o zrOYk57IXkf;WpYP?a?+Q!*n5{E@nm&;4RVCuJMlqOe7$oI>p4J0|o=4t9^B7;N^n9z<}};e^4Rw#~^?# z#*iBNVQ!7m>6AL1kT#rQGA(~KGd<8>2Ug#ERLG#H6?#s}XjDoG?VJl_kvy{ClmU6; zl)>$MO$Xr&)=E2+rx!{zgkwd;V9=ulEfg7w;{jw{+919st`*m5lnLtOhqOEaQ}1@6 z=M895plZ~|dIEudDSG#i;T zEJt8D0{`P9pazL?{_Yj)wRV5WHYv!Fp#MNFdI0!9E+}v>+D@gbDswm%c&DiTMtyT{ z@YJUX*Ihh|%50DMMf`6;%Aa;4|4Vc+<6wk)#P#~D`r7&{^>_9J3>_XiIkc}LyYc$t zgMprbUc(264~3KoN(E&xj0+J45grj%5n&M~5q=Sl5lIn0)u+{$*XPwY)o0cd>T&hd z>)A%vjIJADjdG2$jB<>!w;oFK@;K;mNXpnexh>*ee0}_{@pt13<6HOaclk=0w#V1y zsEenIzsvrcTZ?ua$@9)Wd$8O~cfI@C*+XH6!aT#gblC2!?k?e15xMcj@s#)<~5WR`P#IJ})iJnA%WoG=<_{#W#cv5^? ze0h9ceA6C3q6g8B=yUGc%40j$l_fJYjCUCq8SacW#wJDyBZPq|YcI1ZwJIwwEiVfz z4J+f8a?4CgP0I30^UC~6{mPi7%reJP$Fio94>1@`K%5-udYAhaY_Pcyp=j@~3+O ze)#0iz|i*7Z7UgvM!$6Y{9uLA?_ZXLmpFK@X0G^n-G}7g ze?G)%+UGsnUp;0^m1tf;nRPqmLkZp0C~t2K)=I_;|1@Ho^Kr&O#&_7yB0hKaVtmUu z%s9!|$2f-F-hn;i72p}*HFPj`a|^Z+n~0TRjj(Qj`vj#lyR@$KOUC|V+a`B72QrQ@ zf`DKiXLvCJk8R$bd;OsD;Gr$=T*>;*6Y2@^%rB*mnRWcHjxRgln=fwW9_r5ik1ZFs zxNmmf!u7JfwkrGNwhgu|2Y2s_tsVQQqYRsc)m=NewdLgQL%RcFQ_uW5cBtdK4yTR_ z9lqGgEzE=auq&`vu$)8M1M}6(b$ol+w(>~Z}FXT?SNs2w`6ZlIw-t; zIQMYY$d>#a`CD58=a3yB>?C#rHXrMc9mG0gAqRyy&vHhrTA;gq<~%yy#Tslm)sfQi zw+@Ww`a|nI*LzuDcV_L(aod^v^}WM$x5vKdi0xR}F^bK>LOhGFxo!8~qNqF5Ft)v; z2OGATcBucL_H#w$nW{184wnuwR*GGXHQ8*sWgYhC%{yi{7i_jW&SP)j1oARmR zo4Ks9kPa&>8H>XDY{7ZTyuf)WpDB*Tei{2s?7_L5v6`{S4o9pmE5+LRbA|UDdF=L> zLB|&znONLrh*yS})>rxn>J|ITW?$@9XJ4$X^O0kLvBc7h((2OUu~nWZ&v!lDJMysy zu#d5;JlA&kV{c*4VkfZcvANj8p1GR?8HX4Fj04zB9Xa(m_1X1#@z=&YT@KA{DcW4L z;d9IQZ+sV(d?t8Ds5z$h?8>g{Au9 z^1*+@7|;kq5V4`pC)>81v?w;~HVo}D3iUDK9J8S8GOH*usW7duJe+S~w6}0Ww_*Rv zvj&du9S^(g{-~~V2kwU#E8Kpxn)#37Kh|FOn~!?-gXJvG|2MC695fv>VN$&7n?ucR zZq9`!Pj8w(EjFW+xuu+P`qtI%TgPu9H+F-}&NGujoI1dNBU7<7MkN{nOvb#%65BUmCXIhm!F`yz>6_ z$(P@><-ctGzN+>6#F|`nK>)DXi-Ojx22)_|Mq);m)idSvSN*+DbR8?fbx>@+TCABJ{X2 zI0W9&Z&^|myY&V$GqI*~<<;riYl5b4oi`{*Mj=Ii)Z3h;Q!Ub)MVRo2>Z17jC+*>p zaKg;v6TuwyZt;b>ti}`xZKA;32F5LV+#MkE-=k^?nP6D7c#3f6cb!SI+yG-ELSTE! zJT{3kUmjk2vp5t!kN%E>tkj zM}S2AUzMHJ6q!#4xhe;)H`NIq*XAdlRb>P-Pg&k*%kul1P$dHIW;aGke%qe>B0s#p zS@e5oQqac9Xxn?8!V37XU32%uN!1laL`N2{i6*U!NVUhxL>13E1viN$T|b5yf89xm zoHX~pa7EGPUL@E+MhrDPQZ(Luac#prd|DF<1tZN|MP_$XwXd`e3Bs!rgTa$*_E}Yl zds^zV?lm^kKNV>Ns!guLaLTz;ow9g&hK=;_(}$y5geqNcSWSAdD@$G7T;22=s=Dcd zAmLVclB4Su!XN!QN-gT#8B_sf8;2izKc+6JOYEuANwaPxJ_xhBNghTZG)8-+Lf&rL zT~y7(;O2%L8?sfKtrMG>Sjl=)DW5!KFBOA5Rn%mkD}te>&F8=W$^`!XxqL_XgX~+} zv8+VFjAPf;VgBHDn}Xp}bqw+vTtf|#K(UL^5!oXTwF7s`H^{Wf>?0*E$>r-fVa5K5 zsUYI%G=G*GQ@g8_YF_n}o90Hi1JJ#T|8- zn#OEUIf^J4ios%`VrrS^x6-+MKS}$x1FM0QB$>_xmKWco&hE``5 zg@hf3hvHyE_iv%;&jO|D&+OsNc@jmtjo`G&vva4rDU>EhHQ_zBAsiaAG5V9RIx%#w zT@(Bivv=xjpdDu|tfo9mU~06bh{~#TZsgZHc8v+mTWxHTzAo{D8SVkCuNP}ao(VTV zqfE_2@U+&M7)kfVv}Y{l9&e>uQN>8n%bdQkGA&W)5u|gKMLwlzp?tZ z4p-YgnN93=J@XjS>N2VA$mcwdR-Pja)|9s-P2Kodgva~c>^E*`;xoLaQPzmz+&0AU z`Rtu=X=vrM9}9^`cuCrUsfNbcQn z_nx%~*A(UNjSXUGVMOXPc9SZmuA&|XBK7#xH4tDSq7+gs zbI7aCj?)O5)Ev)^uTPE(-m__ObAL4l;+(WzVq(Qp-hj48AJkv6+MLB ze@CUQ|K$q@!BsQ+&_BYSKhwa-tb&4@4ad`-#vKF;;eD#dln76* zF7dD{>2?^T2Yu#`VS00&U(D09_D$XxQqnOZ>!kcf{`_|yNM2Ty9Y*bXYw&KZCzHyw z9)M+7~ z5oP`H!itYf%gq$qI>D~AwCetsN=<)X_A0Zxrj3LSIgrA7+zBd<5F^X1$hz^1=F<-< z{R&#rT#$9DM)CSk?_Kt6KHNYaeacQ-Q2;xB?2J|ST`o(KCx6zv*V>UN(uHyc=VG5u zYX`UMxR8#tiiiH^f_c5IQ${YOqV4y4*M4-c!{^8yv3K;ikRsETO%Ur{cR(#S?5UZP zHX0}8I7Eb|yLx0cJq`clcF(AK;KD-;;a3l}qL#V0>-nhm3lkd)zvL#Picbp82cvA0 zGAuGRx?j5TZp>uu&k=$6g@4PBd$mH#=)U}X;6fC8W-g=Fc?~8bxF9$axT>l>OLt

ooZBUZ_hYMGsaBgUl_Rt zx-cq1TI?$KP3aPvfrNxEa;kqKLnd=w-87Wjhd%xGBLb|ydSYeDr%NI|@wdgN$ z7?r{e=tQHBZ5&<;Op+se`SX|44WzUy@@H2Rh}8&;_cV9z%Y-<|g#(Y>fYqe^pxtT3 z%y!bR^ttNA$&!Od7B7RDTDqiegH&q;vG;B5+Q#LiHPtTLACKGfU+62XPDG8o^YF ze7FJa%xV2qhp@O1Px;{odt>|GF4XM3*9)B*^1Tgut~#zD&pes%!cOSZ0j9)NSyN3e z=O!HnLS0#$y%*H}iep%-Kv)(wJX)BnFKy})iXQ28(&rTrO-m0Bu6yq>N8F5jMh#OdlpCERuP%(E z0Wl+QPXt$;r&Y>p&e5#Z$YZpHtZpm>{VGU)7({gyQu~qPBgV6PNkF`*jT_NEp5N86 zt;1l88JNaAB)$1F4OR$@TC+$0Pt^YQr0NDL4%lNC`jnN5z(WjBhy?7k6#Wq7?`Ud9GxLiW&=q~< zo1ri+%x%Yd^b-~#$Lf$yA6lq8#fm>|9 zht+G2erd1z9{$}sk1|dpr+@!$4My&L)?DasFGhiMZd6Q)sX7)X6RY^$P0kEZToa_3LHC-QUQJy^%qD;Hc#J%* z=?#e`KU>Rh422}alOe!g+u`DLVy(EC|3yo%Y`rddSF+FDCn_4Un|XU^_`LoM+#RkW zARb4evSL76nWO+YXRf7}Sy`E&vk^g{SA}u%CfDlrF+{!<&MU#@vg!`-T2Y7$X8E;& zn80w;Rg_?bcIRkiF_*kremt{7$`DUE4wqUOEgbCj5n_TJ3BJJlyTgj9#Qv@!j_}Dr z{ZO3yL&(r`#$wR%$6U~u_O5`tA#y^0dAaKib)kwr7Poz2w0tW$o_IhIrC(dVh}Ie% zTxpTbA2*BB@j2g5o)W?m%+nT6ET6(?QHzt9bEZLqN*J_8;M8mtT&!+^+`;uc9#f8i zhes;?i$h>ScXlA#2bf5wTZ^$T=Jm=%w5b4N`AJ~D_nMTBPHmu`iO6pSux1?MaCKg> zAGKtw^!&m>c_i~z@tF3y{K0Sf>BQ+5H%_Mahs2`hoSkosJNMs`Ke!J57HKLh*H3~r z_0^w-RA}#3A9WoDG88wVV zYp@Wu=%-73_NslTm79`Yv6Qt~!JNg}X1`$sYox8CxM{rjnLggx@^7Ry{Nz9pwK|X1 z;@79yH4>)}8T;&f($j-P!qD7ZYjTCmVMHHD64wKz<9s~RRbhA^e$^ry+xc-mc-e(hegN=^vHa&ZpCMA($4&}TA?@TM}#7K=}dv{##! zQo;}|#(nICla*(DfP>fDRiCgowf;70Awu0wT3fIXwp6k`4Np*?$gDu5n?Uw`z~TZo zOhWq5;A3E$+RHy96fi5=fl7}fb&zj9Vjh`%Tj@aAp`nXGOX=1Ss->?N1(a*&pY95F zAR6+@>9a)mH$YwhLj`Wpajy~r^&axB;v8M)Ry zfA?1A+V&1=!&YgH=i#uYH%|Jf!>^&KZA&How%T!5zF6*DmYcWb{&l%`U+&%iPxS5@ z#~FfHUNHSV#vhZ6iBEZ1F}uH0Rc_0-6J)}7xlm^3)wNl5%2C4;GLQC&Uu|53Gc6%; zezNwa;DoU!vN#X#!Skgp;;G;s4NIKm6aK!3sCEK=CvPYf`o7uYI9@8TS?M6Vi4Jkt8JaT>u3?sPvhsJuO~cJ1MlU5b5Pg8 zBk~(ohbKLp(LG+pbrVjaN7=Sl7qV)6;kjst%pDM_JDVX- z0?4;(a)EW11?z&JqQ#0a@okLVxN;mkp$c#0c==QB1}d@44*5|Hayx1}|GGkqx-P0y zBF4F2X$swJGWZ;?WQ*&{C|ZYp+PF#iG5s71G408^P_3Foz!@sE+r`%rn&JwKU1;$! z0_V6?v;|@!)*J^;R2T1{wYBQzC>^2>C8|&kWm_vyH#qH5RoSy_PMQLBk_w$u?GjOe zxDA?Imv=@0UyO-f_ArnwD&){AQQp01p%>bcD+eZure05&Lw%bOhnrJqzsH#($DMG@ z47l$m_&6^{(Teb+wC2$WW=R;Ig8T5X1bsF1-_aas3J!L}`4v-vYO>PJjs$PW z6IHrl{OsCU^ZDxFg(@4tr9j7%i3sOz&183FAuXpMxM2dG-vlJjO6lc&kb5-47^Lmq z5(-V$g@HAm`m`mDb~d-_cKZOeEJU(m2H{IdANETsB!neF6S%(BaTJU`_!(Z7PZfDU z9$M_m*wPXpYe3)^>f^3ochXCgpSQCpwpI9!L6x&DKSwOkBEh2A6je{OTfHmr23S4% zW8P|c4y|dSz)Y!+!;Ta~|6G>VZ_1M6_a>=9&bFu>3t0s!mJ1%&aK)~tF7pE^IMoeW z357D=`c!dUY_00z>&pTfN)0XAjk&|55@I#6P$1Si3NYz~>eIi{zhC_N44}+7o@no0 zS{9IA^){(8`8%M%fI5X310|z1=W6m4ULa~F5d-g|=Bd?Ac!Xt_1i77`tBewr+(Y{} z0hOr$rj-Zoqjjjk(-WPdoD{Tz+-shfEJkj@LD6yB2(?Le7GEqRngm$YM~9Q?O8n}5rk5R7!pw#? z*GGh-dC~ay@Ll@~(i%U2Z5Y}SFQ@n8Vf*ACz)Xj1>rvsS!jFch(KGNx`&`o!R|qx| z!@`^BLi}^O1s>#%OqCh2tm}PlyNWdX%2JzfW z-bNtD_(Qy8JU1=)!+`ZJdnFI(5%?Z=yOm@o36{PAKaS749K1eY4I-Jh4)1rldsWdo z7uWjT@!sLSJYT$DTJTx~jF*6Crgg6-e-}*b(W_V=D-(IH_yTv%YLuNsLpRCdL@!)iN{>lt|M-ei0DOl z-C^xK_tM1WiPZiTRi*)U_2}@6^vn44eWqy+D_|R^EXe>A{v%#>$!Mjy84(^H!@EYuie$D@CN z4{Jq0@oLb{`aO~v`lfyDY2F`zc1v#4OWdt4TdYMy@nZ16dn-Yo zNgmRHnlS=xk@V7g_W2m6rxmY3oaOoB!_tb^x;RTl=-ct!)M87-SK-sd8)TD zd7I=F-D_XirQEdOjR^nn{k(C!-=*LUF81*!c+>Q!bZ0zJDyhNdh$FlPJTo=e(#277 znVyVyysX5F-Qg?A&XQdEdw4^4_-e9^L`HYrCt4NsF~X0>!e4b4tspzbd-53cP`o^q zZK?c{cM)%#YHN&gir>%MgtvCLTphGd`5kW*U+K=gRJ|JYxukec(0so-_W zc%D6;b{TAg?JiS3R>tzI@qBmaYLtycN_X7HS%tEWcjrB(AHl=ZC~Jw%5*A&) zk7KO)RMJMz+y^mYeWdi{mE!~5A%+?^$yK_YJ9VYThsVToE}>18K|DDga!LD%@;I-N zj&sMXU|A?5d2l>6O}j=J#)INr_vwsS83U4O~J5N0v-d3|VjC7tCSnAW(0wW%H!PNK{3srdbP^L?(D6ITk%hPKp4 z@`~}SeWbJl?#N5BRf2UxuJHsq7>{xXD!_aQ68|Od4?1d}L8_4v%dy@m{veM|KZW0d z55Ar6`am}WQ5);pmVf>GhCtV+<*HFk#BW*Jd)hfIZ~lSr=+#`gPf$N}$ddDYe9Gvo zB4ksTwB6+HbY@1^hQK|Etbx`X&FrJfjFaO3I8%JXYT`=r>BV0kSuC;v|cak}`D$MPNG-+oVsp`(~?&Z+E%t+)SD zNgQZjpWbyvSI~49LXlQ%B5@}C2vmwZf7p&6RL8qpOw%RdbdGspf@(q4$>Ia^SLLyf(F zi&*E+&pU!C2);>okxT4-|Jvb^9WY9-u_h9-2j-IOM#xVv&+Xbh~_{9LxDXz=b} zOc82`1g#l-Q3?e`&CsER&a<6OLa;!^N5y*E6T9F;RP^k%i7Y$q*Q$Eg^+wegnjgZN zl1P{+qoIx=W_gvT7V2)&?F9uTUIp0gle0b)y;j9RbuHTNrF(sXP`67X8 z39G&^;4270C(7^hy-AJoi~`Z{NtCpGNI??0ekWsU3>-0&YXllgN_)KsD*53nQRXI- zCNBqNAgy>9dd*VgE)rsL2nqM$7G*Ou>=o1@%(mLN}u=<2sbo&ATDcB^m6$*XT>`GIRS{hMT;jKqZ5 zf6@{y_oE+sO) zMwvFNb~?>x`+s)=hMlG zWI;J?;Pm;v2WdtYh{3en=@ynvkBr*=G!mKFn%I6HWJ!khSSb2smT3e>C;7T2sl$K* z+;xvr#G%}4(Y?q>z4|_&3-2wsXV_TWsErPb?wgL&w%!*w3J0P{gXP2akT|64+yp_+ zCaIoxRrsV(G_%PhVKt`;o+v{ZB7@o@A@ zQEU?MJo-)5lB_@nFP{Ot|4fe8kRMgVyB|hv3P1mpWmh!+xN1{RKW1u^W7LUafcitA zQ7}E2whEbyrmjc*VJr`?Mm%fmvb30^%O2t^-JBFNAOkLY^b{ECGPMeff6hiJ_O!O-Axh0Ib8w?NXz1y;fq@ryMg*K_BeLv)Q}zvmsnXmuj? z;OuN!Wk$fBIh~p#xzm56F1cispV7T%GA4~?UeMIP;@4?s_?Z&XG$S-^6SQ@*g5AMy zJ=4XtY@K*F5{!A7TVSLTaBKwt`UI6`W3n?@Ibuyzeyxb^wR$=)6fZ!DlJBSb2DRb(U5BYHxM z=SjqAtEHpyI3b3OVe8Q`BBHH{PpP=*pm>$MkZ`o2X#qt7K28%i@W6Hm!PQ(dR2!*0 zy+axqjS^txzmk)l0+=iQnqiEvrb#XAIxfQz_RAH#M_2PwP`ISSdT3a2if!Nj!`_?5 zHF@WI|F3!Vo>GgDX`PCSkZC(gtqWBYWJz*L+A>xpwN|OJq}C!RLX-e8`*aW(t3nkk z%9f@Su(CvD2@sN$qEf(+h(Z)dq5?@EBoPvFk^TQm2x#X%OXoi4Jm`OTWI%-Ey1rNP zjlAEVPs1`NN(6osF&*nM2Pu(8&2?v3q*7>>!f}u&6X)-@*E4=_5s_d1x6V7U!LKUAy!Pa&)=)>fH^wiG~1eh{i0){H7=NM;<|iU$87>ku*b zc|PSoBC*DV*AGK?F8q?jTJ=NDVbpHG`LtwI?Z%|i`&+$`sS|qu0y?oXj zRb-12lbRIZP{`6_E6mqZg%>7`EB3g6v2q0w7BEy*rAo)c{K&i(z;%J|Pn`0->YIUJ zT!FCSZpBd5IBl5|=!(bHSwiExO7W~Ypi60g&p~7DlnCzwgh}3M6{_S74w^}RXYM9V z9307simFr$j!*UMl*N{?4u@^`>v9rVCpaW=#Y&k)Tyg*2D&iqi$_==r( zz{%_`-fcc6o(0ef@OL1^;?m@FlD#(apWw32Sro^1)aGal{~}F&gUj(dk10A1a5m_; zM7P}LaU;a|T^|6mj*B_wnNyb?bdw)cH?DJXi-2rIRgn)l4_^US@W|TBWoJ-$AGok8 z(im}?d1S6Ta^jgo_IVBsZ%-il8d7mO>6C*JE3JKB(|!##_HYZB0zhO{>H)JjH%w?3 z0huB_kCLV52AEpbP=9tIenFT{#y#BIE_}^2!a4Mzy($Ue#O%NmNdU5=Vj>t@b_6#| zZRx|b2QJbpckBy--bt`IfHd6*;uAzzb(b&)T<)`E)~Zumb}Y|V5jI{Jld|M10JkUI z>zBPLen)MTCH{@k+vTq7=1IhoF(@=<4$qt)`pe6fEI=s>P|E)Vl(K;MEg*jX@`#_9 z{(?P(&Og^+SNW^_wRySzF8{WMlm7@)i`T%$2s6rhl#*2x7F#V>Yxt z8d%2pT%d@IwH_>3r>SW&eHTaZy|8hfO}^QZ{17t-%`3obp}p0yycBS6(S3JFm#m!v zqJeJ991Wf@PUJ0>C^R^N1E#O-x3Vc4rl-e`Na)>Wq4+!VkX4=Q&=TaGijQMY&#gqI zQBa?YNv>i=@lf3}R~5v$*S=PhAn=DA7xVo*U+!3J>kRDBO&^0Fr6F81#9ba}r+{^e z2-eLcGx2sI5xhoAt)8n<3gl*EOqUbumUmLG%tQF*fLUys6x=anHwAr4xF3!0R=sw2WBLiS} z|C};m^XNGytbcyj>(gG@hFIxb@5crje(-BXjdI_);&d$>502+=WDA)IvUt`(4rTqB#dJ{6vA-6pK90Ly;9(6ZWL z(PT6_-!)Kwo_zy>G!CVtvJOqA?>0S~bTD=@@!VoQzrX@dgqS-$_1v?rTE^RaV#9k}^X#hK3Oo5!bc zHQNwJU%ltl3{We+muZi=-ij8_PkX}WI~F#V|E>zRa0C~Q;KJ^{u)8~U_gO7=b?5_Z zKUNC&!8RSC+?}%h&~4go9`-(yiwi^2kB(H^X8!h#_3riHgQs&tDVhgEfPJWm@(I3S z9Xhu=aOGi5Yl}FI4pByy6t>+ZR~rFTJ>%o$8@3YCs#37MR<3;NHEylcx*mY_RMr`C z$LYouJq3y}@MT}!5*!J!PoXR){1g|YpYsg( z@iD>L6mDw@Ajg|gtT5HtoQu=CmU6fQiUyx6$Hc?SVGu&TXt!xsu1tSkQ61c49yit5 zR&zX{BJ)$my}Z)iJV5x{jroE5J~r&?U!yB<7z-I4qpjVhc*(~ANe^j8#anNrAhkCl zH4(v@9?lTfh+OCRIskJku=2Hlwh!K5eM{zV#+o?XisFc1>Ej9D5mXgpAQ^`rFTG(8 z7E;D8=-%qUK2$ZgQ%4;vHc%FRiK{o$;(1PXJq0T{9&As=exuu<*f)ERHc;PEV8KlV z?*fSs190*402hxNHOJdWEgy1i+zSqiU>+AQxFMq;jbY=m?aj32JR`DD4Sn42ue!(^ zaRB@FxDT;iA@Mc}eAHYGq+a4jC_l{UVPGeex|Jgb@i!cIJHa$dgDYgmrS!`vyGEr5 z2jo29$@!Jw`(CcRuYx|Yk~@XBD0YC2BlwkFI-m`-;-<(qfFEEjfIg(b<JqT0ib*$^7F>DDZEF zhbk#N)C0xsxZoZ|!*CJuezzjZA(zA;S<-L(wWGSV9TM5SLr=Yl5{Gq?euRIwra&xE ze525UY3ihhv7>H}n%@@6`nrz?(N^=FQ??cIlSL6r!9T#$TXb(L@S_z`HO6T9dT{bi z+0HNNC0-at=kf8|9!5#+mofBS$7z$yY}x>v2!}vojw8evqGwOK#?a@ZP(74HNr>O5 zinR>D>K6?6=A_yj^QYPV1HtMzYv8~;=L2X2u_t8T?C#;LWRZV-Hz~K9G)m#vONE@3 zGF!bjObtF!4zr(l6b~jux86C3w^@_>@)c3fjise)ZY+(^p81m+mAjoWdse~3He-l( zqb2gvW79>g0x8#BVc{H-%cuqTCLxZohzKuijzE@~6#PWl`~^LtunzWxHo@TwIU}7MzL&rR2A7)ClPbHoT4ZW|Cd} zD;JO3wcbVS8z|n1zH@Kt6xGH!$gcR0Sn)!5OLXG9iFG15bA5PpcH_$uM16Wyobyeana@+ z&zF3bwP2LrirSdw?4>!Li+omutgHX?efKTQc67wa+{HexwP2MtrrXI}PoE_rf%+ds zpDN#F4u``$5r5W;M1U(t!(fOv(kRN4qBqfbFC#+IK38sJc4Yf{ktoXVMN7glixF$| zBck2m7!SmHyM+-elUM0mTjZCkT7qcQgxbLLxSL#Kl!~H$E ze@^=#k*F+Z7G&)`8RtuSU-=C)0G*K4<5{z~WXb;5NPCohO!3J$Z&Ivsnu$l_UMle- z9VEcAdtlr(dbub%-0S3q#oR6WM$sTz`w}-GZI?2ZNyypY$@SAqMXQ+G(FIvjFA`n( z0uz>v{9{X){;}vaX5Pu##njDddz1&6c{#P7u}fM4^`}IB=qmIQX5i+RV!ceCDDN?a zCz0N!7^RNsf0DP%^s%y=$vRmJqi)eRh{R~wOVj{;jmU%PjSkA<`IsV=auES7bEl&8 znIae(^|Ezq+IFQMlaZa`VO^)!iTuO)i>+_#QA}(OZHXyaNnpxyY=1BvRAQKe;dD>y zsx-2a#pGw%d`%Hb29utpdC3}@mY_uCqin&_)>qV_z}-McE~aKu2XEmzqu}$3o5Gc+nf_V4m8K7sCQ%d`>rO@J&x(T3 zvX>=)XxW&yTRF(oo=owi2I;vX2#tEl`i{O#l!s=Yq`hi7tn3#F!%;Bn8~Sb$E8M=s zl&%b7%Cc!*CaMy~L}lCl*s@-)5MjgFIkpw1Xr-1(KWSTTic;o^_-I15Ex`DMv83>S z*H}uFiD*o6jsfQLhmd#D7@~KXnP|vky{FHsA)C`aP26_H>xew}w~kORI1v^hy_|F=)u}{U)Qsstjl6YN6l|pBiDC~mKedP4L z8DjaBdwyJrUHi5-BBAMSytP2s`jD6a!L&k7srARo z&0?#~VyM!#eg_ro31SagD5qD!sW&r=mgGKYX@0uO4Yt%a=A}lVHsOx-?8({emY4j% zX%-xsyem9h`#^u73QmPoB)H}0mP}!_EJ;k4pAy{086!Y@;2`N~zFfskf$L+%i~*uRWQ@|&GQGfl2C1&JrQm_|H8!y=$zMP98zDy)e#0$u1!@&uw{G< zpF7>J2=tN|HLIo}Sd01XzoK4u256(`Vtu=dp8 z-7q0kDR8;?^2`p}!2weq{qen`yXo3btrrn|v3bOg?QMB_6EP@`o4g=DiILqmKhV)6 zHhsN+gI9@JXFyc?id#=i3-0mk=3B5MnjjOaJ&xbeC@7L$?|bwu7UcDx7;rCQBc>ln zyHid?nYGb@@Cl0is-b7SuPCgJdaqt91fXuA50-T6=F|&;$p-$xsx~2OsM$KRL6}IR zJWU|EwNgsjP#*-Y9HQM554=g|eE(2U!T|0?vp+c%FC@gZj4)J|8_-vTnxGRF5A9)X z>-6gaUpHGBTk=3L=r*M2H4yA|w?6A0-Kn#QEvW+#c7~S_1hYwfm%vKxV^SFG(|5M& zMAQLnP~MRY^JuaBR#4vJ9`4kX;evbMvQ|3BZF8tUvAbTYHUAh;i1RNN`_AM?7p$Xe z7=uA})!$a}&60p3LbWYiJIcC*YSMPFlOp2ivE-+G#@!*yO)>Bw5>m+f>jYSMKl1TL zhK6d+m9f2t06@?c*7K^4pB#DD&^}=t=@6CB@QT zG?-jVPt-=?mep6{CV!Tqq@mJQ<4;K>OmySqBOMOK=UUrvM;Hjja)A`KjA5`AP<6Wm zK`ASf__pd$5{p9`Y5OwyXm$^{*a>xc0_;R(K$6P)xHBW#DPZQ?)~hxkco!hqGT zdb6~*I>h=eSDa!xEA1|@H*?MK4M*U|Zq^l6(iOkw*06Y2B@m~eJ}QMwBr7~HM{Ao+ zHxH|DX3g%Rf>y(q>NYJ;L5f%8}muUF8-GmO|0$D(PIRY-;xaLdqeKrw&yH zqM9bP_cZp?k^A~B6BnZkS`+fc1{$QtHVzt~MA!p^{?I)cymwH20F}%lbuU*(!ziDV z)q@OiYszWA$Y}fRN>uB^C$Gy&w(8av1u|sBoHm}&)^^hhqov^c#Hjn`Hh0Bh-xGq0 z%*s3bR>^cpMBp&X{Byk)RtjB^{~eXwYo0oUQc5OO=EWQKDgKsm*8_7{uC;tA)w0bQ zV5r29oG!PX_4%Zc#5$OXEfvVbGXq7_6?Y3+BYtloCl#rSJ?X1fWroNw)-wUGF z5!71lPr8->gKksUFzA`G^JK!c{;tk z7%kM6updVZ;tbM$PRStoQF63r zMyn9v$uY|@$Bq@(rtGkS_vo%*<7pYtFCsLyEAs7_t+*!``ltNh zgE<~A!ebjFIdQrnxjc{0g4DO@NrZtENn34Y>_eONNm&s_sv88)L0{Y4i}LBy##6Ac zqwi}={H@96jvxV+`Xr*XFDBLMMJnQ_G1|Pb+^bYkGvkQRJf2_JL7yOke|jp(R~D0I zrRCUWMnXwe342*FLHoX%EIje9j*Bh1XTD!#zaMu@4IkLqNRQyC8NSkb9K&WgVSh!K zBrlzuT7~oGRzsQjwg;tCcR3B1eWd+54FgdNRbXNsmO`!U#wacCJB0PF@CcHx!BH`IS^V9qP%=^N&rx|J^ye!d`)CxLqvp_Q4X;Dz@c0eudTA5c!!S)YY4sP_PA4 zu$~=8=$!Qo{D)F4g+gr8XhnN&TOwgV8+Qu-(MQUXydYVosyVh|ktznA4o z6Q?(yc%wx7USUGha~YnONu|IWXa@0DAeq%#*ts8BU?L$tPtL2H_BwV+?||+y^8jp~>Cb@aE^2Y<5+Fu)4wIh; zQyjX4aKzf;E0@Y=Q&JJJr-yP5(7(cPrY;Vv;&%RkF;%SrP}<#_vuU^HuU!$Ut=sYc zNMUt9$6x{Odp^vQuWH!=9ULuPGyIIjb5b`=QBh)*+&9udNq-K)iWy!H67Z*|f}yJ| za93>V;5vKcz7nXIyvG%M8)5ewH>0%4R-oOh6qLZWSae*w*RuX7-0RRYg$ z>Xo@b+|uvqSj>I#@wwyXfb-@ifBssUg|MDp8yak7w>n~mE0fNd-#JYtBBv+r@*--(sU9&Dk0F`}Eq|C^$IJvomq=&;9QJb3k_MqY5&=(3$fZ zU9zm}0yxBz1*a}}7x-3KclhM1PQiCB;8_3isPlp~u|Uo(kaPd5mz9 z@Ty&bt#74|->~)CW?n2C_RbT+y@RVzywBq4X-9h=d^haL2zADes6gBr&}K)56Qp-l zJj=tpFy2EG_|Jx!rY9HjGq5`2h5QZblH-GabFjbLS;k%3Y`BlK&ux^m3kQ(zd2g1V zFfmxu@OfQk&!ro)S^S``AVPaU7;>GcF2cm&H+aPzOFk2OU5>A;wO!x76xtg@-Q|yj zbzO43P8AS8mo)x+=4`k02@1Aad+If!ohW>U#|KP@J=a7KrLsjk7Tfd&4Id{2Y3lxGp#v8j0S7iK|z>ftD(8&Gw|?2%7%Dn&4O+O`HA>SiF4uCSueY` zw&H^ty48E#zq8W1V&%aS{fX(JqcR}Xo?ME8{HE~^S#>C(h`gm6UoDuo1RWg+<3iUX z(^&z0?gG*G+Fwj)N*H*S;}6;w2>0jO7t+)N4)1X)Zkc0Cr}Q`;T&vf>>>BG<7uGSL zkGs^HaLII_OwmcTIlgkZ^-A zK?#~U_fdnltafXTN{9JbM=p>zxrA7WBBk`15{n#*t(;$2xG@F$ET0xVXHdu0#!A~k ze?T;=h`jZouJIQ}EKc=*ZLy2(k;}J1d!kFz%Ygwo8_f?{u4wJ^W0-2!xdhf=(39c1 zi7$Q}mwZoGvtB!-mjq((X>dfg9M~g$z4#_jIsnXs7XGND5D@skXr`Fo)r$}EPzg_u zCX?)qWBa0dsI#a)@v2SnJRzTO%cZ$=o2zTn6>Tqe7I=bfZ$L3BiuUZ~kiVu3oB`*I z)1fLAKE-_pB+9r4s;^ihj#cw%(;$`iEaMS zFC6B%S%2=R-LW!8i4VHD0@wV^c2|6#5as9u=}e&I1CM2{@nUX?Wl&OrnxpDwp7UXy zW9M36LFL4JV+9bL!#YUMKPxx{X{Yqu(wq~lV>!=m+vcG6cx6$%_CVthgFkjkHHACra|1nz}-uP8q6iDs|#6B_<9^ z(e%@|(cLDZq&M$IsuW-QYXHDrGlDn-AB)tWSa%l8R1w2 z2DBGwZrd%}iq@?$zXhWE9gIJnGn^`Q_yYh9PbnE06R=+HF&br~w& zRZUP1sCqzo+qff&&z;fyrLx_l#7m%=AjpW=$IsK*y2-2fFdayg=U*U#ZUTOu3IY`b za%Y~s`EOIS;Pn!e5kc8c@^{n>JO&1VfZC}Yw$Rq7fOK#pl9$(YT(cV~gHPZOvQ{c| z1Cu-9<4$=1Q2mT3ZC3#_W&CDXqdgojbJppSNJ1K9%KmN;b>F2UfT0^fl@6{Qo}xO% zf+lZNzRWpg#pZlaTYz!OC%^#EHK*RI0#mBH2u|L6^v*?NVO^8V4q5Z}6Kl2hJC&%z zj?)GlmIBZ#vBs&tgl;gepFQaW^LVaqdO}mBY3)kyfiU3UcOM@^2iw7UKtbjAIW)&v zb2eRo`Izo9ZfMnZtx?40iV=^#-O>@yP31}GLhvz;K5#ZC1AL-4haUFy2AB||bHE70 z;+g7B-P-5Ik~Tk<7W3P4CzZDrfcmp%6-?|1k(Q&1L+s;^Fp7-FY81U`~U}K`A-CJJcrbsGvfOxEVCVt)`hE1_IuQL z?MKzWo7gV;fO!^ObTZIG0F$h*|5N?GM4f0Ub3HmbOXE8bQomE_%S=WOW;MdrFJAw$ z`VRve>vtzAL>ri~XlSu|=>TH$ih(!l6B1b>DHHS*EJ{~!u}Ue^igr8cxmeAXq&N=Dn;j?uAtK(UPqQ}fP!y~eDUzYf%*+nmeN95!#^6@5pmbgy;t>`$@Eywp& zQljzzvz(cel3n3e@S!@o#&1~rhTkj#PrK9{Ua$=xrSMpUAmM+ zN(|GV5e1-yFC#+r#iCB8B0F*s!e4I@xrcX$S7yg83t68=Qohgh%87e5g^WNI2zrQ8}|aTAw`Yq2Vvb`kIa@ zzY;}+M`a^dw7jjqF6ux><Yh$qoSyAaSm^VX}?m%#GT|VH|AM1-Keh=B{NYuwpUvada+0rPKR0F z)C)!YaKcHOPfLKF(FwN~i-tt4q7Y^}T9wrgOa8;aJN3H~hefNIJJFJ|;s1d`i(i3< zhv($vzT~q!|Y+@SO2#Ds;COhTEq?1w}@0|MYi-0A^!T?AEDA{%AL%h9OzTbFG-fQtkJ7Qgz%sop0_DZX=cJt(w3QaDLa_a;iyH{&1p%>1STVg_KGQ9 z8O5|OvWDo3L^w3x-HO!bh~Q}Y%aWxgywaaJh^Bj3*Xtn>D%`%yyf2qj`qw*4bT`8`VqAax^yb1q{N}F-BWn_+ z&~o};III7azb(A3x-GAn$nyu4kcEDKdVV18 z&XxuYDR0=i6QmQZbvf)^VSo2F@x}J3!2JnFOSZ7@o?H7w_HS+}$HR4oj@XAhdqvUU z?xJ)z*$D}0yRYTOaN+wx$)RJNyD%#($(1_`lX*tA6hg8wx*k5Z-}~|u|24sB2vOZ^ zY=$NIBdn#FbPBzrEtA$nH%niDiz(AxsD`2;t*EWCc8J^j=oWtrF_r}-rdaM*Uc4`Z zqv(0^pIDOnv?t&Cyoe_Z9Xoa_0F%t;C9>q}1z%>elgZjP7BZC~PiA2stP8cabKR^* zwnfwJ6;Y^`=0hCz&akP8bT_u;@sDjYOAo5oqfSujCW-t+8DD71+%C7*K_-TsOVD3F zT@+719vQG!uiM=7ww7B*--Y35dO0yu*!tP&wn><@!l-)1 zdkdV@6)1S9@(op+qcrxf<$51&!^jJ4+?EH#$yYK<2H*{T{K;<*jT|gmoB~P?dP;Po z($3J8qIKlSKtVHue{8_WKctHYf*xA>mJo{h(nn8a8!3F`kfO~8J5+-`$dOp=ZFf5Tt4;0o`HML%4a z?qx5JHB^v2Bkd1r1st07w3HDUbAMgalkX0jWomS2klpqKI{@=-l|1?#1CW5XgpEhx z)Nz~OXFu4o#mEN5qvDd>qf&jAM`1f`g?>Yb;k)*6Ey~ioy-3y*%(6aRqDiD1gaSAF zLu02p*IV0rrkeA`)2q{(fD~_ie4hg$ zXb)_*MU283XwP6^bS>r4BRq?*c%&Ft!@M%!JO^Y9Q+#x^g{|yiHe5=8^Z?Yabj}9J1&XMP7k$y5@wL7wQRboj5=32EqB|T=N5);9_M;;LJ!mp&m{$Xasv8S!MGRLcuie zdiyINHGV9~_g)OUBYL$jUO&`!c6+X`o<3vB=was_Z8ZP5H?siOs(zHkvc0R#=4%Eh z2e0kc&B$tl)D{BmHd1h}R9|t58+YvIM|Zt_P?zjalLoNbbV^eKd?7Gq z>rC%-L?D)v_I>jgX54g&^nk72%4&LDqnh0AS4wT$mD%2i{Nfz~i`AHL6cWj|}aSgF}OGXvKIU*bz13k?3H;-ehfPA(rue^_tjo2+?B7fT3=n|_TpGTAt8b>krA zchzHL(5PX(zTPc|dnl((MmKj$NrOklVWZ8QpwPyMohrK`?&0Q8&GcOZ)ZxEuBF*BP zIa-DvEFlEnAWGaamRmkAyWU1)4AmIh3IY!|+BGeLi>R=K0h@j5%5YJ8O4HFN?);$> zPcP^U%UhwF?0WVDo-Yq#3>-c7^B{X!HT07Vl3r795JY>%JqQ~Oh7~v33bZuFA$DGD zpY4%u5N4aY;PE0XRQ`8^#B1GQDW?xq1;Bi_&P*>c9F8;32#S(3OzGN|_Q=ueD zRi>7aw+!c2Ju+x_%G}^~1%U)Z9W%>M z6=?e`($`WGgn|V)2gWxXCdE-4EFn(u#)N9hCMp3AiB+xo4K>J)%<8s7(KMk zXzds5!PyeE7b4WbYSTYcSa73)gSAS_NPaBc`7I0_O8z@nU%}v!oQfzwO87L!T>GNz z1$b|tAINY;?z&B2BsjCj-8WFC3fZ?dRxvia5?ndM5XE$kYg2nw8#NLwfwc@jStw)} zZ~v{aiZJZUJDQhr=e_0m^J%1(*jU5dM}X}`X;Io{fO-*P3r_p>8Y90aW|MowSLT?H z9~8#@X$&u};r;Er)O%k7pvSa@&CBR~$iRk;#*C|`|E@+69^eB94s_WZ#RhQhiOKWu z)u6cqJ^0LcchrP8_;1Li zNlBlI^8v5~N*#s^(ike=NNz%Zu8{n*b>E=ybitj?tHtc0sUh%q{yU1nGajqHg2CX7 z{RkceV6s5Mr6Xmw_KblT(C3!mXHK$!jPmyh9QHF3w z^YdZ!q5J0^WMvXNWctrgf`I(Egh3kt*|Bj*g+r&1bDutnCp#+zgq;sd54J<$LC9wi z4rMQ93@shJ@1(6k9|SNO55%;Pl$2A9m-Td7k*dDs-8n-+f z<~~G!_UPsjoabm9LdkGFIX#`k1M&qxv-*sMx9ifrYmEkhvU!fG0?=5HAQ_*kKCMu^ z=R~%g3do#H5!?VH)D(%Yg1hAkd)-?JX~*@|YZb^@geq*0UH3eBRQI_H9~`Cb6}RX< zThd7fh!$z4_SDFFKY(bx;Yxdji=An&Rr&t=OQez|fYHihTsrJDT~JlXS&IS8XI)Fe zoL{jR5GH#DG=qXDF zdG9mxe`VKg+R$TshvUBu9bpsa5Q?)D%bie?qyS|;*AWzrrM|ZW9#x#$t4TsKQhUp} z0V&ZUz@t;gwxGr(Ps`dlrf(@hY9=vFsTh)LcDKrb63H1-7 zEsC}g#Pv78`ui-Vo?21*oJpW{v=k^l#1%Px}-y)RAG)>L|q<^M|>u5e{t82G- zW`Pzvq`{yhI_J$dXH9}6?KeLl?l{&d=xcLbBO{^uvZG&2qWx?NuR-Sv%O_^*#vIj4 z4YU1JJk&g`e#gq9XO@G0s=oBu{($(>t9B!_#f61~uA{Rbq6QtvzlJXW#_adBNJnXv z`6%8|zSYnSdlsyZ{KUH4~Zt1w}(PC&v(Q~_AoHe)5Wlzjy(o5Yl!Ow6R zkx+QF1Nv5seC8-bkh+6@0oT5QK~}jv|0B(HEgz^|vMb}6Hep{U#dT@l`o#{LPuVZ2 zS~I&p8UtzO9;_T3O&fVaoo=*!%cW4A$JblGOb!C?=Sz+dAQ_ z+v;77dRRe#9)bQXP3B0+07ohz^mT_dGxjiYoc@5Z*t|~?;OekQeo-r~p!?U4$*yHO zWZAk5rwD^!(ejL-lgu%(t|E(a2KU)yxQ0(F08WFcbZiv)iL34`9|1m?yIk)$PT1V6 z&mTZX7h7EzjQo#PZ2E~Kpxpy$yn~wpj`)$;x6AV#7WqQ;YUs#PY*6`glWwjvQ9Yb` zF<8v!6j|;>G8k9?=bja~|3Tf>#2%yZjLVNLyt%^BO>l#~uT}O(=k{8&b8$QlV$Ikte&nBzjJx6mXEsn_Q{3U)_wwxTCZiaRP*amyf)>`?O|QV2R-9?{+~XC`xFH_~W+r<943Luu39J+JL>Q8kfi% zMF3D&6QMP+>M_L`AJ6)9XmN*{I^iz1sI5ncx|{UmEnz(4DFFYgMvNTbB-}1I?l+N+ zW>x1q0_mH~(YL?RX#iH9F`Zg$)^*iMPbBU`v7$|m)&YQ-$Cy4b>@ppb9V~sP96ma; z)1TN=INZvrMX7K%29NQZVfJ9F(qSJ+W*Df$`K#icH3tCFc2sZma6rM4f!{G| zZuSSM^Qt{OtO47@a!3Wid7y2;5bcb$2eN=!pvV`NfffT*f%%fQ>>)tDfiN+)JK13oZbyf`;ki7g%EMIMI3?`Dz#pm<8t5d_NTg8V-o8xsrlS7*}sWr?%b6qbFcL zlKNMcoBjAXnC;q8hbGX9?P>v!iCPSr3alm|9CSjvf_tc43UekkW#XG*N?a&!49%s` zMw`5`w?}p3vBs%5BtLb?ib#Hue@?6h(-2*0@RpE>5=tD=2FeH+mJU2QK8DAdvcx)? zArhv@u95Ssa=E`U-AhEeag6P&i8 zv;M=t3jtjXxg0G>+2=pkE5R$3Q2R=vntO~kXVt~?#S-E?Fdws~3h-Xyf@N7yY%u7- zm~3%dSuby@j_P5XwMOu<8pD|sX}~8;PxIX30*9iIGUsVev8-ThA<|q!;2cZV{8-Yl z^Cy)s+NA%%c~<=$m%;^zA=I8+UW+r*GzXD+CeuaG%doqLH4(&ej0N=@>&0}3L1O21 zO?Oklv9_A$dN-W6z~_l`UjfF9=4TAvg^B=2Kvy=a7OVk7YYnTW`x^{Ed6Qe?Mcfo@=#wN|KmKr9ikP#o5Q< z0$)mAHt<&c@x**l1=GmnWGy`z=$`C5@J4-nVwFgOj?QUZBzdzwIgukWFteENGkwBC zvol^+uN+uapQJ2hO2RogAE3!64U5&U4TPq#L`bHl6>^t^*3*@L6@AGxpY(j$XJrd6 z@woD95t9iT2*SA^EeQ!rdr$eZXcyC*rGb}*cy zE2KlpeawWM(j_6m`XSL8rXs5pPI@>JC#pGuQ1<4M`UTd zTIkB(G0V^$Cv!az?`&R5`bgO%%0$B!AvWu;iF!rt=)9MF;32Q;e-@R6=XsVaBkfRL zX95iY9qG>C`w($SYsx9O>nx*^yFtDzkn0Q)rky#T?fF=73tWL`l zH8R!VgE>pHmu6+Wq+T|#rXHVIAgX1OvoqY)zAHS`uM7m#Mr6BnC?+RK%tM-&7J|rs`$6PU;7uSDtNQW-Q@5$xTsG_; zb=H5s6qUY@ewa?lygPZeLy>p3xDSrKb^{d>m~iy&gY<*Cy@n1iR!FK~$=5_jB>h}K z{k9V~?F$tazx6|4-(pi(sAAH$=YHjH3pZy-9<+jX$XivpF}nZr#d8C3iukJUi z8=?~83CX!3)ZRg>z4pdke71LNKiznzsDxm68u*plakL#WVWl6a&eYV&14B~Zw|FbX zVoEuo-2xfnrtWntzL-orh_iRBygy{Q!8dqDBiKaoMphD3$ceABuY&oU8RAK9Lhwxk z&b9(Mo53#^^FDvKt}WRN>kl1`=1#(J(h1WKMP!IK@>h8Rj@Qa13dunnlogqYwI0nk zB%NlL7_5^yG54eEAIb%*^AM>|Cpvgcs41QP(0mspa^tABQo$0Mwb`(RzLmF}e{8rd zC@!P_y#Sjw&T19P>YI$}(>QgBQ9u1+0!}v|)xu18Z7~5+M!4(kE3pWNqS)>-LV{NB zk4!msR|30yVnWUb>GrVF`giflNL7%QcHcKM6yB@yvHmSsfXy}0a0Qsn0u`*~d#kkp zIudD}U}pO8If+ktVN(2n%`2IDtWM>j+U!P`!_+Tgi8yUz5<^PvPNEM++b23Uvx|?# znI9d5663h{sBUGUj0faVHa{@g>h4SYOb43gzK05<1n@IF-FMO#2*wocE#W?ix#uhl zXK1$OtDB>a_EToiFz@22bGGwhsNi0hL`8J;aiH~toL+BQj zWuk~Xs8HW+O;zazw%Z$b*FD^xyJDhgEJWQa=m*EA8)@K4rIjgz>t>!lX>-q|M&cM` zmMly?o?e(IOVY}_6ZnnPZbaT;gY+<`ZHaFxWcY3nDmGZXXsUiG2i6J#&D~a_X2f|o zB^UZVzKAiDtd&2g3p&c1Vz>{U-Rs+Lv2>6ZZ!pP8TrYe2X%IC(`FaB6NxZp}Yp3WY zGsOb9UOHuyK{lTFfZ#ZRUq;~5H>A*KzKz%dvsz!wJjaN&p0a&r2LIAtrYiz8JBcl; z4B9$9iBrbD!#h~l`IU;CfI7;O|LBLkM^lcREy7`GKYcdhKUhnf>KaqoAN_7{a}??z zTi&Y1C812F0;tq@XFHj0RE=`$u+UBU8h&%5b({}NEFw{n5TbL7Z&k}tPERmPLO{nsqN7sKh%NFUG!3# zJw5m*7#MvUqsQn4=;^kh2@+ar=*nFC#48+lMq@Rsk+KrlQ z&mESe{y~Z@iqwIWKaii|H#J$O_PTA>k=#7? zPE7Tl%E#*S4$H9CnIcjic_u}`M>XgkHJ0>fBq2PS?(d^mO)_WZsoYW~?YAlZzDke@ z$&!!BVm(t7sCd<3GYaysj{&8KZERabPlC6k*`=qY(^Bakds8D@niF}Xr&)(kd-2m^ zbOlw7zST{`2bQR}aX@E)CtY^Mc()u}T@rF$TU1`l_fT#buTMoymmKOlw^Nt)fM$8HCoj;(x6@Wt=zkS zUu&Icj@Sa19qVXImkpLnn%QIZTDtj4auE?{{!Cj$_V;al@>iOc!#;X{omI~#3CTl9 zLDxa-U8#LAdW#z?3_OM0oi@SCA8*XF?~qbS@M*IpSKh~d!g z$?LF^reNb~77k_ljvRzJVQ$OBi$e)d+_6KH0h1T<^V5g&#fwQrYU0n2CaPouVTKy5t3eOQ#CUAc@J0Jo!u@o2H!V`d z{?VxssKJdgbxEERy>8{ulsV(ckbMV`4z$vTlau9D zx%)Ry*)1~t5B-QHp`1y{OwE-S=vHO_8u z@25b)-rRuOF|Bf40`GXrU?1l=OKMb0E}h!6+QsL@m(Fln{=Ga`ZOaBP%=fyIzvvb} zm*a`tEp!M2{v`G3iz}(^oJ!CVkd#7IpNGzH30=+i#i~~Z>6&Y7i$<{9zyOT~M&+&3 z1r#cS;9%nc`M#mAVU{o%#cPP^SPv)$KsqM(etmL}XD);}wFTReE(|j zoS~qlc-B?!EaIcIzelwzxH~1n`@ppTp{YVw;ehjdx$9C|^2%?4r130mX@LXr~%bl`Y9?|_rpUA)_TO#BNnD>uQ1LrWePwC)@(;PEYE5;1ub3lq@|r!1P@5YY?WB zaS!*l>oA!{IEOy8&q0Xbi6j8o0fZ53nJcXLKSHy@s=I_a&PZE_ZL;dr7Dv7-VdI4{ zDNDYh8RP)(^~>G_`V68h@o$XYF3&lP+6|nJClO1=pwO5g)*c7$moPu{mzOR1-((dQ zNRWl2yU>`hAlUrN2sUE+3-%B?|DOnA>+nGmj!NeeVDZwQJA|!eVydwYPNdXL_eE@; zJbr79pPDDvt?`Q;2pw+x{Z8=8RuHEvlX{>+TB~5dqLpSaIFk~^LξjwXD-j!wFD z$gO0f%cBhnRpv_DOVdBs4FoY-+tkcvqnZApo_D)sp-W zGq_*gBL5#2+d2a~ zbkoP+M`;Mx3~`qS+9_b2B7$}E|7-8t!=#L0dmir31HiG z&i-TV=RD^;yL)S%hnY;qneY4k-f!mpP2M*jT#kI2jRdvQEajFxP_DxvsUK~{`Kq2O z>$e|g|CcGbWL9V6?3p4HeK{LRrl~~PUfBwTM)fR@AeJlLzck_jM3NG-s;1#KB9@jOt z(p5Uxo_(6I*aSKi{0m<{E&&-)SbkXFZz}R}uH?j)i2r>KE^u0mKg_~~>LT`wi>A1#oo|uIkK$oHfWsnMCN{R65hE)U%;v&=q^s z1zgOk-TIy3#6db17D2oyj0Sz;zabrj-5=+9_{~KF!gXok-|)r;;0g4zW`szhin82e z=Tg4cpM4_TQ2X#VIVaEk_3*bJ74(I@IMDFfw#a`Rq2=z_-*ElI&ms*^YC9b`c$;Im z=jB|vr9n^l?b|**UmkUO%udKw(TK5Du}I*a9i<_0s!AM_$FGW99>MFe4C0|5ad+(K z5po(E@zt<(#0l}X_wKc+G+Rlg}xrW7-l z7`M0*Vad*|s!c!=(1Dqt%)=yR3a$5V-$&y;7~YSC_m_hIl2<~8;4!9m$2rRN7Mh=D z)jlTTVq^VYs}^%B)67thyP*XB*;D3_OuqdtZ?6}R9z5E-%FAe={*ri zMMa#LL}uB^+@{Bk9&c59^u3cdLYEBnD3%3|Ppdd|LMk)dmOYiCzj*#UL&$Zf@~xF{ z(JXm@C*_rJVrIqMo|(*)MxLBEiXM)hRm?Q&G95;2L-db^MSgc(SnmCKLXv{3Qiebm- zQovnVP3;tg{JDC+5cvf$f?FpV22X?@7KJJ~4*D-&In2hT)KM*(%}9hrN-2il)~LFO zB@BVvB@Yi%zhF!<5ZtD4J2x)Vo`SYZOJRn~+7^ccpmu~jDi{xc@AwcgLSRfnklabB zdF-eB;pNA3?5ZHHSbm-}Qv+cOk|}ADrCKgi##e>dtJ;Q9zY?^3i7Rava-R(G1XQ+@ z0$IbL3kos%m9wJrj3(YFS;CWuPl20JHN0V|QvQrWf6eH|=>N*j@QB`t37D1t^3;17 zaz&h>5|=ob#h`VXf&iOkUXw^pM=9E0v7s?~I%$~F*YsqtogDT&I^(meYjO2qi08*ybivPBKL~GD1b>W zrMsquw#PWEB(1e-ahQ@nBj4L@mXNm^1`rV6i4lw%ksMVcjDpUpRv{F`)FkDPBPe$c z?ig#*Z5oaHcD%-)K4+yAsK|T?$q}|APmj42qK_U63kdfIzrg$afWkpe@NsnUQkZ#X zRnYLk^4^}iUXy#gr=CvMhSF;Jb-KnbCEVJ)OK)g~J7KkaquoN*|w(Z#N~z5lR}=S*4DXdge!^hRf*T-pw_^9Gh>E@F(+!Fo=+TViL1y|P=q_e~)gLehrtF1Bj8~Ikd zAf){EnWIKIMuqqTE=Qel_i@HK#zA-!oCzLk=xgL_%rs=~Wg2axEz`EpQfY3qKAI`b zI@mkdDmXT{trlBbTANwhRGU^ytR>V^BEFBfZ+_V3n9W(6Z?npFmLES|OADk;&_1Kx zp`D_=pdFyq(86g#+D4aRbDyjS$35N7xSe+May#YbS&kdPLH7>e2Ji=Q2k~{d8QKVK z^W@Ipl8CB^{0MTy4-o|scOq&ceu{V);Um*KbJ3^{-;FoM8RKne1Y89UYV2*yFpR{* z@jM(4uY=RU=iqYiNE!?$rNQvNIEqozUW|8`N0|3^!&t-Ey|G5I#$v-w_yoKhC&ync zKH6pcaj_TjAIOu)W5~0}Z;@w^$B}->Z;+>vEtYieuLCdC{qm9yRT`SQvLyhp>Ien6nD6bIqYK3%`J z%X6wcB6tLSu&&s%r3>E@-V%~i8zK*(-b$Zo3E^tD)ZV(dt$16*Hj|5QY6~x>R?63Z z^anI-jC)}J!-~l8Z+jC_iBa%De1)^W9GO&$+5j5)gV3&os z(d@=gzF1T&+RW(6JmpGYCrWcuR$j$HO@t&0oxEvClhJl2&n9Z2=-TKuGXoBV=lP}e ztC8j0)RFr@xHsmsE3K1?fNl<^rc9iISLSq;F#dEgI~2(-4Z-Sr^q67}(!Q-vg^0 zXsl*gi84|~?eDUuCQ%(@qI|@nBFEj@^0p)~##{x{;mY%cff3=rxX3Br^1vB{IBsDg zJap4!Y+mLWcZDmOYqBRvL{wRCg<@RT9vwKS#xVz-#jCxGjO_-#0ZZTX8ld*v+|`J- zvjK`5J{8$MBKVo}ixo}$JvU9#nDsA~2$Zn%aezcOy~fJn4m(_Mw?K|_J!qFQVv`G; zZMr5;>t5hc+cRz_6rl=SXG#Kxe=u=WDCmv3U+CrTFehFJpc5?x{5SJPSzjkGSravH z_0z*kt&Ds4i>43fI)CcSt7B!LtLnfK)tm^~_)hUE!ONwvoR4%pzX(AlUBaa#KSOAnmS7iotH#k`ra`;L{x9GA(~x zn&rExPLxDs-lR4`?J>`!c1fC8=-9Nh5M#Z^??4G+Nym5JGs?AVqUu;+JUy~%WOpuu z9Qvi@)!vR9C{q`e5XtB;*^T++l73V53lYPV4Q(SeTN>HTK2>yB>AAbLpz;#2Xi^eQ zIeL>GHtXN~RzXbP(AtIrWUnh72aj+M#jErY_r6FkY^p@g-)1ZjfJJ2g+n<}v$2XPU z^u2_-gA}w!E{tfwj+vJ6^YjO2o>xz|`h#FnO6{;L`E2I#mbTD>&WpU8VlzdRtn3ao zsZ>9gkP9KgCxoo=&PiC_9cTitvZKE}AHJw?HMMq#VzqpobqEuKepci)@!B{c)nRh5 zvvZCP!G~0r`A+;e%FBq|tN#oZv8W1s`-87VJi1pGi`8e1m`de~Oi}DU>EB0G&Af5} za`x4{66c;iR$8HD8(M2<6Dr#zCBj&{Lzb^_TeBHR0WP~LU4(ssKTFLI=q~g=ArEmoXM+p=}i;p-(k=GbGNgI zOu5h(XQ!i=An1CN#Brdr9b%XlqLMQg#`_$sA4y+Gj?!T~7ODf(I6u_B=26$tmnXZi zbv@l>c{=)-^;^wN=?GuNJ%Bev++ArF4`m0=G)|OX7l7G1oaw%C5Uzi+{R#P?Xsd&7 z_HD^qe2;BVss_6!H?zJ=agQDTt*1HfF?eOs%labwgmu#kQLJy9ZyDtE+yVV)!P&Cb z=s4?|#T!NZ-0Pi2So{3MMWnc6j~4F6SM8Z^N-_4fA&Xqn4hS$Xg)iW~4(ocWOYdSI zKptN3OCGOqCu7bJ11*?6b*Tc6;iW#CF3DC{pTNJxJ?@(1=c9D?QUj`UkhC)^F+zfs z1`3)sB9@RL|)LP-8+A%|2-NhoC`K5n$f#Kl!b z1uPsqb#ogGyBqhy-xk^7tcaIue*DN~eWT$O@mt$yirY8|mTS{GX%wBN@7tCz&Y^Cn z*k9zGDl`}Fgp!Yn`?C5sxw%l`vn@M_HPRUlzE{YJ=`-f;%w5hp$5_&+;t{veg`63o zqrm^?E>vvrcqB1b#JU$15KcOj-)C1T*cG(c{B%mvAYJ+u6dmL_VJbo)=uk8SJbqZ3 zQYBxKTTmkv5+bpC$zl!-A4A@PFB4DIH@=R4ZG&!m2oV#H$kdO})jM^_4q(iLyawhi zNLV!*I9HT5K$!`*r=~HHiuP}bN5a>Rn(rj2c%99$L5~e>rv#{!S44#nG2FPyAfH1G zr5Dmep7INP-8;e0!6(ve7z1Vue!$OFDPF+yoY{3V`CwDyWHx0Of#q=LO+l9UEi{8t z^;kf^rEGH;c)1grkKXcDUygbvNsvZEG`Ny*yEHa-+M?WN1nd=eQQ(eQ6H+8B%GnaT z$R~QT_LKnXN}tC1w(==J;x?i#7`WwO0yIA5?Q&hHxX%I%NdQH46@#Wg8woQCP>4Z< zT}okB;1kn4<*0eMi1*XlsV-w~t*j@^vI#HX{(h0_7>w?~q*=t%&&U~u=X_T~k)Vf4 zH4oDqkKfTJ zf3*w8WbNaGeqt|C zog}`!N~w(^zw4wJ8ef?VPX*N|Fj(eO`-=kY909EGa_Ey4Oz~S_j#~wTZzZj0O`f=I zl$C)7B`hL|#kUD(soRAz4%k!i-gxL2jGbV_K_!gBu6UF#V7cT|;5 zsM9Bu3y7<*j`_WZDjvi#p6EpqSjj^Mo4H-_Ru%!bB#1KCYnODvynJVFw%GFb& z{`oFOeK+Ovyl724Z?-{1Ii&+&h0QzD0ESGhV!8l|Ao;2Ty-~^zr@vc~_(^fUg z|B0>7q#s{t4@NcJ(5T$G6^z^x5AneD%3{7Y=QniZ&;2A9N1ssGTKcUGCp8|=)YYno z*39oiOHZjHzL8(M&Gw~OQu0Ay#}bp+*w?CNL6!P63=G$5_z&Yp{%S1&GssGeYDq8W zP3wC3E5;N<-HaLJBTNjz2$1ew8P7Zne8((i@toF5b}3}j6X>)Q;UBIcW6NX2z}5<`kz=Ej4REE77Z@VE|+Ll44K-YDp}9CxWb(`EthJQEF`z$cpWC zU4a}*rew*5+7C=I1hC}AaG4bZc4j}I5l@?jMC5y`+$3fVZYF1>v9)5_LR?2mz&xjwd6)OW& z+41d)fF}uO=@}Xz;ax4L@*Ujh9FT!=D^?m!yu?^92Pys&b%oX3ZfZPH(*~=B6n1r=>5 z#z?@CwR&nmkS0+I*0@=W-tQ#=cky=~opz$lTQny(K~P%eT980?35MS^sOEC_Nh4BlcrLRd{j3%E)Uo_g4|iSjh%UxBWuh&!IaD5la7c*wSpvzBp#HgrF^0hT^&hYa7Cr4prT;-~h1t!<{`4P^ z=fA3HIa~TzT|_Pz=ctQVv_7b#QRq&zbGTokk2%Z6ZUQ*;Biz zjV2qr1o$$XywhSUY4hIWl7~)0NyA26^B7%>3RHv-NG%bH`&wH`sZfk9_@7}p zm&ADibeNTHQxFAu$&C_H?+-rdUG4-W)S&rbc)4pRP`*DRXPwS0R+sKDS=Jgvsfxm! zi#auJ^I%}_vM{9o&36lUDK`~LIphE)O)TjySzjf^> zY%Lx4G`f>)>q0e2NSeN zWPir;6Bw0}25#mjgDwOC+ew8|KQx#BO_iTp&XI`Dvo~0hYX#0~M}*Z964aN2N_=_^ zjFQL3%y9%C1F9O)aGb(!8)!qYx6}jNo)ENv_G!>^0NMp89?Pg{27+RxiZxe3Ktlq1 z8;~%KN#n4!?9u?pB84gxB!f#(qt9&S{zZ<;ddq^UV&lr?{%(L<=xgzlDxO+1E@x|z z(4Zj!6wU9lTb92OhY{C@g=I-moR^m5~fB%^m*1X zBA3V$bQ7Q;j;>Q$t!t1f)a0rqg<`}Kah*Xx2tez?ii>4dY|swcKR&Kt`G^^PaGriF z%=92@9IKngLLr|HO)@Y;Mb~$XVRcJIq4{7G)V>c?KzsVa zn_US*QxJ0l38S_H2$EAu<$<8jU;@nmThEl{o+4`J9cA?sq;;#c65!~@-SLuLQSpl%SPz;WR<8&^@TcMP)ix-Fc zcx$UAWCMPyN|pUu1EX9d0(rH(+i7Tok^*qpMs^e+?=~h)DPOY`P(kZ~9>q*0zNf;z zf4z^u`w0Axj{sLZ0`%-SUHcmIZ}}kWA2dFAa@#NFWlzm+-JC&)nG)#5Z+CGm#@F8J zKX;|D-gDH5kKaNocX2BIzfzemlCgVrj5={C_zBt`+9$yt!A0gSBgeHXZ=W)EdGJxO z--w@fMMy!&L&f0{myvTLM@PIy+#a56^o}_hb0+3=j91L5r$(f`BqNeB2M@tP@a?#E zJPHTSm4j2_*WuRTH_{H%9@D;}eLHzFI6d8Si^mr4)RU=aQctIPrJhRlO!Y|hc00Kj z-$nZ(xIE%PM0!MR#I1<()5WhoDMK}?D6K)ee z8JCQA#ku0U`^ERoU9&3vs{9K49`>A475Nn_d|J-69L=imtMDuJBaYo4D;p~sd(`8h zdeGyoDj$2GIML!MB}L>!kZS+Y;@INR;@aZd;^b2{RxwsOMm%F|;SykE*g#uHyH0bY zb<%dx3TZ(!4s9zfgXTf|`{b2L?Ma_Wn@Qy4$CF1UcTWaSZkzO+w3yVV@n|}<9GWkU zNwcG!pCr&ov;>-*wtm`!e& zbed$-+-Y>0d2q{us}f7e>i4@+P=Dr}QH5b8-qOXlIIllo=KiDdBA9L-uFplQxT62W zOx5B^FOT@W?YKca%SFGK*za!SIdpR9%+To}uc1>iW8Pk#(K?)YFe>=I`LU6IXg>-m z+kEQgpFZoj-2awI*Tftw{(oX4GXGY)Jfvu|ho#p|;~c|XxD)t0G{fM7!S&`xy9|H& zJUjCj=)9(49q@kf=M4c}kd3*h;q&+(yZ^IynP~R+zigfV$o@n3wFS=UJsSS>142Za z03D59`%gMonS0}nyO(paOe$$7s(MnogPFKs<|R8X+;Mua9sL>J^ALECDXYAEU{|KD zMP$gh)!+VzhF`W$e)j&_@Tb2#D#9d}yomY6e>40yyhnJ@;(BNRXejFmPUzd2gSI3D z)EgqFNyx5_{+W$wKx7)GySwi7VAdV7rEl(8dJMp=Ny_)|yBs{XSfa&zJ^Pw@4EE~} z3a&_pebmJG+)#VUQ?~)i4>Z`w{x%>jAj=zWwW+RjcE^w+(JSJG2KFX{-0nb`*oxYe z5vzA5E|T7qo}JX*ncKKEBYi!Pvwg^qd}26g?!&Gnd9-GeU#;i#%W&?6tJquDfHeP2 zq%C9h*M2F4e45wC|9s2bOUOQL^5H_RINQN=Y`+ZGV@r?Tl9q4r_Yj-LpLeoQg(od) z>-UR(=2&jjX1(T!KJ~(DZmjA`Z}sDC{EpdJ+nc+H=&xHU6Oyx|t+cn= z@((m$qnqkz1xUDFNm{HpQyXEwApJ@Xj-?L;#($bD;QzBCqttTlvlt!V@M$~rg@`V* z$-prFxuo3#4CC2D(fp1Krx;C`{eZoF;;Ubd@(8kHdwN*Q+zXO0!q+VVv~$eP*szPz z^!{sIO}<*`PD3V}l^4I80M*NmSXY>P*l@e616p<95>ZDF2uCg~X^nTZLkd$T ziXb-)c&chdb1bu1+8xG+-BSh#I`kS`GlDkpZuB+i9~uQ5^X;wojlwTn?3mFvp!?fc zk1a7mB0PFcCmvwEj1pNCSRfIe-$Nw(GUnq%cLFbCp%$$wE#i_!23PF9h1cRKL4Wg! z2qE=0pG9(0I~;hTFBDG!o!Lo>6dOk33<3#1F4 zVu{cY?UzX_9XArRvVzZmH5WYI;?K~O)S5I<-H7#QnnMuz8+Kds=CUt8lT4c zj~FNt6POoz4`6SUrecgKJi70;WqIP|z|1-;6>m^2yMv>8Q6+-sBrX?(#WRNVV=IkD1zq0gAO zyFuBciGbOBCFQj0B2?27QT zcz(*iz4J%7)7K)s-WzsRu%_n=*#$|-b8lhy7LTS~>wZMYpe6`d)3SuzddE105}qzf zu+U8K8?yeU{vMQ^1YU2$q_r%wbBgTQ-XK>159i ziHeIJ34WRH=}hYB_l?~&Zfy6(_DZ{w`B&?bMOyky+Y3?Iu3;fMkC;7rFr1zHF!vRhlLRiW@XL8{cY~juk`}@}SQV-#r83 zoEL^&9Xs!fZEviDQuf&(^LjgBY@f<%&tP$fj?)&CM|Ixx>g6or*OQTTDGrzo_SLRm zw$7Dj9Vw00U{cxV~;@vyrgyq=9E2WMKg0@qVV z&30DLR$st6vD|MG+Rzf=3vgx?gs<^VsELtZZK%A%K`RjlTnpIa=TEjrMvlHdN@QP` zvL+tymhZOO?R(SnB5_;lm+?*uv;k(4lntaq7n$g!O;)V>=jABIwZVM)*Bp1lD~rl; zUMHk8D>K%0y`m1;y!>--hc4mx3Cl-&H4zFDBCxx1gVxPh|7_?~qiUky!}iWj!i62< z07Bk7)H!f5?9{mhsER0jc-lHA4xPy{&_g}@75qynfIQ1%Ba=5eov=rD^leRLuxUS0 z2|tsLyNAkXg&*ZMm@E>zfNyrJ;{w&s$Vw*9sDcdCCoTvs!Q8XTrVO$#95#tLNy(@j?J*U`7Y1}Qwf>zb5QOCO| z-IBzncqBgtJ{S}+3oJL>X)m0ojTkTZ35Hm8@psyKc zHCxyAsBK^>xPix)JAR3BB`oZM&iRmd9M_Il?J_(jMlsqT&=-eGc>3s#yWIgl>xHF+ z&OMjV`eC6;Vn<&kyyN@4Aya5@AZ65P3bj4n-~PVh$0iwr^?Yey|3zmX@WK0rm!@Y` zhnq!n4c)QYOVEn&WBfg+~CE>RROS(rq0zP)~TyXg*pUAO#zbL8m<4h_tm z!lReNkCZO#c8@6^7j|cUQxVHBW*1Rkjm`$au+EGIGkwZc?C59hFb&82FVB*@C{|^< zBUhZiY;p3rov$CrOiR_NPSlX)jgJF2}3fsk=AG)R2>`sqZ!5 z@K^~0p)4&*2zG0px_-bwF@ryI7`Jm!W|mc&cKq~iPxB_=t%XT^fWNpCa>4z+`(*5w{;_?~K8m2W-GI>x>=__VEu{<*NprfPV*0Es zl-f$&a97W#Zs~~pWpCCY39F}fgP)Ih%sB^kaH?>-crY(g*l%%lPhR-1x0ZG)zggnh zrBo&;WU$6ZS&NT1L`Pmc-fHL((89kZZw(esr^Uc3!up&55MqKtLRn6v6G0B28oPFB zyuC+yL4F`Mak+FGBcA+qLt986Wq(Mz@b+n` z3GcKhs)>0N7}{%A1w%j4&;z)7R=u#qFcn5Q*7Gb>0(@}N^~1mDt^4AGzkKk?2ZPB} zucwXivEk>}eSoc2exUKemmj?QSM-B=?Mm3S=%`53YgdBcA;F$!KK@``2l@L8(Q%tT e`0yk3=PQ3sx_-P~{pLgP