1
+ import os
2
+ import h5py
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from gym .spaces import Box , Discrete , Tuple
7
+
8
+ from envs import get_dim
9
+ from replay_buffer import ReplayBuffer
10
+
11
+
12
+ class MixedReplayBuffer (ReplayBuffer ):
13
+ def __init__ (self , reward_scale , reward_bias , clip_action , state_dim , action_dim , task = "halfcheetah" , data_source = "medium_replay" , device = "cuda" , scale_rewards = True , scale_state = False , buffer_ratio = 1 , residual_ratio = 0.1 ):
14
+ super ().__init__ (state_dim , action_dim , device = device )
15
+
16
+ self .scale_rewards = scale_rewards
17
+ self .scale_state = scale_state
18
+ self .buffer_ratio = buffer_ratio
19
+ self .residual_ratio = residual_ratio
20
+
21
+ # load expert dataset into the replay buffer
22
+ path = os .path .join ("../../d4rl_mujoco_dataset" , "{}_{}-v2.hdf5" .format (task , data_source ))
23
+ with h5py .File (path , "r" ) as dataset :
24
+ total_num = dataset ['observations' ].shape [0 ]
25
+ # idx = random.sample(range(total_num), int(total_num * self.residual_ratio))
26
+ idx = np .random .choice (range (total_num ), int (total_num * self .residual_ratio ), replace = False )
27
+ s = np .vstack (np .array (dataset ['observations' ])).astype (np .float32 )[idx , :] # An (N, dim_observation)-dimensional numpy array of observations
28
+ a = np .vstack (np .array (dataset ['actions' ])).astype (np .float32 )[idx , :] # An (N, dim_action)-dimensional numpy array of actions
29
+ r = np .vstack (np .array (dataset ['rewards' ])).astype (np .float32 )[idx , :] # An (N,)-dimensional numpy array of rewards
30
+ s_ = np .vstack (np .array (dataset ['next_observations' ])).astype (np .float32 )[idx , :] # An (N, dim_observation)-dimensional numpy array of next observations
31
+ done = np .vstack (np .array (dataset ['terminals' ]))[idx , :] # An (N,)-dimensional numpy array of terminal flags
32
+
33
+ # whether to bias the reward
34
+ r = r * reward_scale + reward_bias
35
+ # whether to clip actions
36
+ a = np .clip (a , - clip_action , clip_action )
37
+
38
+ fixed_dataset_size = r .shape [0 ]
39
+ self .fixed_dataset_size = fixed_dataset_size
40
+ self .ptr = fixed_dataset_size
41
+ self .size = fixed_dataset_size
42
+ self .max_size = (self .buffer_ratio + 1 ) * fixed_dataset_size
43
+
44
+ self .state = np .vstack ((s , np .zeros ((self .max_size - self .fixed_dataset_size , state_dim ))))
45
+ self .action = np .vstack ((a , np .zeros ((self .max_size - self .fixed_dataset_size , action_dim ))))
46
+ self .next_state = np .vstack ((s_ , np .zeros ((self .max_size - self .fixed_dataset_size , state_dim ))))
47
+ self .reward = np .vstack ((r , np .zeros ((self .max_size - self .fixed_dataset_size , 1 ))))
48
+ self .done = np .vstack ((done , np .zeros ((self .max_size - self .fixed_dataset_size , 1 ))))
49
+ self .device = torch .device (device )
50
+
51
+ # # State normalization
52
+ self .normalize_states ()
53
+
54
+
55
+
56
+ def normalize_states (self , eps = 1e-3 ):
57
+ # STATE: standard normalization
58
+ self .state_mean = self .state .mean (0 , keepdims = True )
59
+ self .state_std = self .state .std (0 , keepdims = True ) + eps
60
+ if self .scale_state :
61
+ self .state = (self .state - self .state_mean ) / self .state_std
62
+ self .next_state = (self .next_state - self .state_mean ) / self .state_std
63
+
64
+ def append (self , s , a , r , s_ , done ):
65
+
66
+ self .state [self .ptr ] = s
67
+ self .action [self .ptr ] = a
68
+ self .next_state [self .ptr ] = s_
69
+ self .reward [self .ptr ] = r
70
+ self .done [self .ptr ] = done
71
+
72
+ # fix the offline dataset and shuffle the simulated part
73
+ self .ptr = (self .ptr + 1 - self .fixed_dataset_size ) % (self .max_size - self .fixed_dataset_size ) + self .fixed_dataset_size
74
+ self .size = min (self .size + 1 , self .max_size )
75
+
76
+ def append_traj (self , observations , actions , rewards , next_observations , dones ):
77
+ for o , a , r , no , d in zip (observations , actions , rewards , next_observations , dones ):
78
+ self .append (o , a , r , no , d )
79
+
80
+ def sample (self , batch_size , scope = None , type = None ):
81
+ if scope == None :
82
+ ind = np .random .randint (0 , self .size , size = batch_size )
83
+ elif scope == "real" :
84
+ ind = np .random .randint (0 , self .fixed_dataset_size , size = batch_size )
85
+ elif scope == "sim" :
86
+ ind = np .random .randint (self .fixed_dataset_size , self .size , size = batch_size )
87
+ else :
88
+ raise RuntimeError ("Misspecified range for replay buffer sampling" )
89
+
90
+ if type == None :
91
+ return {
92
+ 'observations' : torch .FloatTensor (self .state [ind ]).to (self .device ),
93
+ 'actions' : torch .FloatTensor (self .action [ind ]).to (self .device ),
94
+ 'rewards' : torch .FloatTensor (self .reward [ind ]).to (self .device ),
95
+ 'next_observations' : torch .FloatTensor (self .next_state [ind ]).to (self .device ),
96
+ 'dones' : torch .FloatTensor (self .done [ind ]).to (self .device )
97
+ }
98
+ elif type == "sas" :
99
+ return {
100
+ 'observations' : torch .FloatTensor (self .state [ind ]).to (self .device ),
101
+ 'actions' : torch .FloatTensor (self .action [ind ]).to (self .device ),
102
+ 'next_observations' : torch .FloatTensor (self .next_state [ind ]).to (self .device )
103
+ }
104
+ elif type == "sa" :
105
+ return {
106
+ 'observations' : torch .FloatTensor (self .state [ind ]).to (self .device ),
107
+ 'actions' : torch .FloatTensor (self .action [ind ]).to (self .device )
108
+ }
109
+ else :
110
+ raise RuntimeError ("Misspecified return data types for replay buffer sampling" )
111
+
112
+ def get_mean_std (self ):
113
+ return torch .FloatTensor (self .state_mean ).to (self .device ), torch .FloatTensor (self .state_std ).to (self .device )
0 commit comments