Skip to content

Commit db8f094

Browse files
authored
feat: add a base model to the repository
1 parent e0b785c commit db8f094

File tree

6 files changed

+73
-32
lines changed

6 files changed

+73
-32
lines changed

model/ver20220624/model.json

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_1","layers":[{"class_name":"Dense","config":{"units":32,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true,"batch_input_shape":[null,512]}},{"class_name":"BatchNormalization","config":{"axis":-1,"momentum":0.99,"epsilon":0.001,"center":true,"scale":true,"beta_initializer":{"class_name":"Zeros","config":{}},"gamma_initializer":{"class_name":"Ones","config":{}},"moving_mean_initializer":{"class_name":"Zeros","config":{}},"moving_variance_initializer":{"class_name":"Ones","config":{}},"beta_regularizer":null,"gamma_regularizer":null,"beta_constraint":null,"gamma_constraint":null,"name":"batch_normalization_BatchNormalization1","trainable":true}},{"class_name":"Dense","config":{"units":32,"activation":"relu","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense2","trainable":true}},{"class_name":"BatchNormalization","config":{"axis":-1,"momentum":0.99,"epsilon":0.001,"center":true,"scale":true,"beta_initializer":{"class_name":"Zeros","config":{}},"gamma_initializer":{"class_name":"Ones","config":{}},"moving_mean_initializer":{"class_name":"Zeros","config":{}},"moving_variance_initializer":{"class_name":"Ones","config":{}},"beta_regularizer":null,"gamma_regularizer":null,"beta_constraint":null,"gamma_constraint":null,"name":"batch_normalization_BatchNormalization2","trainable":true}},{"class_name":"Dense","config":{"units":1,"activation":"sigmoid","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense3","trainable":true}}]},"keras_version":"tfjs-layers 3.18.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"dense_Dense1/kernel","shape":[512,32],"dtype":"float32"},{"name":"dense_Dense1/bias","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization1/gamma","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization1/beta","shape":[32],"dtype":"float32"},{"name":"dense_Dense2/kernel","shape":[32,32],"dtype":"float32"},{"name":"dense_Dense2/bias","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization2/gamma","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization2/beta","shape":[32],"dtype":"float32"},{"name":"dense_Dense3/kernel","shape":[32,1],"dtype":"float32"},{"name":"dense_Dense3/bias","shape":[1],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization1/moving_mean","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization1/moving_variance","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization2/moving_mean","shape":[32],"dtype":"float32"},{"name":"batch_normalization_BatchNormalization2/moving_variance","shape":[32],"dtype":"float32"},{"name":"iter","shape":[],"dtype":"int32","group":"optimizer"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v3.18.0","convertedBy":null,"trainingConfig":{"loss":{},"metrics":["binary_accuracy","precision","recall"],"optimizer_config":{"class_name":"Adam","config":{"learningRate":0.001,"beta1":0.9,"beta2":0.999,"epsilon":1e-7}}}}

model/ver20220624/weights.bin

69.4 KB
Binary file not shown.

trainer/model.ts

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import * as tf from "@tensorflow/tfjs-node";
2+
3+
export async function getModel(
4+
modelPath: string
5+
): Promise<tf.LayersModel | tf.Sequential> {
6+
try {
7+
console.info(`Trying to load a model from ${modelPath}`);
8+
return await tf.loadLayersModel(modelPath);
9+
} catch (e) {
10+
console.warn(`Unable to load a model. Creating a new model`);
11+
return tf.sequential({
12+
layers: [
13+
tf.layers.dense({
14+
inputDim: 512,
15+
units: 32,
16+
activation: "relu",
17+
}),
18+
tf.layers.batchNormalization(),
19+
tf.layers.dense({
20+
units: 32,
21+
activation: "relu",
22+
}),
23+
tf.layers.batchNormalization(),
24+
tf.layers.dense({
25+
units: 1,
26+
activation: "sigmoid",
27+
}),
28+
],
29+
});
30+
}
31+
}

trainer/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"description": "",
55
"main": "index.js",
66
"scripts": {
7-
"build": "npx ts-node trainer.ts"
7+
"start": "npx ts-node trainer.ts"
88
},
99
"keywords": [],
1010
"author": "",

trainer/trainer.ts

+33-24
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import * as use from "@tensorflow-models/universal-sentence-encoder";
22
import * as tf from "@tensorflow/tfjs-node";
33

