11from mdp import *
22
3+ sequential_decision_environment_1 = GridMDP ([[- 0.1 , - 0.1 , - 0.1 , + 1 ],
4+ [- 0.1 , None , - 0.1 , - 1 ],
5+ [- 0.1 , - 0.1 , - 0.1 , - 0.1 ]],
6+ terminals = [(3 , 2 ), (3 , 1 )])
7+
8+ sequential_decision_environment_2 = GridMDP ([[- 2 , - 2 , - 2 , + 1 ],
9+ [- 2 , None , - 2 , - 1 ],
10+ [- 2 , - 2 , - 2 , - 2 ]],
11+ terminals = [(3 , 2 ), (3 , 1 )])
12+
13+ sequential_decision_environment_3 = GridMDP ([[- 1.0 , - 0.1 , - 0.1 , - 0.1 , - 0.1 , 0.5 ],
14+ [- 0.1 , None , None , - 0.5 , - 0.1 , - 0.1 ],
15+ [- 0.1 , None , 1.0 , 3.0 , None , - 0.1 ],
16+ [- 0.1 , - 0.1 , - 0.1 , None , None , - 0.1 ],
17+ [0.5 , - 0.1 , - 0.1 , - 0.1 , - 0.1 , - 1.0 ]],
18+ terminals = [(2 , 2 ), (3 , 2 ), (0 , 4 ), (5 , 0 )])
319
420def test_value_iteration ():
521 assert value_iteration (sequential_decision_environment , .01 ) == {
@@ -10,6 +26,30 @@ def test_value_iteration():
1026 (2 , 0 ): 0.34461306281476806 , (2 , 1 ): 0.48643676237737926 ,
1127 (2 , 2 ): 0.79536093684710951 }
1228
29+ assert value_iteration (sequential_decision_environment_1 , .01 ) == {
30+ (3 , 2 ): 1.0 , (3 , 1 ): - 1.0 ,
31+ (3 , 0 ): - 0.0897388258468311 , (0 , 1 ): 0.146419707398967840 ,
32+ (0 , 2 ): 0.30596200514385086 , (1 , 0 ): 0.010092796415625799 ,
33+ (0 , 0 ): 0.00633408092008296 , (1 , 2 ): 0.507390193380827400 ,
34+ (2 , 0 ): 0.15072242145212010 , (2 , 1 ): 0.358309043654212570 ,
35+ (2 , 2 ): 0.71675493618997840 }
36+
37+ assert value_iteration (sequential_decision_environment_2 , .01 ) == {
38+ (3 , 2 ): 1.0 , (3 , 1 ): - 1.0 ,
39+ (3 , 0 ): - 3.5141584808407855 , (0 , 1 ): - 7.8000009574737180 ,
40+ (0 , 2 ): - 6.1064293596058830 , (1 , 0 ): - 7.1012549580376760 ,
41+ (0 , 0 ): - 8.5872244532783200 , (1 , 2 ): - 3.9653547121245810 ,
42+ (2 , 0 ): - 5.3099468802901630 , (2 , 1 ): - 3.3543366255753995 ,
43+ (2 , 2 ): - 1.7383376462930498 }
44+
45+ assert value_iteration (sequential_decision_environment_3 , .01 ) == {
46+ (0 , 0 ): 4.350592130345558 , (0 , 1 ): 3.640700980321895 , (0 , 2 ): 3.0734806370346943 , (0 , 3 ): 2.5754335063434937 , (0 , 4 ): - 1.0 ,
47+ (1 , 0 ): 3.640700980321895 , (1 , 1 ): 3.129579352304856 , (1 , 4 ): 2.0787517066719916 ,
48+ (2 , 0 ): 3.0259220379893352 , (2 , 1 ): 2.5926103577982897 , (2 , 2 ): 1.0 , (2 , 4 ): 2.507774181360808 ,
49+ (3 , 0 ): 2.5336747364500076 , (3 , 2 ): 3.0 , (3 , 3 ): 2.292172805400873 , (3 , 4 ): 2.996383110867515 ,
50+ (4 , 0 ): 2.1014575936349886 , (4 , 3 ): 3.1297590518608907 , (4 , 4 ): 3.6408806798779287 ,
51+ (5 , 0 ): - 1.0 , (5 , 1 ): 2.5756132058995282 , (5 , 2 ): 3.0736603365907276 , (5 , 3 ): 3.6408806798779287 , (5 , 4 ): 4.350771829901593 }
52+
1353
1454def test_policy_iteration ():
1555 assert policy_iteration (sequential_decision_environment ) == {
@@ -18,6 +58,26 @@ def test_policy_iteration():
1858 (2 , 1 ): (0 , 1 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (- 1 , 0 ),
1959 (3 , 1 ): None , (3 , 2 ): None }
2060
61+ assert policy_iteration (sequential_decision_environment_1 ) == {
62+ (0 , 0 ): (0 , 1 ), (0 , 1 ): (0 , 1 ), (0 , 2 ): (1 , 0 ),
63+ (1 , 0 ): (1 , 0 ), (1 , 2 ): (1 , 0 ), (2 , 0 ): (0 , 1 ),
64+ (2 , 1 ): (0 , 1 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (- 1 , 0 ),
65+ (3 , 1 ): None , (3 , 2 ): None }
66+
67+ assert policy_iteration (sequential_decision_environment_2 ) == {
68+ (0 , 0 ): (1 , 0 ), (0 , 1 ): (0 , 1 ), (0 , 2 ): (1 , 0 ),
69+ (1 , 0 ): (1 , 0 ), (1 , 2 ): (1 , 0 ), (2 , 0 ): (1 , 0 ),
70+ (2 , 1 ): (1 , 0 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (0 , 1 ),
71+ (3 , 1 ): None , (3 , 2 ): None }
72+
73+ assert policy_iteration (sequential_decision_environment_3 ) == {
74+ (0 , 0 ): (- 1 , 0 ), (0 , 1 ): (0 , - 1 ), (0 , 2 ): (0 , - 1 ), (0 , 3 ): (0 , - 1 ), (0 , 4 ): None ,
75+ (1 , 0 ): (- 1 , 0 ), (1 , 1 ): (- 1 , 0 ), (1 , 4 ): (1 , 0 ),
76+ (2 , 0 ): (- 1 , 0 ), (2 , 1 ): (0 , - 1 ), (2 , 2 ): None , (2 , 4 ): (1 , 0 ),
77+ (3 , 0 ): (- 1 , 0 ), (3 , 2 ): None , (3 , 3 ): (1 , 0 ), (3 , 4 ): (1 , 0 ),
78+ (4 , 0 ): (- 1 , 0 ), (4 , 3 ): (1 , 0 ), (4 , 4 ): (1 , 0 ),
79+ (5 , 0 ): None , (5 , 1 ): (0 , 1 ), (5 , 2 ): (0 , 1 ), (5 , 3 ): (0 , 1 ), (5 , 4 ): (1 , 0 )}
80+
2181
2282def test_best_policy ():
2383 pi = best_policy (sequential_decision_environment ,
@@ -26,6 +86,26 @@ def test_best_policy():
2686 ['^' , None , '^' , '.' ],
2787 ['^' , '>' , '^' , '<' ]]
2888
89+ pi_1 = best_policy (sequential_decision_environment_1 ,
90+ value_iteration (sequential_decision_environment_1 , .01 ))
91+ assert sequential_decision_environment_1 .to_arrows (pi_1 ) == [['>' , '>' , '>' , '.' ],
92+ ['^' , None , '^' , '.' ],
93+ ['^' , '>' , '^' , '<' ]]
94+
95+ pi_2 = best_policy (sequential_decision_environment_2 ,
96+ value_iteration (sequential_decision_environment_2 , .01 ))
97+ assert sequential_decision_environment_2 .to_arrows (pi_2 ) == [['>' , '>' , '>' , '.' ],
98+ ['^' , None , '>' , '.' ],
99+ ['>' , '>' , '>' , '^' ]]
100+
101+ pi_3 = best_policy (sequential_decision_environment_3 ,
102+ value_iteration (sequential_decision_environment_3 , .01 ))
103+ assert sequential_decision_environment_3 .to_arrows (pi_3 ) == [['.' , '>' , '>' , '>' , '>' , '>' ],
104+ ['v' , None , None , '>' , '>' , '^' ],
105+ ['v' , None , '.' , '.' , None , '^' ],
106+ ['v' , '<' , 'v' , None , None , '^' ],
107+ ['<' , '<' , '<' , '<' , '<' , '.' ]]
108+
29109
30110def test_transition_model ():
31111 transition_model = {
0 commit comments