-
Notifications
You must be signed in to change notification settings - Fork 4
/
space.yaml
98 lines (72 loc) · 1.84 KB
/
space.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# @package _global_
defaults:
- _base
batch_size: 32 # 12
trainer:
_target_: models.space.trainer.SPACETrainer
steps: 200_000
# z_pres prior
z_pres_start_step: 4000
z_pres_end_step: 10000
z_pres_start_value: 0.5 # 0.1
z_pres_end_value: 0.05 # 0.01
# z_scale prior
z_scale_mean_start_step: 10_000
z_scale_mean_end_step: 20_000
z_scale_mean_start_value: 0.0 # -1.0
z_scale_mean_end_value: -1.0 # -2.0
# Temperature for gumbel-softmax
tau_start_step: 0
tau_end_step: 50_000 # 20_000
tau_start_value: 2.5
tau_end_value: 0.5
# Turn on boundary loss or not
boundary_loss: true
# When to turn off boundary loss
bl_off_step: 20_000 # 100_000
# Fix alpha for the first N steps
fix_alpha_steps: 0
fix_alpha_value: 0.1
clip_grad_norm: 1.0
optimizer_config:
fg:
alg: RMSprop
lr: 3e-5 # 1e-5
bg:
alg: Adam
lr: 1e-3
model:
_target_: models.space.model.SPACE
name: space
# The number of slots is G*G+K and must be set at runtime because OmegaConf doesn't support operators.
num_slots: null
fg_params:
# Grid size. There will be G*G slots
G: 8 # 4
# Foreground likelihood sigma
fg_sigma: 0.15
glimpse_size: 32
# Encoded image feature channels
img_enc_dim_fg: 128
# Latent dimensions
z_pres_dim: 1
z_depth_dim: 1
z_where_scale_dim: 2
z_where_shift_dim: 2
z_what_dim: 32
z_scale_std_value: 0.1
bg_params:
# Number of background components
K: 5
# Background likelihood sigma
bg_sigma: 0.15
# Image encoding dimension
img_enc_dim_bg: 64
# Latent dimensions
z_mask_dim: 32
z_comp_dim: 32
# These should be the same
rnn_mask_hidden_dim: 64
rnn_mask_prior_hidden_dim: 64
# Hidden layer dim for the network that computes q(z_c|z_m, x)
predict_comp_hidden_dim: 64