Skip to content

Commit

Permalink
20220513
Browse files Browse the repository at this point in the history
  • Loading branch information
sun1638650145 committed May 13, 2022
1 parent 0c98181 commit 405b9e2
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
34 changes: 30 additions & 4 deletions API.md
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,15 @@ int main() {
env.action_space
```

#### 4.1.1.1.sample()
#### 4.1.1.1.n

随机动作的总数.|`int`

```python
env.action_space.n
```

#### 4.1.1.2.sample()

获取随机动作.|`int`

Expand All @@ -923,23 +931,41 @@ action = env.action_space.sample()
env.close()
```

### 4.1.3.render()
### 4.1.3.observation_space

#### 4.1.3.1.sample()

获取随机环境向量.|`numpy.ndarray`

```python
env.observation_space.sample()
```

#### 4.1.3.2.shape

可观察的环境向量的形状.|`tuple`

```python
env.observation_space.shape
```

### 4.1.4.render()

渲染环境.

```python
env.render()
```

### 4.1.4.reset()
### 4.1.5.reset()

重置环境.|`numpy.ndarray`

```python
observation = env.reset()
```

### 4.1.5.step()
### 4.1.6.step()

执行动作.|`numpy.ndarray`, `int`, `bool``dict`

Expand Down
56 changes: 55 additions & 1 deletion PyTorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -611,4 +611,58 @@ from torchvision.transforms import ToTensor

arr = np.asarray([[1, 2, 3]])
tensor = ToTensor()(pic=arr) # PIL Image or numpy.ndarray|要转换的图像.
```
```

# 3.stable_baselines3

| 版本 | 描述 | 注意 | 适配M1 |
| ----- | -------------------------------- | ---- | ------ |
| 1.5.0 | Torch的强化学习Stable Baselines. | - ||

## 3.1.common

### 3.1.1.env_util

#### 3.1.1.1.make_vec_env()

创建一组并行环境.|`stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv`

```python
from stable_baselines3.common.env_util import make_vec_env

envs = make_vec_env(env_id='LunarLander-v2', # str|环境id.
n_envs=16) # int|1|并行的环境数量.
```

## 3.2.PPO()

实例化近端策略算法.

```python
model = PPO(policy='MlpPolicy', # {'MlpPolicy', 'CnnPolicy'}|使用的策略.
env=envs, # gym.env|Gym环境.
n_steps=1024, # int|2048|每次更新为每个环境的时间步数.
batch_size=64, # int|64|批次大小.
n_epochs=4, # int|10|优化代理损失的轮数.
gamma=0.999, # float|0.99|折扣系数.
gae_lambda=0.98, # float|0.95|广义优势估计器的偏差与方差权衡因子.
ent_coef=0.01, # float|0.0|损失计算的熵系数.
verbose=1) # {0, 1, 2}|0|日志显示模式.
```

### 3.2.1.learn()

训练模型.

```python
model.learn(total_timesteps=200000) # int|训练步数.
```

### 3.2.2.save()

保存模型到zip文件.

```python
model.save(path='./ppo-LunarLander-v2') # str|文件名.
```

0 comments on commit 405b9e2

Please sign in to comment.