4+
import path from "path";
5+
import { getModel } from "./model";
6+
7+
const MODEL_PATH = `file://${path.join(
8+
__dirname,
9+
"..",
10+
"model",
11+
"ver20220624"
12+
)}`;
13+
414
async function main() {
515
const encoder = await use.load();
6-
const trainData = tf.data
16+
const trainData = await tf.data
717
.csv(
818
"https://raw.githubusercontent.com/smilegate-ai/korean_unsmile_dataset/main/unsmile_train_v1.0.tsv",
919
{
@@ -28,10 +38,11 @@ async function main() {
2838
ys: Object.values(data.ys),
2939
};
3040
})
41+
.prefetch(10000)
3142
.batch(32)
3243
.shuffle(32);
3344

34-
const valData = tf.data
45+
const valData = tf.data
3546
.csv(
3647
"https://raw.githubusercontent.com/smilegate-ai/korean_unsmile_dataset/main/unsmile_valid_v1.0.tsv",
3748
{
@@ -56,38 +67,36 @@ async function main() {
5667
ys: Object.values(data.ys),
5768
};
5869
})
70+
.prefetch(10000)
5971
.batch(32)
60-
.shuffle(32);
72+
.shuffle(32);
6173

62-
const model = tf.sequential({
63-
layers: [
64-
tf.layers.dense({
65-
inputDim: 512,
66-
units: 512,
67-
activation: "relu",
68-
}),
69-
tf.layers.batchNormalization(),
70-
tf.layers.dense({
71-
units: 512,
72-
activation: "relu",
73-
}),
74-
tf.layers.batchNormalization(),
75-
tf.layers.dense({
76-
units: 1,
77-
activation: "sigmoid",
78-
}),
79-
],
80-
});
74+
const model = await getModel(MODEL_PATH);
8175

8276
model.compile({
8377
optimizer: tf.train.adam(),
8478
loss: tf.losses.sigmoidCrossEntropy,
85-
metrics: [tf.metrics.binaryAccuracy],
79+
metrics: [
80+
tf.metrics.binaryAccuracy,
81+
tf.metrics.precision,
82+
tf.metrics.recall,
83+
],
8684
});
8785

86+
model.summary();
87+
8888
model.fitDataset(trainData, {
8989
epochs: 5,
90-
validationData: valData
90+
validationData: valData,
91+
callbacks: [
92+
tf.callbacks.earlyStopping({
93+
patience: 1,
94+
}),
95+
],
96+
});
97+
98+
model.save(MODEL_PATH, {
99+
includeOptimizer: true,
91100
});
92101
}
93102

trainer/tsconfig.json

+7-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
1212

1313
/* Language and Environment */
14-
"target": "es2016", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
14+
"target": "es2016" /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */,
1515
// "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
1616
// "jsx": "preserve", /* Specify what JSX code is generated. */
1717
// "experimentalDecorators": true, /* Enable experimental support for TC39 stage 2 draft decorators. */
@@ -25,7 +25,7 @@
2525
// "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
2626

2727
/* Modules */
28-
"module": "commonjs", /* Specify what module code is generated. */
28+
"module": "commonjs" /* Specify what module code is generated. */,
2929
// "rootDir": "./", /* Specify the root folder within your source files. */
3030
// "moduleResolution": "node", /* Specify how TypeScript looks up a file from a given module specifier. */
3131
// "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
@@ -57,7 +57,7 @@
5757
// "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
5858
// "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
5959
// "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
60-
// "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
60+
"inlineSourceMap": true /* Include sourcemap files inside the emitted JavaScript. */,
6161
// "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
6262
// "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
6363
// "newLine": "crlf", /* Set the newline character for emitting files. */
@@ -71,12 +71,12 @@
7171
/* Interop Constraints */
7272
// "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
7373
// "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
74-
"esModuleInterop": true, /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
74+
"esModuleInterop": true /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */,
7575
// "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
76-
"forceConsistentCasingInFileNames": true, /* Ensure that casing is correct in imports. */
76+
"forceConsistentCasingInFileNames": true /* Ensure that casing is correct in imports. */,
7777

7878
/* Type Checking */
79-
"strict": true, /* Enable all strict type-checking options. */
79+
"strict": true /* Enable all strict type-checking options. */,
8080
// "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
8181
// "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
8282
// "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
@@ -98,6 +98,6 @@
9898

9999
/* Completeness */
100100
// "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
101-
"skipLibCheck": true /* Skip type checking all .d.ts files. */
101+
"skipLibCheck": true /* Skip type checking all .d.ts files. */
102102
}
103103
}

0 commit comments

Comments
 (0)