Run this example code in Matlab. This provides a real-time visualization of agent learning.
x% Windy gridworld
% Set up grid
nrows = 7;
ncols = 10;
% Initial and terminal states
S0 = [4,1];
ST = [4,8];
% Wind, actions, etc.
wind = [0 0 0 1 1 1 2 2 1 0];
actions = [1:4]; % Up, down, left, right
num_actions = length(actions);
num_states = nrows * ncols;
% Initialize Q and set other parameters
Q = zeros(num_states,num_actions);
epsilon = 0.1;
alpha = 0.5;
gamma = 1;
% Run algorithm
e = 1;
t = 1;
performance = zeros(8000,1);
cmap = [0 0 0; parula(64)]; % Colormap
while t <= 8000
t0 = t;
% Choose initial state
s1 = S0(1); % Grid row index
s2 = S0(2); % Grid column index
s = (s1-1)*ncols + s2; % State index
map = zeros(nrows,ncols);
map(s1,s2) = map(s1,s2) + 1;
figure(1); clf;
han = imagesc(map); axis image; colormap(cmap);
str = sprintf('Episode %3d, Step %4d, Total Steps %4d',e,t-t0+1,t);
hant = title(str); drawnow;
% Construct epsilon-greedy policy
[~,a] = max(Q(s,:));
pi = (epsilon/num_actions)*ones(num_actions,1);
pi(a) = 1 - epsilon + (epsilon/num_actions);
% Epsilon greedy action
a = min(find((rand()<=cumsum(pi))==1));
% Run episode
while 1
% Take action and observe R, S'
% Set reward
r = -1;
% Up action
switch a
case 1, s1 = s1 - 1; % Up
case 2, s1 = s1 + 1; % Down
case 3, s2 = s2 - 1; % Left
case 4, s2 = s2 + 1; % Right
end
% Keep on the grid
s2 = max([1,s2]);
s2 = min([ncols,s2]);
% Apply wind
s1 = s1 - wind(s2);
% Keep on the grid
s1 = max([1,s1]);
s1 = min([nrows,s1]);
sp = (s1-1)*ncols + s2;
map(s1,s2) = map(s1,s2) + 1;
set(han,'CData',map); drawnow;
str = sprintf('Episode %3d, Step %4d, Total Steps %4d',e,t-t0+1,t);
set(hant,'String',str); drawnow;
% Construct epsilon-greedy policy
[~,ap] = max(Q(sp,:));
pi = (epsilon/num_actions)*ones(num_actions,1);
pi(ap) = 1 - epsilon + (epsilon/num_actions);
% Epsilon greedy action
ap = min(find((rand()<=cumsum(pi))==1));
% SARSA update
delta = r + gamma*Q(sp,ap) - Q(s,a);
Q(s,a) = Q(s,a) + alpha*delta;
% Update state-action pair
s = sp;
a = ap;
performance(t) = e;
t = t + 1;
% Termiation
if(s1 == ST(1) && s2 == ST(2)) break; end
end
e = e + 1;
end
figure(2); clf;
plot(performance);
xlabel('Total Time Steps','FontSize',16);
ylabel('Episodes','FontSize',16);