1010from __future__ import print_function
1111
1212import tensorflow as tf
13- from tensorflow .python . ops import rnn , rnn_cell
13+ from tensorflow .contrib import rnn
1414import numpy as np
1515
1616# Import MNIST data
@@ -60,20 +60,20 @@ def BiRNN(x, weights, biases):
6060 # Reshape to (n_steps*batch_size, n_input)
6161 x = tf .reshape (x , [- 1 , n_input ])
6262 # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
63- x = tf .split (0 , n_steps , x )
63+ x = tf .split (x , n_steps , 0 )
6464
6565 # Define lstm cells with tensorflow
6666 # Forward direction cell
67- lstm_fw_cell = rnn_cell .BasicLSTMCell (n_hidden , forget_bias = 1.0 )
67+ lstm_fw_cell = rnn .BasicLSTMCell (n_hidden , forget_bias = 1.0 )
6868 # Backward direction cell
69- lstm_bw_cell = rnn_cell .BasicLSTMCell (n_hidden , forget_bias = 1.0 )
69+ lstm_bw_cell = rnn .BasicLSTMCell (n_hidden , forget_bias = 1.0 )
7070
7171 # Get lstm cell output
7272 try :
73- outputs , _ , _ = rnn .bidirectional_rnn (lstm_fw_cell , lstm_bw_cell , x ,
73+ outputs , _ , _ = rnn .static_bidirectional_rnn (lstm_fw_cell , lstm_bw_cell , x ,
7474 dtype = tf .float32 )
7575 except Exception : # Old TensorFlow version only returns outputs not states
76- outputs = rnn .bidirectional_rnn (lstm_fw_cell , lstm_bw_cell , x ,
76+ outputs = rnn .static_bidirectional_rnn (lstm_fw_cell , lstm_bw_cell , x ,
7777 dtype = tf .float32 )
7878
7979 # Linear activation, using rnn inner loop last output
@@ -82,15 +82,15 @@ def BiRNN(x, weights, biases):
8282pred = BiRNN (x , weights , biases )
8383
8484# Define loss and optimizer
85- cost = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (pred , y ))
85+ cost = tf .reduce_mean (tf .nn .softmax_cross_entropy_with_logits (logits = pred , labels = y ))
8686optimizer = tf .train .AdamOptimizer (learning_rate = learning_rate ).minimize (cost )
8787
8888# Evaluate model
8989correct_pred = tf .equal (tf .argmax (pred ,1 ), tf .argmax (y ,1 ))
9090accuracy = tf .reduce_mean (tf .cast (correct_pred , tf .float32 ))
9191
9292# Initializing the variables
93- init = tf .initialize_all_variables ()
93+ init = tf .global_variables_initializer ()
9494
9595# Launch the graph
9696with tf .Session () as sess :
0 commit comments