You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
2024-09-19 11:57:06.630942: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-19 11:57:14.731284: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
0.5.1 1.0.0a2 2.2.2+cu121 1.26.2 3.10.12 (main, Mar 22 2024, 16:50:05) [GCC 11.4.0] linux
Hello,
I am trying to setup an actor network that outputs (mu,sigma), state
classActor(nn.Module):
def__init__(self, state_shape, action_shape, max_action,min_action, device=device):
super().__init__()
self.device=deviceself.cnn=ThreeDCNN().to(self.device) # Our 3D CNNself.well_conv=nn.Conv1d(
in_channels=state_shape['well_observations'][1], # Number of features (15)out_channels=32, # You can adjust thiskernel_size=3,
padding=1
).to(self.device)
# Calculate the size of the flattened 3D CNN outputwithtorch.no_grad():
sample_input=torch.randn(1, 2, 4, 163, 120).to(self.device) # Example input shapecnn_output_size=self.cnn(sample_input).view(1, -1).size(1)
sample_well_obs=torch.randn(1, *state_shape['well_observations']).to(self.device)
well_cnn_output_size=self.well_conv(sample_well_obs.permute(0, 2, 1)).view(1, -1).size(1)
# Combine the CNN output size with the sizes of other observationscombined_size=cnn_output_size+well_cnn_output_sizeself.fc1=nn.Linear(combined_size, 128).to(self.device)
self.fc2=nn.Linear(128, 128).to(self.device)
self.fc_mu=nn.Linear(128, action_shape[0]).to(self.device)
self.fc_std=nn.Linear(128, action_shape[0]).to(self.device)
nn.init.xavier_uniform_(self.fc_mu.weight)
nn.init.xavier_uniform_(self.fc_std.weight)
self.max_action=max_actionself.min_action=min_action#self.sigma_param = nn.Parameter(torch.ones(action_shape[0], 1))*0.1defforward(self, obs, state=None, info={}):
image=torch.from_numpy(obs['res_state']).to(self.device)
image_features=self.cnn(image)
well_obs=torch.from_numpy(obs['well_observations']).to(self.device)
well_obs=well_obs.permute(0, 2, 1).float() # Change the shape to [batch_size, num_features, num_wells]well_features=torch.relu(self.well_conv(well_obs))
well_features=well_features.view(well_features.size(0), -1)
# Flatten the CNN output and combine with other observations#This should changecombined_obs=torch.cat([image_features,well_features], dim=1)
x=torch.relu(self.fc1(combined_obs))
logits=torch.relu(self.fc2(x))
# Compute the action mean (mu)mu=self.fc_mu(logits)
sigma=self.fc_std(logits)
max_action_tensor=torch.from_numpy(self.max_action).to(self.device).float()
min_action_tensor=torch.from_numpy(self.min_action).to(self.device).float()
action_range= (max_action_tensor-min_action_tensor)/2action_midpoint= (max_action_tensor+min_action_tensor) /2print(max_action_tensor,min_action_tensor)
print(max_action_tensor.shape,min_action_tensor.shape)
print(action_range,action_midpoint)
# Bound the action with tanh if requiredmu=torch.tanh(mu) *action_range+action_midpointsigma= (torch.clamp(sigma, min=-3, max=1)).exp() # Make sigma always positive# Compute the action standard deviation (sigma)print(mu,sigma)
# Return the action mean and standard deviationreturn (mu, sigma), state`
My issue is that the outputs of the actor do not match the received actions at all, for example:
Hello,
I am trying to setup an actor network that outputs (mu,sigma), state
My issue is that the outputs of the actor do not match the received actions at all, for example:
if those values are directly given to the torch Normal the output is fine.
So these examples are still at the first stages, before training, where the network evaluates the test envs initially.
here is the following code with parameters of training selected
The text was updated successfully, but these errors were encountered: