diff --git a/ckanext/charts/chart_builders/plotly.py b/ckanext/charts/chart_builders/plotly.py index 4dafb0b..8ad1ace 100644 --- a/ckanext/charts/chart_builders/plotly.py +++ b/ckanext/charts/chart_builders/plotly.py @@ -28,16 +28,45 @@ def get_supported_forms(cls) -> list[type[Any]]: class PlotlyBarBuilder(PlotlyBuilder): def to_json(self) -> str: - return cast(str, px.bar(self.df, **self.settings).to_json()) + return self.build_bar_chart() + def build_bar_chart(self) -> Any: + if self.settings.get("skip_null_values"): + self.df = self.df[self.df[self.settings["y"]].notna()] + + fig = px.bar( + data_frame = self.df, + x = self.settings["x"], + y = self.settings["y"], + ) + + fig.update_xaxes( + type="category", + ) + + return fig.to_json() -class PlotlyHorizontalBarBuilder(PlotlyBuilder): - def __init__(self, df: pd.DataFrame, settings: dict[str, Any]) -> None: - super().__init__(df, settings) - self.settings["orientation"] = "h" +class PlotlyHorizontalBarBuilder(PlotlyBuilder): def to_json(self) -> Any: - return px.bar(self.df, **self.settings).to_json() + return self.build_horizontal_bar_chart() + + def build_horizontal_bar_chart(self) -> Any: + if self.settings.get("skip_null_values"): + self.df = self.df[self.df[self.settings["y"]].notna()] + + fig = px.bar( + data_frame = self.df, + y = self.settings["x"], + x = self.settings["y"], + orientation="h", + ) + + fig.update_yaxes( + type="category", + ) + + return fig.to_json() class PlotlyPieBuilder(PlotlyBuilder): @@ -208,10 +237,34 @@ def build_line_chart(self) -> Any: class PlotlyScatterBuilder(PlotlyBuilder): def to_json(self) -> Any: - try: - return px.scatter(self.df, **self.settings).to_json() - except Exception as e: - raise exception.ChartBuildError(f"Error building the chart: {e}") + return self.build_scatter_chart() + + + def build_scatter_chart(self) -> Any: + self.df = self.df.fillna(0) + + if self.settings.get("skip_null_values"): + self.df = self.df.loc[self.df[self.settings["y"]] != 0] + + if self.df[self.settings["size"]].dtype not in ["int64", "float64"]: + raise exception.ChartBuildError( + """The 'size' source should be a field of positive integer + or float type.""" + ) + + fig = px.scatter( + data_frame = self.df, + x = self.settings["x"], + y = self.settings["y"], + size = self.settings["size"], + size_max = self.settings["size_max"], + ) + + fig.update_xaxes( + type="category", + ) + + return fig.to_json() class BasePlotlyForm(BaseChartForm): @@ -243,6 +296,7 @@ def get_form_fields(self): self.log_y_field(), self.sort_x_field(), self.sort_y_field(), + self.skip_null_values_field(), self.color_field(columns), self.animation_frame_field(columns), self.opacity_field(), @@ -372,6 +426,7 @@ def get_form_fields(self): self.log_y_field(), self.sort_x_field(), self.sort_y_field(), + self.skip_null_values_field(), self.size_field(columns), self.size_max_field(), self.limit_field(),