11from  mdp  import  * 
22
3+ sequential_decision_environment_1  =  GridMDP ([[- 0.1 , - 0.1 , - 0.1 , + 1 ],
4+                                              [- 0.1 , None , - 0.1 , - 1 ],
5+                                              [- 0.1 , - 0.1 , - 0.1 , - 0.1 ]],
6+                                             terminals = [(3 , 2 ), (3 , 1 )])
7+ 
8+ sequential_decision_environment_2  =  GridMDP ([[- 2 , - 2 , - 2 , + 1 ],
9+                                              [- 2 , None , - 2 , - 1 ],
10+                                              [- 2 , - 2 , - 2 , - 2 ]],
11+                                             terminals = [(3 , 2 ), (3 , 1 )])
12+ 
13+ sequential_decision_environment_3  =  GridMDP ([[- 1.0 , - 0.1 , - 0.1 , - 0.1 , - 0.1 , 0.5 ], 
14+                                              [- 0.1 , None , None , - 0.5 , - 0.1 , - 0.1 ], 
15+                                              [- 0.1 , None , 1.0 , 3.0 , None , - 0.1 ], 
16+                                              [- 0.1 , - 0.1 , - 0.1 , None , None , - 0.1 ], 
17+                                              [0.5 , - 0.1 , - 0.1 , - 0.1 , - 0.1 , - 1.0 ]],
18+                                             terminals = [(2 , 2 ), (3 , 2 ), (0 , 4 ), (5 , 0 )])
319
420def  test_value_iteration ():
521    assert  value_iteration (sequential_decision_environment , .01 ) ==  {
@@ -10,6 +26,30 @@ def test_value_iteration():
1026        (2 , 0 ): 0.34461306281476806 , (2 , 1 ): 0.48643676237737926 ,
1127        (2 , 2 ): 0.79536093684710951 }
1228
29+     assert  value_iteration (sequential_decision_environment_1 , .01 ) ==  {
30+         (3 , 2 ): 1.0 , (3 , 1 ): - 1.0 ,  
31+         (3 , 0 ): - 0.0897388258468311 , (0 , 1 ): 0.146419707398967840 , 
32+         (0 , 2 ): 0.30596200514385086 , (1 , 0 ): 0.010092796415625799 ,
33+         (0 , 0 ): 0.00633408092008296 , (1 , 2 ): 0.507390193380827400 , 
34+         (2 , 0 ): 0.15072242145212010 , (2 , 1 ): 0.358309043654212570 , 
35+         (2 , 2 ): 0.71675493618997840 }
36+ 
37+     assert  value_iteration (sequential_decision_environment_2 , .01 ) ==  {
38+         (3 , 2 ): 1.0 , (3 , 1 ): - 1.0 , 
39+         (3 , 0 ): - 3.5141584808407855 , (0 , 1 ): - 7.8000009574737180 ,
40+         (0 , 2 ): - 6.1064293596058830 , (1 , 0 ): - 7.1012549580376760 ,
41+         (0 , 0 ): - 8.5872244532783200 , (1 , 2 ): - 3.9653547121245810 ,
42+         (2 , 0 ): - 5.3099468802901630 , (2 , 1 ): - 3.3543366255753995 ,
43+         (2 , 2 ): - 1.7383376462930498 }
44+ 
45+     assert  value_iteration (sequential_decision_environment_3 , .01 ) ==  {
46+         (0 , 0 ): 4.350592130345558 , (0 , 1 ): 3.640700980321895 , (0 , 2 ): 3.0734806370346943 , (0 , 3 ): 2.5754335063434937 , (0 , 4 ): - 1.0 ,
47+         (1 , 0 ): 3.640700980321895 , (1 , 1 ): 3.129579352304856 , (1 , 4 ): 2.0787517066719916 ,
48+         (2 , 0 ): 3.0259220379893352 , (2 , 1 ): 2.5926103577982897 , (2 , 2 ): 1.0 , (2 , 4 ): 2.507774181360808 ,
49+         (3 , 0 ): 2.5336747364500076 , (3 , 2 ): 3.0 , (3 , 3 ): 2.292172805400873 , (3 , 4 ): 2.996383110867515 ,
50+         (4 , 0 ): 2.1014575936349886 , (4 , 3 ): 3.1297590518608907 , (4 , 4 ): 3.6408806798779287 ,
51+         (5 , 0 ): - 1.0 , (5 , 1 ): 2.5756132058995282 , (5 , 2 ): 3.0736603365907276 , (5 , 3 ): 3.6408806798779287 , (5 , 4 ): 4.350771829901593 }
52+ 
1353
1454def  test_policy_iteration ():
1555    assert  policy_iteration (sequential_decision_environment ) ==  {
@@ -18,6 +58,26 @@ def test_policy_iteration():
1858        (2 , 1 ): (0 , 1 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (- 1 , 0 ),
1959        (3 , 1 ): None , (3 , 2 ): None }
2060
61+     assert  policy_iteration (sequential_decision_environment_1 ) ==  {
62+         (0 , 0 ): (0 , 1 ), (0 , 1 ): (0 , 1 ), (0 , 2 ): (1 , 0 ),
63+         (1 , 0 ): (1 , 0 ), (1 , 2 ): (1 , 0 ), (2 , 0 ): (0 , 1 ),
64+         (2 , 1 ): (0 , 1 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (- 1 , 0 ),
65+         (3 , 1 ): None , (3 , 2 ): None }
66+ 
67+     assert  policy_iteration (sequential_decision_environment_2 ) ==  {
68+         (0 , 0 ): (1 , 0 ), (0 , 1 ): (0 , 1 ), (0 , 2 ): (1 , 0 ),
69+         (1 , 0 ): (1 , 0 ), (1 , 2 ): (1 , 0 ), (2 , 0 ): (1 , 0 ),
70+         (2 , 1 ): (1 , 0 ), (2 , 2 ): (1 , 0 ), (3 , 0 ): (0 , 1 ),
71+         (3 , 1 ): None , (3 , 2 ): None }
72+ 
73+     assert  policy_iteration (sequential_decision_environment_3 ) ==  {
74+         (0 , 0 ): (- 1 , 0 ), (0 , 1 ): (0 , - 1 ), (0 , 2 ): (0 , - 1 ), (0 , 3 ): (0 , - 1 ), (0 , 4 ): None ,
75+         (1 , 0 ): (- 1 , 0 ), (1 , 1 ): (- 1 , 0 ), (1 , 4 ): (1 , 0 ),
76+         (2 , 0 ): (- 1 , 0 ), (2 , 1 ): (0 , - 1 ), (2 , 2 ): None , (2 , 4 ): (1 , 0 ),
77+         (3 , 0 ): (- 1 , 0 ), (3 , 2 ): None , (3 , 3 ): (1 , 0 ), (3 , 4 ): (1 , 0 ),
78+         (4 , 0 ): (- 1 , 0 ), (4 , 3 ): (1 , 0 ), (4 , 4 ): (1 , 0 ),
79+         (5 , 0 ): None , (5 , 1 ): (0 , 1 ), (5 , 2 ): (0 , 1 ), (5 , 3 ): (0 , 1 ), (5 , 4 ): (1 , 0 )}
80+ 
2181
2282def  test_best_policy ():
2383    pi  =  best_policy (sequential_decision_environment ,
@@ -26,6 +86,26 @@ def test_best_policy():
2686                                                             ['^' , None , '^' , '.' ],
2787                                                             ['^' , '>' , '^' , '<' ]]
2888
89+     pi_1  =  best_policy (sequential_decision_environment_1 ,
90+                      value_iteration (sequential_decision_environment_1 , .01 ))
91+     assert  sequential_decision_environment_1 .to_arrows (pi_1 ) ==  [['>' , '>' , '>' , '.' ],
92+                                                                  ['^' , None , '^' , '.' ],
93+                                                                  ['^' , '>' , '^' , '<' ]]
94+ 
95+     pi_2  =  best_policy (sequential_decision_environment_2 ,
96+                      value_iteration (sequential_decision_environment_2 , .01 ))
97+     assert  sequential_decision_environment_2 .to_arrows (pi_2 ) ==  [['>' , '>' , '>' , '.' ],
98+                                                                  ['^' , None , '>' , '.' ],
99+                                                                  ['>' , '>' , '>' , '^' ]]
100+ 
101+     pi_3  =  best_policy (sequential_decision_environment_3 ,
102+                      value_iteration (sequential_decision_environment_3 , .01 ))
103+     assert  sequential_decision_environment_3 .to_arrows (pi_3 ) ==  [['.' , '>' , '>' , '>' , '>' , '>' ], 
104+                                                                  ['v' , None , None , '>' , '>' , '^' ], 
105+                                                                  ['v' , None , '.' , '.' , None , '^' ], 
106+                                                                  ['v' , '<' , 'v' , None , None , '^' ], 
107+                                                                  ['<' , '<' , '<' , '<' , '<' , '.' ]]                                                               
108+ 
29109
30110def  test_transition_model ():
31111    transition_model  =  {
0 commit comments