|
5 | 5 | import numpy as np
|
6 | 6 | import torch
|
7 | 7 | import tqdm
|
| 8 | +import dill |
| 9 | +import json |
8 | 10 |
|
9 | 11 | from .dataframe import EncoderDataFrame
|
10 | 12 | from .logging import BasicLogger, IpynbLogger, TensorboardXLogger
|
11 | 13 | from .scalers import StandardScaler, NullScaler, GaussRankScaler
|
12 | 14 |
|
| 15 | + |
| 16 | + |
| 17 | + |
| 18 | +def load_model(path): |
| 19 | + """ |
| 20 | + Loads serialized model from input path. |
| 21 | + """ |
| 22 | + with open(path, 'rb') as f: |
| 23 | + loaded_serialized_model = f.read() |
| 24 | + loaded_model = dill.loads(loaded_serialized_model) |
| 25 | + return loaded_model |
| 26 | + |
13 | 27 | def ohe(input_vector, dim, device="cpu"):
|
14 | 28 | """Does one-hot encoding of input vector."""
|
15 | 29 | batch_size = len(input_vector)
|
@@ -54,6 +68,7 @@ def transform(self, df):
|
54 | 68 | return df
|
55 | 69 |
|
56 | 70 |
|
| 71 | + |
57 | 72 | class CompleteLayer(torch.nn.Module):
|
58 | 73 | """
|
59 | 74 | Impliments a layer with linear transformation
|
@@ -854,6 +869,80 @@ def get_deep_stack_features(self, df):
|
854 | 869 | result = torch.cat(result, dim=0)
|
855 | 870 | return result
|
856 | 871 |
|
| 872 | + def _deserialize_json(self, data): |
| 873 | + """ |
| 874 | + encodes json data into appropriate features |
| 875 | + for inference. |
| 876 | + "data" should be a string. |
| 877 | + """ |
| 878 | + data = json.loads(data) |
| 879 | + return data |
| 880 | + row = pd.DataFrame() |
| 881 | + for item in data: |
| 882 | + row[item] = [data[item]] |
| 883 | + return row |
| 884 | + |
| 885 | + |
| 886 | + def compute_targets_dict(self, data): |
| 887 | + numeric = [] |
| 888 | + for num_name in self.num_names: |
| 889 | + raw_value = data[num_name] |
| 890 | + trans_value = self.numeric_fts[num_name]['scaler'].transform(np.array([raw_value])) |
| 891 | + numeric.append(trans_value) |
| 892 | + num = torch.tensor(numeric).reshape(1, -1).float().to(self.device) |
| 893 | + |
| 894 | + binary = [] |
| 895 | + for bin_name in self.bin_names: |
| 896 | + value = data[bin_name] |
| 897 | + code = self.binary_fts[bin_name][value] |
| 898 | + binary.append(int(code)) |
| 899 | + bin = torch.tensor(binary).reshape(1, -1).float().to(self.device) |
| 900 | + codes = [] |
| 901 | + for ft in self.categorical_fts: |
| 902 | + category = data[ft] |
| 903 | + code = self.categorical_fts[ft]['cats'].index(category) |
| 904 | + code = torch.tensor(code).to(self.device) |
| 905 | + codes.append(code) |
| 906 | + return num, bin, codes |
| 907 | + |
| 908 | + def encode_input_dict(self, data): |
| 909 | + """ |
| 910 | + Handles raw df inputs. |
| 911 | + Passes categories through embedding layers. |
| 912 | + """ |
| 913 | + num, bin, codes = self.compute_targets_dict(data) |
| 914 | + embeddings = [] |
| 915 | + for i, ft in enumerate(self.categorical_fts): |
| 916 | + feature = self.categorical_fts[ft] |
| 917 | + emb = feature['embedding'](codes[i]).reshape(1, -1) |
| 918 | + embeddings.append(emb) |
| 919 | + return [num], [bin], embeddings |
| 920 | + |
| 921 | + def get_deep_stack_features_json(self, data): |
| 922 | + """ |
| 923 | + gets "deep stack" features for a single record; |
| 924 | + intended for executing "inference" logic for a |
| 925 | + network request. |
| 926 | + data can either be a json string or a dict. |
| 927 | + """ |
| 928 | + if isinstance(data, str): |
| 929 | + data = self._deserialize_json(data) |
| 930 | + |
| 931 | + self.eval() |
| 932 | + |
| 933 | + with torch.no_grad(): |
| 934 | + this_batch = [] |
| 935 | + num, bin, embeddings = self.encode_input_dict(data) |
| 936 | + x = torch.cat(num + bin + embeddings, dim=1) |
| 937 | + for layer in self.encoder: |
| 938 | + x = layer(x) |
| 939 | + this_batch.append(x) |
| 940 | + for layer in self.decoder: |
| 941 | + x = layer(x) |
| 942 | + this_batch.append(x) |
| 943 | + z = torch.cat(this_batch, dim=1) |
| 944 | + return z |
| 945 | + |
857 | 946 | def get_anomaly_score(self, df):
|
858 | 947 | """
|
859 | 948 | Returns a per-row loss of the input dataframe.
|
@@ -957,3 +1046,11 @@ def df_predict(self, df):
|
957 | 1046 | output_df = self.decode_to_df(x, df=df)
|
958 | 1047 |
|
959 | 1048 | return output_df
|
| 1049 | + |
| 1050 | + def save(self, path): |
| 1051 | + """ |
| 1052 | + Saves serialized model to input path. |
| 1053 | + """ |
| 1054 | + with open(path, 'wb') as f: |
| 1055 | + serialized_model = dill.dumps(self) |
| 1056 | + f.write(serialized_model) |
0 commit comments