|
7 | 7 | from utils import ( |
8 | 8 | is_in, argmin, argmax, argmax_random_tie, probability, weighted_sampler, |
9 | 9 | memoize, print_table, open_data, Stack, FIFOQueue, PriorityQueue, name, |
10 | | - distance |
| 10 | + distance, vector_add |
11 | 11 | ) |
12 | 12 |
|
13 | 13 | from collections import defaultdict |
@@ -526,39 +526,37 @@ def and_search(states, problem, path): |
526 | 526 | # body of and or search |
527 | 527 | return or_search(problem.initial, problem, []) |
528 | 528 |
|
| 529 | +# Pre-defined actions for PeakFindingProblem |
| 530 | +directions4 = { 'W':(-1, 0), 'N':(0, 1), 'E':(1, 0), 'S':(0, -1) } |
| 531 | +directions8 = dict(directions4) |
| 532 | +directions8.update({'NW':(-1, 1), 'NE':(1, 1), 'SE':(1, -1), 'SW':(-1, -1) }) |
529 | 533 |
|
530 | 534 | class PeakFindingProblem(Problem): |
531 | 535 | """Problem of finding the highest peak in a limited grid""" |
532 | 536 |
|
533 | | - def __init__(self, initial, grid): |
| 537 | + def __init__(self, initial, grid, defined_actions=directions4): |
534 | 538 | """The grid is a 2 dimensional array/list whose state is specified by tuple of indices""" |
535 | 539 | Problem.__init__(self, initial) |
536 | 540 | self.grid = grid |
| 541 | + self.defined_actions = defined_actions |
537 | 542 | self.n = len(grid) |
538 | 543 | assert self.n > 0 |
539 | 544 | self.m = len(grid[0]) |
540 | 545 | assert self.m > 0 |
541 | 546 |
|
542 | 547 | def actions(self, state): |
543 | | - """Allows movement in only 4 directions""" |
544 | | - # TODO: Add flag to allow diagonal motion |
| 548 | + """Returns the list of actions which are allowed to be taken from the given state""" |
545 | 549 | allowed_actions = [] |
546 | | - if state[0] > 0: |
547 | | - allowed_actions.append('N') |
548 | | - if state[0] < self.n - 1: |
549 | | - allowed_actions.append('S') |
550 | | - if state[1] > 0: |
551 | | - allowed_actions.append('W') |
552 | | - if state[1] < self.m - 1: |
553 | | - allowed_actions.append('E') |
| 550 | + for action in self.defined_actions: |
| 551 | + next_state = vector_add(state, self.defined_actions[action]) |
| 552 | + if next_state[0] >= 0 and next_state[1] >= 0 and next_state[0] <= self.n - 1 and next_state[1] <= self.m - 1: |
| 553 | + allowed_actions.append(action) |
| 554 | + |
554 | 555 | return allowed_actions |
555 | 556 |
|
556 | 557 | def result(self, state, action): |
557 | 558 | """Moves in the direction specified by action""" |
558 | | - x, y = state |
559 | | - x = x + (1 if action == 'S' else (-1 if action == 'N' else 0)) |
560 | | - y = y + (1 if action == 'E' else (-1 if action == 'W' else 0)) |
561 | | - return (x, y) |
| 559 | + return vector_add(state, self.defined_actions[action]) |
562 | 560 |
|
563 | 561 | def value(self, state): |
564 | 562 | """Value of a state is the value it is the index to""" |
@@ -1347,3 +1345,4 @@ def compare_graph_searchers(): |
1347 | 1345 | GraphProblem('Q', 'WA', australia_map)], |
1348 | 1346 | header=['Searcher', 'romania_map(Arad, Bucharest)', |
1349 | 1347 | 'romania_map(Oradea, Neamt)', 'australia_map']) |
| 1348 | + |
0 commit comments