|
| 1 | +import json |
1 | 2 | from typing import Any, List
|
2 | 3 |
|
3 | 4 | from ..error.illegal_attr_checker import IllegalAttrChecker
|
4 | 5 | from ..error.uncallable_namespace import UncallableNamespace
|
5 |
| -import json |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
|
9 |
| - def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str, |
10 |
| - target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": |
11 |
| - configMap = { |
| 9 | + def train( |
| 10 | + self, |
| 11 | + graph_name: str, |
| 12 | + model_name: str, |
| 13 | + feature_properties: List[str], |
| 14 | + target_property: str, |
| 15 | + target_node_label: str = None, |
| 16 | + node_labels: List[str] = None, |
| 17 | + ) -> "Series[Any]": # noqa: F821 |
| 18 | + mlConfigMap = { |
12 | 19 | "featureProperties": feature_properties,
|
13 | 20 | "targetProperty": target_property,
|
14 | 21 | "job_type": "train",
|
15 |
| - "nodeProperties": feature_properties + [target_property] |
| 22 | + "nodeProperties": feature_properties + [target_property], |
16 | 23 | }
|
17 | 24 |
|
18 | 25 | if target_node_label:
|
19 |
| - configMap["targetNodeLabel"] = target_node_label |
| 26 | + mlConfigMap["targetNodeLabel"] = target_node_label |
20 | 27 | if node_labels:
|
21 |
| - configMap["nodeLabels"] = node_labels |
| 28 | + mlConfigMap["nodeLabels"] = node_labels |
22 | 29 |
|
23 |
| - mlTrainingConfig = json.dumps(configMap) |
| 30 | + mlTrainingConfig = json.dumps(mlConfigMap) |
24 | 31 |
|
25 | 32 | # token and uri will be injected by arrow_query_runner
|
26 | 33 | self._query_runner.run_query(
|
27 |
| - f"CALL gds.upload.graph($graph_name, $config)", |
28 |
| - params={"graph_name": graph_name, "config": { |
29 |
| - "mlTrainingConfig": mlTrainingConfig, |
30 |
| - "modelName": model_name |
31 |
| - }} |
32 |
| - ) |
33 |
| - |
| 34 | + "CALL gds.upload.graph($graph_name, $config)", |
| 35 | + params={ |
| 36 | + "graph_name": graph_name, |
| 37 | + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, |
| 38 | + }, |
| 39 | + ) |
34 | 40 |
|
35 |
| - def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]": |
36 |
| - configMap = { |
| 41 | + def predict( |
| 42 | + self, |
| 43 | + graph_name: str, |
| 44 | + model_name: str, |
| 45 | + feature_properties: List[str], |
| 46 | + target_node_label: str = None, |
| 47 | + node_labels: List[str] = None, |
| 48 | + ) -> "Series[Any]": # noqa: F821 |
| 49 | + mlConfigMap = { |
37 | 50 | "featureProperties": feature_properties,
|
38 | 51 | "job_type": "predict",
|
| 52 | + "nodeProperties": feature_properties, |
39 | 53 | }
|
40 | 54 | if target_node_label:
|
41 |
| - configMap["targetNodeLabel"] = target_node_label |
42 |
| - mlTrainingConfig = json.dumps(configMap) |
43 |
| - # TODO query available node labels |
44 |
| - node_labels = ["Paper"] if not node_labels else node_labels |
| 55 | + mlConfigMap["targetNodeLabel"] = target_node_label |
| 56 | + if node_labels: |
| 57 | + mlConfigMap["nodeLabels"] = node_labels |
| 58 | + |
| 59 | + mlTrainingConfig = json.dumps(mlConfigMap) |
45 | 60 | self._query_runner.run_query(
|
46 |
| - f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})") |
| 61 | + "CALL gds.upload.graph($graph_name, $config)", |
| 62 | + params={ |
| 63 | + "graph_name": graph_name, |
| 64 | + "config": {"mlTrainingConfig": mlTrainingConfig, "modelName": model_name}, |
| 65 | + }, |
| 66 | + ) # type: ignore |
0 commit comments