Skip to content

Commit aea0278

Browse files
sagunbtensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 771154434
1 parent cc038da commit aea0278

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

orbit/utils/loop_fns.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
import tensorflow as tf, tf_keras
2121

22+
# pylint: disable=g-direct-tensorflow-import
23+
from tensorflow.python.tpu.google.sparse_core import tpu_embedding_v3
24+
# pylint: enable=g-direct-tensorflow-import
25+
2226

2327
def create_loop_fn(step_fn):
2428
"""Creates a loop function driven by a Python `while` loop.
@@ -200,7 +204,8 @@ class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
200204

201205
def __call__(self, iterator, num_steps):
202206
if tf.summary.should_record_summaries():
203-
output = self.with_summaries(iterator, tf.constant(1))
207+
with tpu_embedding_v3.SequentialEmbeddingContext():
208+
output = self.with_summaries(iterator, tf.constant(1))
204209
num_steps -= 1
205210
if num_steps >= 1:
206211
output = self.without_summaries(iterator, num_steps)

0 commit comments

Comments
 (0)