diff --git a/DP/Policy Evaluation Solution.ipynb b/DP/Policy Evaluation Solution.ipynb index a8b949367..a4399cd77 100644 --- a/DP/Policy Evaluation Solution.ipynb +++ b/DP/Policy Evaluation Solution.ipynb @@ -2,10 +2,8 @@ "cells": [ { "cell_type": "code", - "execution_count": 53, - "metadata": { - "collapsed": false - }, + "execution_count": 1, + "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", @@ -18,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 2, "metadata": { "collapsed": true }, @@ -30,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 3, "metadata": { "collapsed": true }, @@ -43,9 +41,9 @@ " Args:\n", " policy: [S, A] shaped matrix representing the policy.\n", " env: OpenAI env. env.P represents the transition probabilities of the environment.\n", - " env.P[s][a] is a (prob, next_state, reward, done) tuple.\n", + " env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).\n", " theta: We stop evaluation once our value function change is less than theta for all states.\n", - " discount_factor: lambda discount factor.\n", + " discount_factor: gamma discount factor.\n", " \n", " Returns:\n", " Vector of length env.nS representing the value function.\n", @@ -54,18 +52,23 @@ " V = np.zeros(env.nS)\n", " while True:\n", " delta = 0\n", + " new_V = np.copy(V) # new_V is V_{k+1} \n", " # For each state, perform a \"full backup\"\n", " for s in range(env.nS):\n", " v = 0\n", " # Look at the possible next actions\n", " for a, action_prob in enumerate(policy[s]):\n", " # For each action, look at the possible next states...\n", + " immediate_reward = env.P[s][a][0][2]\n", + " v_prim = 0\n", " for prob, next_state, reward, done in env.P[s][a]:\n", " # Calculate the expected value\n", - " v += action_prob * prob * (reward + discount_factor * V[next_state])\n", + " v_prim += discount_factor * prob * V[next_state]\n", + " v += action_prob * (immediate_reward + v_prim)\n", + " new_V[s] = v\n", " # How much our value function changed (across any states)\n", " delta = max(delta, np.abs(v - V[s]))\n", - " V[s] = v\n", + " V = new_V \n", " # Stop evaluating once our value function change is below a threshold\n", " if delta < theta:\n", " break\n", @@ -74,10 +77,8 @@ }, { "cell_type": "code", - "execution_count": 56, - "metadata": { - "collapsed": false - }, + "execution_count": 4, + "metadata": {}, "outputs": [], "source": [ "random_policy = np.ones([env.nS, env.nA]) / env.nA\n", @@ -86,25 +87,23 @@ }, { "cell_type": "code", - "execution_count": 57, - "metadata": { - "collapsed": false - }, + "execution_count": 5, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Value Function:\n", - "[ 0. -13.99993529 -19.99990698 -21.99989761 -13.99993529\n", - " -17.9999206 -19.99991379 -19.99991477 -19.99990698 -19.99991379\n", - " -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569 0. ]\n", + "[ 0. -13.99989315 -19.99984167 -21.99982282 -13.99989315\n", + " -17.99986052 -19.99984273 -19.99984167 -19.99984167 -19.99984273\n", + " -17.99986052 -13.99989315 -21.99982282 -19.99984167 -13.99989315 0. ]\n", "\n", "Reshaped Grid Value Function:\n", - "[[ 0. -13.99993529 -19.99990698 -21.99989761]\n", - " [-13.99993529 -17.9999206 -19.99991379 -19.99991477]\n", - " [-19.99990698 -19.99991379 -17.99992725 -13.99994569]\n", - " [-21.99989761 -19.99991477 -13.99994569 0. ]]\n", + "[[ 0. -13.99989315 -19.99984167 -21.99982282]\n", + " [-13.99989315 -17.99986052 -19.99984273 -19.99984167]\n", + " [-19.99984167 -19.99984273 -17.99986052 -13.99989315]\n", + " [-21.99982282 -19.99984167 -13.99989315 0. ]]\n", "\n" ] } @@ -121,10 +120,8 @@ }, { "cell_type": "code", - "execution_count": 51, - "metadata": { - "collapsed": false - }, + "execution_count": 6, + "metadata": {}, "outputs": [], "source": [ "# Test: Make sure the evaluated policy is what we expected\n", @@ -158,9 +155,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.1" + "version": "3.6.0" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 }