阅读(11k) 书签 (0)

TensorFlow定义控制流程操作

2018-09-07 13:52 更新

#版权所有2015 TensorFlow作者.版权所有.

#根据Apache许可证2.0版(“许可证”)许可;

#你不能使用这个文件,除非符合许可证.

#您可以获得许可证的副本

#http      ://www.apache.org/licenses/LICENSE-2.0

#除非适用法律要求或书面同意软件

根据许可证分发的#分发在“按原样”基础上,

#无明示或暗示的任何形式的担保或条件.

#查看有关权限的特定语言的许可证

#许可证下的限制.

# =============================================== =============================

“”“控制流程操作“”“.

See the @{$python/control_flow_ops} guide.

@@identity

@@tuple

@@group

@@no_op

@@count_up_to

@@cond

@@case

@@while_loop

@@logical_and

@@logical_not

@@logical_or

@@logical_xor

@@equal

@@not_equal

@@less

@@less_equal

@@greater

@@greater_equal

@@where

@@is_finite

@@is_inf

@@is_nan

@@verify_tensor_all_finite

@@check_numerics

@@add_check_numerics_ops

@@Assert

@@Print

"""

# pylint: disable=g-bad-name

from __future__ import absolute_import

from __future__ import division

from __future__ import print_function

import collections

import six

from six.moves import xrange  # pylint: disable=redefined-builtin

from tensorflow.core.protobuf import control_flow_pb2

from tensorflow.python.framework import constant_op

from tensorflow.python.framework import dtypes

from tensorflow.python.framework import ops

from tensorflow.python.framework import sparse_tensor

from tensorflow.python.framework import tensor_shape

from tensorflow.python.ops import array_ops

from tensorflow.python.ops import gen_array_ops

from tensorflow.python.ops import gen_control_flow_ops

from tensorflow.python.ops import gen_data_flow_ops

from tensorflow.python.ops import gen_logging_ops

from tensorflow.python.ops import math_ops

from tensorflow.python.ops import tensor_array_ops

# go/tf-wildcard-import

# pylint: disable=wildcard-import,undefined-variable

from tensorflow.python.ops.gen_control_flow_ops import *

# pylint: enable=wildcard-import

from tensorflow.python.platform import tf_logging as logging

from tensorflow.python.util import deprecation

from tensorflow.python.util import nest

from tensorflow.python.util import tf_should_use

# We override the 'tuple' for a control flow op, so we keep python's

# existing 'tuple' for later use in this module.

_basetuple = tuple

# pylint: disable=protected-access

# Assert and Print are special symbols in python, so we must

# use an upper-case version of them.

@tf_should_use.should_use_result

def Assert(condition, data, summarize=None, name=None):

  """Asserts that the given condition is true.

  If `condition` evaluates to false, print the list of tensors in `data`.

  `summarize` determines how many entries of the tensors to print.

  NOTE: To ensure that Assert executes, one usually attaches a dependency:

  ```python

  # Ensure maximum element of x is smaller or equal to 1

  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])

  with tf.control_dependencies([assert_op]):

    ... code using x ...

  ```

  Args:

    condition: The condition to evaluate.

    data: The tensors to print out when condition is false.

    summarize: Print this many entries of each tensor.

    name: A name for this operation (optional).

  Returns:

    assert_op: An `Operation` that, when executed, raises a

    `tf.errors.InvalidArgumentError` if `condition` is not true.

  """

  with ops.name_scope(name, "Assert", [condition, data]) as name:

    xs = ops.convert_n_to_tensor(data)

    if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):

      # As a simple heuristic, we assume that string and int32 are

      # on host to avoid the need to use cond. If it is not case,

      # we will pay the price copying the tensor to host memory.

      return gen_logging_ops._assert(

          condition, data, summarize, name="Assert")

    else:

      condition = ops.convert_to_tensor(condition, name="Condition")

      def true_assert():

        return gen_logging_ops._assert(

            condition, data, summarize, name="Assert")

      guarded_assert = cond(

          condition, no_op, true_assert, name="AssertGuard")

      return guarded_assert.op

def _Identity(data, name=None):

  """Return a tensor with the same shape and contents as the input tensor.

  Args:

    data: A Tensor.

    name: A name for this operation (optional).

  Returns:

    A Tensor with the same type and value as the input Tensor.

  """

  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)

  if isinstance(data, ops.Tensor):

    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access

      return gen_array_ops._ref_identity(data, name=name)

    else:

      return array_ops.identity(data, name=name)

  else:

    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

      raise TypeError("Type %s not supported" % type(data))

    values = _Identity(data.values, name=name)

    indices = array_ops.identity(data.indices, name="indices")

    if isinstance(data, ops.IndexedSlices):

      dense_shape = data.dense_shape

      if dense_shape is not None:

        dense_shape = array_ops.identity(dense_shape, name="dense_shape")

      return ops.IndexedSlices(values, indices, dense_shape)

    else:

      dense_shape = array_ops.identity(data.dense_shape, name="dense_shape")

      return sparse_tensor.SparseTensor(indices, values, dense_shape)

def _NextIteration(data, name=None):

  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)

  if isinstance(data, ops.Tensor):

    if data.dtype._is_ref_dtype:   # pylint: disable=protected-access

      return ref_next_iteration(data, name=name)

    else:

      return next_iteration(data, name=name)

  else:

    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

      raise TypeError("Type %s not supported" % type(data))

    values = _NextIteration(data.values, name=name)

    indices = next_iteration(data.indices, name="indices")

    if isinstance(data, ops.IndexedSlices):

      dense_shape = data.dense_shape

      if dense_shape is not None:

        dense_shape = next_iteration(dense_shape, name="dense_shape")

      return ops.IndexedSlices(values, indices, dense_shape)

    else:

      dense_shape = next_iteration(data.dense_shape, name="dense_shape")

      return sparse_tensor.SparseTensor(indices, values, dense_shape)

def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,

           use_ref=True, use_input_shape=True, name=None):

  """Creates or finds a child frame, and makes `data` available to it.

  The unique `frame_name` is used by the `Executor` to identify frames. If

  `is_constant` is true, `data` is a constant in the child frame; otherwise

  it may be changed in the child frame. At most `parallel_iterations`

  iterations are run in parallel in the child frame.

  Args:

    data: The tensor to be made available to the child frame.

    frame_name: The name of the child frame.

    is_constant: If true, the output is constant within the child frame.

    parallel_iterations: The number of iterations allowed to run in parallel.

    use_ref: If true, use ref_enter if data is of ref type.

    name: A name for this operation (optional).

  Returns:

    The same tensor as `data`.

  """

  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)

  if isinstance(data, ops.Tensor):

    if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access

      result = ref_enter(data, frame_name, is_constant, parallel_iterations,

                         name=name)

    else:

      result = enter(data, frame_name, is_constant, parallel_iterations,

                     name=name)

    if use_input_shape:

      result.set_shape(data.get_shape())

    return result

  else:

    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

      raise TypeError("Type %s not supported" % type(data))

    values = _Enter(data.values, frame_name, is_constant,

                    parallel_iterations=parallel_iterations,

                    use_input_shape=use_input_shape, name=name)

    indices = enter(data.indices, frame_name, is_constant,

                    parallel_iterations, name="indices")

    if use_input_shape:

      indices.set_shape(data.indices.get_shape())

    if isinstance(data, ops.IndexedSlices):

      dense_shape = data.dense_shape

      if dense_shape is not None:

        dense_shape = enter(dense_shape, frame_name, is_constant,

                            parallel_iterations, name="dense_shape")

        if use_input_shape:

          dense_shape.set_shape(data.dense_shape.get_shape())

      return ops.IndexedSlices(values, indices, dense_shape)

    else:

      dense_shape = enter(data.dense_shape, frame_name, is_constant,

                          parallel_iterations, name="dense_shape")

      if use_input_shape:

        dense_shape.set_shape(data.dense_shape.get_shape())

      return sparse_tensor.SparseTensor(indices, values, dense_shape)

def exit(data, name=None):

  """Exits the current frame to its parent frame.

  Exit makes its input `data` available to the parent frame.

  Args:

    data: The tensor to be made available to the parent frame.

    name: A name for this operation (optional).

  Returns:

    The same tensor as `data`.

  """

  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)

  if isinstance(data, ops.Tensor):

    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access

      return gen_control_flow_ops._ref_exit(data, name)

    else:

      return gen_control_flow_ops._exit(data, name)

  else:

    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

      raise TypeError("Type %s not supported" % type(data))

    values = exit(data.values, name=name)

    indices = gen_control_flow_ops._exit(data.indices, name="indices")

    if isinstance(data, ops.IndexedSlices):

      dense_shape = data.dense_shape

      if dense_shape is not None:

        dense_shape = gen_control_flow_ops._exit(dense_shape, name)

      return ops.IndexedSlices(values, indices, dense_shape)

    else:

      dense_shape = gen_control_flow_ops._exit(data.dense_shape, name)

      return sparse_tensor.SparseTensor(indices, values, dense_shape)

def switch(data, pred, dtype=None, name=None):

  """Forwards `data` to an output determined by `pred`.

  If `pred` is false, the `data` input is forwarded to the first output.

  Otherwise, the data goes to the second output.

  This op handles `Tensor`s and `IndexedSlices`.

  Args:

    data: The tensor to be forwarded to the appropriate output.

    pred: A scalar that specifies which output port will receive data.

    dtype: Optional element type for the returned tensor. If missing,

           the type is inferred from the type of `value`.

    name: A name for this operation (optional).

  Returns:

    `(output_false, output_true)`: If `pred` is true, data will be forwarded

    to `output_true`, otherwise it goes to `output_false`.

  """

  with ops.name_scope(name, "Switch", [data, pred]) as name:

    data = ops.internal_convert_to_tensor_or_indexed_slices(

        data, dtype=dtype, name="data", as_ref=True)

    pred = ops.convert_to_tensor(pred, name="pred")

    if isinstance(data, ops.Tensor):

      return gen_control_flow_ops._switch(data, pred, name=name)

    else:

      if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

        raise TypeError("Type %s not supported" % type(data))

      val, ind = data.values, data.indices

      val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)

      ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")

      if isinstance(data, ops.IndexedSlices):

        dense_shape = data.dense_shape

        if dense_shape is not None:

          dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(

              dense_shape, pred, name="dense_shape")

        else:

          dense_shape_f, dense_shape_t = None, None

        return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),

                ops.IndexedSlices(val_t, ind_t, dense_shape_t))

      else:

        dense_shape = data.dense_shape

        dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(

            data.dense_shape, pred, name="dense_shape")

        return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f),

                sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t))

def _SwitchRefOrTensor(data, pred, name="Switch"):

  """Forwards `data` to an output determined by `pred`.

  If `pred` is false, the `data` input is forwared to the first output.

  Otherwise, the data goes to the second output.

  This op handles `Tensor`s and `IndexedSlices`.

  Args:

    data: The tensor to be forwarded to the appropriate output.

    pred: A scalar that specifies which output port will receive data.

    name: A name for this operation (optional).

  Returns:

    `(output_false, output_true)`: If `pred` is true, data will be forwarded to

    `output_true`, otherwise it goes to `output_false`.

  Raises:

    TypeError: if data is not a Tensor or IndexedSlices

  """

  data = ops.convert_to_tensor_or_indexed_slices(data, name="data")

  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below

  # addresses the following scenario.

  #

  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().

  #

  # 1. The update op is created inside a `with ops.colocate(var):` block

  #

  # 2. Some tensor `data` is captured and a switch is created in a

  #    `with ops.colocate_with(data):` block.

  #

  # with ops.colocate_with(var):

  #  with ops.colocate_with(data):

  #    op = ...

  #

  # var and data may be pinned to different devices, so we want to ops

  # created within ops.colocate_with(data) to ignore the existing stack.

  with ops.colocate_with(data, ignore_existing=True):

    if isinstance(data, ops.Tensor):

      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access

        return ref_switch(data, pred, name=name)

    return switch(data, pred, name=name)

def merge(inputs, name=None):

  """Returns the value of an available element of `inputs`.

  This op tests each of the tensors in `inputs` in turn to determine if any of

  them is available. If it finds an available tensor, it returns it and its

  index in `inputs`.

  It is an error if more than one tensor in `inputs` is available. If no tensor

  in `inputs` is available, the returned tensor and index are not set.

  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of

  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices

  before merging.

  Args:

    inputs: The input tensors, at most one of which is available.

    name: A name for this operation (optional).

  Returns:

    A tuple containing the chosen input tensor and its index in `inputs`.

  Raises:

    ValueError: If any of the inputs is None, or inputs are IndexedSlices and

      some but not all have a dense_shape property.

  """

  if any([inp is None for inp in inputs]):

    raise ValueError("At least one of the merge inputs is None: %s" % inputs)

  with ops.name_scope(name, "Merge", inputs) as name:

    inputs = [ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)

              for inp in inputs]

    if all([isinstance(v, ops.Tensor) for v in inputs]):

      if all([v.dtype._is_ref_dtype for v in inputs]):  # pylint: disable=protected-access

        return gen_control_flow_ops._ref_merge(inputs, name)

      else:

        return gen_control_flow_ops._merge(inputs, name)

    elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]):

      # Only handle the case when all inputs are SparseTensor.

      values, _ = merge([inp.values for inp in inputs], name=name)

      indices, chosen_index = gen_control_flow_ops._merge(

          [inp.indices for inp in inputs], name="indices")

      dense_shape, _ = gen_control_flow_ops._merge(

          [inp.dense_shape for inp in inputs], name="dense_shape")

      return (sparse_tensor.SparseTensor(indices, values, dense_shape),

              chosen_index)

    else:

      # For now convert all the inputs as IndexedSlices.

      inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)

      values, _ = merge([inp.values for inp in inputs], name=name)

      indices, chosen_index = gen_control_flow_ops._merge(

          [inp.indices for inp in inputs], name="indices")

      if any(inp.dense_shape is not None for inp in inputs):

        if any(inp.dense_shape is None for inp in inputs):

          raise ValueError("Either all merged IndexedSlices must have a "

                           "dense_shape, or none must have a dense_shape.")

        dense_shape, _ = gen_control_flow_ops._merge(

            [inp.dense_shape for inp in inputs], name="dense_shape")

      else:

        dense_shape = None

      return ops.IndexedSlices(values, indices, dense_shape), chosen_index

# pylint: enable=protected-access

def _convert_tensorarray_to_flow(tensor_or_tensor_array):

  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):

    return tensor_or_tensor_array.flow

  else:

    return tensor_or_tensor_array

def _make_tensor_array(ta, t_or_flow):

  # pylint: disable=protected-access

  new_ta = tensor_array_ops.TensorArray(

      dtype=ta.dtype, handle=ta.handle, flow=t_or_flow,

      infer_shape=ta._infer_shape,

      colocate_with_first_write_call=ta._colocate_with_first_write_call)

  new_ta._colocate_with = ta._colocate_with

  new_ta._element_shape = ta._element_shape

  # pylint: enable=protected-access

  return new_ta

def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):

  if len(tensors_or_tensorarrays) != len(tensors_or_flows):

    raise ValueError(

        "Lengths of original Tensor list and new list do not match: %d vs. %d"

        % (len(tensors_or_tensorarrays), len(tensors_or_flows)))

  return [

      _make_tensor_array(ta, t_or_flow)

      if isinstance(ta, tensor_array_ops.TensorArray)

      else t_or_flow

      for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)]

def _IsLoopConstantEnter(op):

  """Return true iff op is a loop invariant."""

  is_enter = (op.type == "Enter" or op.type == "RefEnter")

  return is_enter and op.get_attr("is_constant")

def _GetLoopConstantEnter(value):

  """Return the enter op if we can infer `value` to be a loop invariant."""

  id_ops = {"Switch", "RefSwitch", "Identity", "RefIdentity"}

  op = value.op

  while op.type in id_ops:

    op = op.inputs[0].op

  return op if _IsLoopConstantEnter(op) else None

def _GetOutputContext(op):

  """Return the control flow context for the output of an op."""

  ctxt = op._get_control_flow_context()

  if IsLoopExit(op):

    ctxt = ctxt.outer_context

  return ctxt

def _ShapeLessThanOrEqual(shape1, shape2):

  if shape2.dims is None:

    return True

  if shape1.ndims != shape2.ndims:

    return False

  for dim1, dim2 in zip(shape1.dims, shape2.dims):

    if dim2.value is not None and dim1.value != dim2.value:

      return False

  return True

def _SetShapeInvariants(input_vars, enter_vars, shapes):

  """Set the shapes of the tensors in `enter_vars` to `shapes`.

  Args:

    input_vars: A list of tensors that are inputs to `enter_vars`.

    enter_vars: A list of tensors whose shapes will be set.

    shapes: A (possibly nested) list of shapes.

  Raises:

    ValueError: If any tensor in `enter_vars` has a less specific shape

      than its corresponding shape in `shapes`.

  """

  if shapes is None:

    return

  flat_shapes = nest.flatten(shapes)

  if not all([isinstance(s, tensor_shape.TensorShape) for s in flat_shapes]):

    raise ValueError("`shapes` must be a (possibly nested) list of shapes.")

  # Check that the shapes of the inputs are less than the shape invariants,

  # and set the shapes of `enter_vars` to the shape invariants.

  for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):

    if isinstance(var, ops.Tensor):

      if not _ShapeLessThanOrEqual(inp.get_shape(), shape):

        raise ValueError(

            "The shape invariant specified for %s is not compatible with "

            "the initial shape of the loop variable. It enters the loop "

            "with shape %s, but the specified shape invariant is %s."

            % (inp.name, inp.get_shape(), shape))

      var.set_shape(shape)

    else:

      if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

        raise TypeError("Type %s not supported" % type(var))

      if isinstance(var, ops.IndexedSlices):

        if not _ShapeLessThanOrEqual(inp.values.get_shape(), shape):

          raise ValueError(

              "The shape invariant specified for %s is not compatible with "

              "the initial shape of the values tensor of this IndexedSlices. "

              "It enters the loop with shape %s, but the specified shape "

              "invariant is %s."

              % (inp.values.name, inp.values.get_shape(), shape))

        var.values.set_shape(shape)

        var.indices.set_shape(tensor_shape.TensorShape([shape[0]]))

        if var.dense_shape is not None:

          var.dense_shape.set_shape(tensor_shape.TensorShape([shape.ndims]))

      else:

        if not _ShapeLessThanOrEqual(inp.dense_shape.get_shape(), shape):

          raise ValueError(

              "The shape invariant specified for %s is not compatible with "

              "the initial shape of the shape tensor of this SparseTensor. "

              "It enters the loop with shape %s, but the specified shape "

              "invariant is %s."

              % (inp.dense_shape.name, inp.dense_shape.get_shape(), shape))

        var.values.set_shape(tensor_shape.TensorShape([None]))

        var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims]))

        var.dense_shape.set_shape(shape)

def _EnforceShapeInvariant(merge_var, next_var):

  """Check if the shapes of the loops variables are invariants.

  Args:

    merge_vars: The list of tensors representing the initial values of the

      loop variables.

    next_vars: The list of tensors representing the values of the loop

      variables after one loop iteration.

  Raises:

    ValueError: If any tensor in `merge_vars` has a more specific shape than

      its correspnding tensor in `next_var`.

  """

  if isinstance(merge_var, ops.Tensor):

    m_shape = merge_var.get_shape()

    n_shape = next_var.get_shape()

    if not _ShapeLessThanOrEqual(n_shape, m_shape):

      raise ValueError(

          "The shape for %s is not an invariant for the loop. It enters "

          "the loop with shape %s, but has shape %s after one iteration. "

          "Provide shape invariants using either the `shape_invariants` "

          "argument of tf.while_loop or set_shape() on the loop variables."

          % (merge_var.name, m_shape, n_shape))

  else:

    if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

      raise TypeError("Type %s not supported" % type(var))

    if isinstance(var, ops.IndexedSlices):

      m_values_shape = merge_var.values.get_shape()

      m_indices_shape = merge_var.indices.get_shape()

      m_shape_shape = tensor_shape.TensorShape(None)

      if merge_var.dense_shape is not None:

        m_shape_shape = merge_var.dense_shape.get_shape()

      n_values_shape = next_var.values.get_shape()

      n_indices_shape = next_var.indices.get_shape()

      n_shape_shape = tensor_shape.TensorShape(None)

      if next_var.dense_shape is not None:

        n_shape_shape = next_var.dense_shape.get_shape()

      if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or

          not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape)):

        if not _ShapeLessThanOrEqual(n_values_shape, m_values_shape):

          raise ValueError(

              "The shape for %s is not an invariant for the loop. It enters "

              "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "

              "after one iteration. Provide shape invariants using either the "

              "`shape_invariants` argument of tf.while_loop or set_shape() "

              "on the loop variables."

              % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,

                 n_values_shape, n_indices_shape, n_shape_shape))

    else:

      m_values_shape = merge_var.values.get_shape()

      m_indices_shape = merge_var.indices.get_shape()

      m_shape_shape = merge_var.dense_shape.get_shape()

      n_values_shape = next_var.values.get_shape()

      n_indices_shape = next_var.indices.get_shape()

      n_shape_shape = next_var.dense_shape.get_shape()

      if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or

          not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or

          not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)):

        raise ValueError(

          "The shape for %s is not an invariant for the loop. It enters "

          "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "

          "after one iteration. Provide shape invariants using either "

          "the `shape_invariants` argument of tf.while_loop or set_shape() "

          "on the loop variables."

          % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,

             n_values_shape, n_indices_shape, n_shape_shape))

def _AddNextAndBackEdge(m, v):

  """Add NextIteration and back edge from v to m."""

  if isinstance(m, ops.Tensor):

    v = ops.convert_to_tensor(v)

    v = _NextIteration(v)

    m.op._update_input(1, v)   # pylint: disable=protected-access

  elif isinstance(m, ops.IndexedSlices):

    # pylint: disable=protected-access

    v = math_ops._as_indexed_slices(v, optimize=False)

    v = _NextIteration(v)

    m.values.op._update_input(1, v.values)

    m.indices.op._update_input(1, v.indices)

    # pylint: enable=protected-access

    if m.dense_shape is not None:

      if v.dense_shape is None:

        raise ValueError("Must have dense shape: %s" % v.name)

      m.dense_shape.op._update_input(1, v.dense_shape)

  elif isinstance(m, sparse_tensor.SparseTensor):

    if not isinstance(v, sparse_tensor.SparseTensor):

      raise ValueError("Must be a sparse tensor: %s" % v.name)

    v = _NextIteration(v)

    # pylint: disable=protected-access

    m.values.op._update_input(1, v.values)

    m.indices.op._update_input(1, v.indices)

    m.dense_shape.op._update_input(1, v.dense_shape)

    # pylint: enable=protected-access

  else:

    raise TypeError("Type %s not supported" % type(m))

  return v

class GradLoopState(object):

  """The state used for constructing the gradient graph for a while loop.

  We create a GradLoopState for each while loop in forward and its

  corresponding while loop in backprop. This gives us access to both

  the forward and the backprop WhileContexts.

  During the construction of gradient graph, any time when we detect

  a forward value that is needed for backprop, we create a history

  accumulator and add it to `history_map`. Any time when we backprop

  a loop switch op (in _SwitchGrad), we add the grad merge op in

  `switch_map`.

  """

  def __init__(self, forward_ctxt, outer_grad_state):

    # The grad loop state for the outer while loop.

    self._outer_grad_state = None

    # The while loop context for forward.

    self._forward_context = None

    # The loop counter added by AddForwardLoopCounter. It is the value

    # of the loop counter for the next iteration.

    self._forward_index = None

    # A sync op for forward.

    self._forward_sync = None

    # The while loop context for backprop.

    self._grad_context = None

    # The loop counter added by AddBackPropLoopCounter. It is the value

    # of the loop counter for the current iteration.

    self._grad_index = None

    # A sync op for backprop.

    self._grad_sync = None

    # Information needed by backprop.

    self._history_map = {}

    self._switch_map = {}

    self._unused_exits = []

    self._deferred_exits = []

    self._forward_loop_exits = list(forward_ctxt.loop_exits)

    self._pending_exits_count = len(forward_ctxt.loop_exits)

    self._outer_grad_state = outer_grad_state

    if outer_grad_state:

      outer_forward_ctxt = outer_grad_state.forward_context

    else:

      outer_forward_ctxt = forward_ctxt.outer_context

    # Add the forward loop counter.

    if outer_forward_ctxt: outer_forward_ctxt.Enter()

    cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)

    if outer_forward_ctxt: outer_forward_ctxt.Exit()

    self._forward_context = forward_ctxt

    self._forward_index = forward_index

    # Add the backprop WhileContext, and the backprop loop counter.

    if outer_grad_state:

      # This is a nested loop. Remember the iteration counts for each

      # execution of this inner loop.

      outer_forward_ctxt.AddName(cnt.name)

      history_cnt = outer_grad_state.AddForwardAccumulator(cnt)

      outer_grad_ctxt = outer_grad_state.grad_context

      outer_grad_ctxt.Enter()

      self._grad_context = WhileContext(forward_ctxt.parallel_iterations,

                                        forward_ctxt.back_prop,

                                        forward_ctxt.swap_memory,

                                        forward_ctxt.name,

                                        self)

      real_cnt = outer_grad_state.AddBackPropAccumulatedValue(history_cnt, cnt)

      self._grad_index = self._grad_context.AddBackPropLoopCounter(

          real_cnt, outer_grad_state)

      outer_grad_ctxt.Exit()

    else:

      if outer_forward_ctxt: outer_forward_ctxt.Enter()

      self._grad_context = WhileContext(forward_ctxt.parallel_iterations,

                                        forward_ctxt.back_prop,

                                        forward_ctxt.swap_memory,

                                        forward_ctxt.name,

                                        self)

      self._grad_index = self._grad_context.AddBackPropLoopCounter(

          cnt, outer_grad_state)

      if outer_forward_ctxt: outer_forward_ctxt.Exit()

  @property

  def outer_grad_state(self):

    """The grad loop state for outer loop."""

    return self._outer_grad_state

  @property

  def forward_context(self):

    """The while loop context for forward."""

    return self._forward_context

  @property

  def forward_index(self):

    """The loop index of forward loop."""

    return self._forward_index

  @property

  def forward_sync(self):

    """A control trigger node for synchronization in the forward loop.

    One main use is to keep the push ops of a stack executed in the

    iteration order.

    """

    if self._forward_sync is None:

      with ops.control_dependencies(None):

        self._forward_sync = control_trigger(name="f_sync")

      self._forward_sync._set_control_flow_context(self._forward_context)

      self._forward_index.op._add_control_input(self._forward_sync)

    return self._forward_sync

  @property

  def grad_context(self):

    """The corresponding WhileContext for gradient."""

    return self._grad_context

  @property

  def grad_index(self):

    """The loop index of backprop loop."""

    return self._grad_index

  @property

  def grad_sync(self):

    """A control trigger node for synchronization in the grad loop.

    One main use is to keep the pop ops of a stack executed in the

    iteration order.

    """

    if self._grad_sync is None:

      with ops.control_dependencies(None):

        self._grad_sync = control_trigger(name="b_sync")

      self._grad_sync._set_control_flow_context(self._grad_context)

      self._grad_index.op._add_control_input(self._grad_sync)

    return self._grad_sync

  @property

  def history_map(self):

    """The map that records all the tensors needed for backprop."""

    return self._history_map

  @property

  def switch_map(self):

    """The map that records all the Switch ops for the while loop."""

    return self._switch_map

  @property

  def unused_exits(self):

    """The list of "unused" exits."""

    return self._unused_exits

  @property

  def deferred_exits(self):

    """The list of "deferred" exits."""

    return self._deferred_exits

  @property

  def forward_loop_exits(self):

    """The list of exits of the forward loop."""

    return self._forward_loop_exits

  @property

  def pending_exits_count(self):

    """The number of exits we expect to see but haven't."""

    return self._pending_exits_count

  @pending_exits_count.setter

  def pending_exits_count(self, cnt):

    """Set the pending count to cnt."""

    self._pending_exits_count = cnt

  def AddForwardAccumulator(self, value, dead_branch=False):

    """Add an accumulator for each forward tensor that is needed in backprop.

    This is added to the forward loop at the first time when a tensor

    in the forward loop is used by backprop gradient computation loop.

    We create an accumulator that accumulates the value of tensor at each

    iteration. Called in the control flow context where gradients() is called.

    The pseudocode is:

    ```

      acc = stack();

      while (_pivot) {

        acc = stack_push(acc, value);

      }

    ```

    We make sure that the stack push op in one iteration is executed before

    next iteration. This is achieved by adding a control edge from

    `forward_index.op.inputs[0].op` to the push op, and another control

    edge from the push op to either `forward_index.op` or `forward_sync`.

    Args:

      value: The source tensor in forward that is to be accumulated.

      dead_branch: True iff the tensor is on a dead branch of a cond.

    Returns:

      The stack that contains the accumulated history of the tensor.

    Raises:

      TypeError: For internal errors involving the value condition context.

    """

    curr_ctxt = ops.get_default_graph()._get_control_flow_context()

    with ops.control_dependencies(None):

      if curr_ctxt: curr_ctxt.Enter()

      with ops.colocate_with(value):

        # pylint: disable=protected-access

        acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")

        # pylint: enable=protected-access

      if curr_ctxt: curr_ctxt.Exit()

      # Make acc available in the forward context.

      enter_acc = self.forward_context.AddValue(acc)

      # Add the stack_push op in the context of value.op.

      swap_enabled = self.forward_context.swap_memory

      value_ctxt = _GetOutputContext(value.op)

      if value_ctxt == self.forward_context:

        # value is not nested in the forward context.

        self.forward_context.Enter()

        push = gen_data_flow_ops._stack_push(

            enter_acc, value, swap_memory=swap_enabled)

        self.forward_context.Exit()

        # Protect stack push and order it before forward_index.

        self.forward_index.op._add_control_input(push.op)

      else:

        # value is in a cond context within the forward context.

        if not isinstance(value_ctxt, CondContext):

          raise TypeError(

              "value_ctxt is not a CondContext: %s" % value_ctxt)

        if dead_branch:

          # The special case for creating a zero tensor for a dead

          # branch of a switch. See ControlFlowState.ZerosLike().

          value_ctxt.outer_context.Enter()

          push = gen_data_flow_ops._stack_push(

              enter_acc, value, swap_memory=swap_enabled)

          value_ctxt.outer_context.Exit()

          push.op._set_control_flow_context(value_ctxt)

        else:

          value_ctxt.Enter()

          push = gen_data_flow_ops._stack_push(

              enter_acc, value, swap_memory=swap_enabled)

          value_ctxt.Exit()

        # Protect stack push and order it before forward_sync.

        self.forward_sync._add_control_input(push.op)

      # Order stack push after the successor of forward_index

      add_op = self.forward_index.op.inputs[0].op

      push.op._add_control_input(add_op)

      return acc

  def AddBackPropAccumulatedValue(self, history_value, value,

                                  dead_branch=False):

    """Add the getter for an accumulated value in the grad context.

    This is added to the backprop loop. Called in the grad context to

    get the value of an accumulated value. The stack pop op must be guarded

    by the pred of the controlling cond.

    Args:

      history_value: The history (a stack) of a value.

      value: The value that is pushed onto the stack.

      dead_branch: True iff the tensor is on a dead branch of a cond.

    Returns:

      The current value (the top of the stack).

    """

    history_ctxt = history_value.op._get_control_flow_context()

    # Find the cond context that controls history_value if any.

    cond_ctxt = None

    value_ctxt = value.op._get_control_flow_context()

    while value_ctxt and value_ctxt != history_ctxt:

      if isinstance(value_ctxt, CondContext):

        cond_ctxt = value_ctxt

        break

      value_ctxt = value_ctxt.outer_context

    with ops.control_dependencies(None):

      self.grad_context.Enter()

      if cond_ctxt:

        # Guard stack pop with a switch if it is controlled by a cond.

        grad_state = self

        pred = None

        while pred is None and grad_state:

          pred = grad_state.history_map.get(cond_ctxt.pred.name)

          grad_state = grad_state.outer_grad_state

        if pred is None:

          pred = cond_ctxt.pred

        branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch

        history_value = _SwitchRefOrTensor(history_value, pred)[branch]

      pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype)

      pop.set_shape(value.get_shape())

      self.grad_context.Exit()

    parallel_iterations = self.grad_context.parallel_iterations

    if parallel_iterations > 1:

      # All pops are ordered after pivot_for_body and before grad_sync.

      self.grad_sync._add_control_input(pop.op)

    return pop

  def GetRealValue(self, value):

    """Get the real value of `value`.

    If backprop "uses" a value produced by forward inference, an accumulator

    is added in the forward loop to accumulate its values.  We use the

    accumulated value. This method must be called in the grad loop context.

    `value` must be in forward and needed for backprop.

    Args:

      value: A tensor to be captured.

    Returns:

      The same tensor obtained from the saved history.

    """

    assert value.op.type not in ["Variable", "VariableV2"]

    real_value = self._history_map.get(value.name)

    if real_value is None:

      cur_value = value

      cur_grad_state = self

      while True:

        enter_op = _GetLoopConstantEnter(cur_value)

        if enter_op:

          # Special case: cur_value comes from a constant Enter node.

          cur_value = enter_op.inputs[0]

          cur_grad_state = cur_grad_state.outer_grad_state

          if cur_grad_state is None:

            # We are now outside all nested loops for this gradient(),

            # so `value` is a loop invariant and there is no need to

            # save the history of value. Just make cur_value to enter

            # the right control flow context.

            real_value = self._grad_context.AddValue(cur_value)

            break

        else:

          # Record the history of this value in forward_ctxt.

          # TODO(yuanbyu): Avoid recording constants.

          self._grad_context.Exit()

          history_value = cur_grad_state.AddForwardAccumulator(cur_value)

          self._grad_context.Enter()

          break

      if real_value is None:

        # Add the stack pop op in the grad context.

        real_value = cur_grad_state.AddBackPropAccumulatedValue(history_value,

                                                                cur_value)

        if cur_grad_state != self:

          real_value = self._grad_context.AddValue(real_value)

      self._history_map[value.name] = real_value

    return real_value

def _GetWhileContext(op):

  """Get the WhileContext to which this op belongs."""

  ctxt = op._get_control_flow_context()

  if ctxt:

    ctxt = ctxt.GetWhileContext()

  return ctxt

class ControlFlowState(object):

  """Maintain the mapping from the loops to their grad states."""

  def __init__(self):

    self._map = {}   # maps forward loop context to GradLoopState

  def GetGradState(self, op, before):

    """Return the grad state for this op if it's in a forward loop context."""

    if before and IsLoopExit(op):

      forward_ctxt = op._get_control_flow_context()

      forward_ctxt = forward_ctxt.outer_context

      if forward_ctxt:

        forward_ctxt = forward_ctxt.GetWhileContext()

    else:

      forward_ctxt = _GetWhileContext(op)

    if forward_ctxt:

      return self._map.get(forward_ctxt)

    return None

  def ProcessUnusedLoopExits(self, pending_count, to_ops_set):

    """Process all the "unused" loop exits.

    The "unused" exits of the loops are added to `unused_exits`. An exit is

    unused if its pending_count is 0. If there is an exit with real gradient,

    all these deferred exits will enter the backprop loop with zero gradient.

    Otherwise, they will enter the backprop loop with None. As an example,

    people often write:

           ```

           v1, _ = tf.while_loop(p, b, [x1, x2])

           result = gradients(v1, x1)

           ```

    The exit node for x2 is not included by the betweenness analysis. But we

    need to backprop x2 if x2 is involved in computing v1.

    Args:

      pending_count: The number of backprop inputs for every op.

      to_ops_set: The set of ops for ys in gradients(ys, xs)

    Returns:

      The set of unused loop exits that we know at this point we need

      to backprop.

    """

    loop_exits = []

    for _, grad_state in self._map.items():

      # pylint: disable=protected-access

      for y in grad_state.forward_loop_exits:

        if pending_count[y.op._id] == 0:

          grad_state.pending_exits_count -= 1

          if y.op._id not in to_ops_set:

            grad_state.unused_exits.append(y)

          if grad_state.pending_exits_count == 0:

            loop_exits.extend(grad_state.unused_exits)

      # Need to include Enters in backprop for higher-order gradients.

      for y in grad_state.forward_context.loop_enters:

        if pending_count[y.op._id] == 0:

          pending_count[y.op._id] = 1

      # pylint: enable=protected-access

    return loop_exits

  def EnterGradWhileContext(self, op, before):

    """Enter the WhileContext for gradient computation."""

    grad_state = self.GetGradState(op, before)

    if grad_state:

      grad_state.grad_context.Enter()

  def ExitGradWhileContext(self, op, before):

    """Exit the WhileContext for gradient computation."""

    grad_state = self.GetGradState(op, before)

    if grad_state:

      grad_state.grad_context.Exit()

  def AddWhileContext(self, op, between_op_list, between_ops):

    """Add the grad state for the while loop that op belongs to.

    Note that op is an Exit, and this method must be called in

    the control flow context where gradients() is called.

    Note that this method modifies `between_op_list` and `between_ops`.

    """

    forward_ctxt = _GetWhileContext(op)

    grad_state = self._map.get(forward_ctxt)

    if grad_state is None:

      # This is a new while loop so create a grad state for it.

      outer_forward_ctxt = forward_ctxt.outer_context

      if outer_forward_ctxt:

        outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()

      outer_grad_state = None

      if outer_forward_ctxt:

        outer_grad_state = self._map.get(outer_forward_ctxt)

      grad_state = GradLoopState(forward_ctxt, outer_grad_state)

      self._map[forward_ctxt] = grad_state

      # We need to include all exits of a loop for backprop.

      for loop_exit in grad_state.forward_loop_exits:

        if not between_ops[loop_exit.op._id]:

          between_ops[loop_exit.op._id] = True

          between_op_list.append(loop_exit.op)

  def ZerosLikeForExit(self, val):

    """Create zeros_like gradient for a loop exit.

    If the result of a loop variable is not used but is involved in

    computing the result of some needed loop variable, we create a

    zero-valued tensor that is fed as gradient for the Exit node of that

    loop variable. Note that val.op is an Exit, and this method must be

    called in the control flow context where gradients() is called.

    Args:

      val: The output tensor of an Exit op.

    Returns:

      A zero tensor of the same shape of val.

    """

    val_shape = val.get_shape()

    forward_ctxt = val.op._get_control_flow_context()

    outer_forward_ctxt = forward_ctxt.outer_context

    if outer_forward_ctxt:

      outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()

    outer_grad_state = None

    if outer_forward_ctxt:

      outer_grad_state = self._map.get(outer_forward_ctxt)

    if outer_grad_state:

      # This is a nested loop.

      if val_shape.is_fully_defined():

        # If the shape is known statically, just create a zero tensor

        # with the right shape in the right context.

        outer_grad_state.grad_context.Enter()

        result = array_ops.zeros(val_shape.dims, val.dtype)

        outer_grad_state.grad_context.Exit()

      else:

        # Only the shape of value is needed for backprop.

        forward_ctxt.outer_context.Enter()

        shape = array_ops.shape_internal(val, optimize=False)

        forward_ctxt.outer_context.Exit()

        # Save the shape to a stack.

        history_shape = outer_grad_state.AddForwardAccumulator(shape)

        # Get the shape back from the stack.

        outer_grad_ctxt = outer_grad_state.grad_context

        outer_grad_ctxt.Enter()

        real_shape = outer_grad_state.AddBackPropAccumulatedValue(

            history_shape, shape)

        result = array_ops.zeros(real_shape, val.dtype)

        outer_grad_ctxt.Exit()

    else:

      # This is not a nested loop.

      if val_shape.is_fully_defined():

        # If the shape is known statically, just create a zero tensor

        # with the right shape.

        result = array_ops.zeros(val_shape.dims, val.dtype)

      else:

        result = array_ops.zeros_like(val, optimize=False)

    return result

  def ZerosLike(self, op, index):

    """Create zeros_like for the specified output of an op.

    If op is in a while loop that is part of gradients(), this method

    must be called in its grad loop context.

    Args:

      op: A tensorflow operation.

      index: the index for a specific output of the op.

    Returns:

      A zero tensor of the same shape of op.outputs[index].

    """

    if IsLoopSwitch(op): return None

    dead_branch = IsSwitch(op)

    forward_ctxt = _GetWhileContext(op)

    grad_state = self._map.get(forward_ctxt)

    if grad_state is None:

      # op is not in a while loop that is part of gradients().

      return ZerosLikeOutsideLoop(op, index)

    op_ctxt = op._get_control_flow_context()

    val = ops.convert_to_tensor(op.outputs[index], name="tensor")

    shape = val.get_shape()

    if shape.is_fully_defined():

      # If the shape is known statically, just create a zero tensor with

      # the right shape in the grad loop context.

      result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)

      if dead_branch:

        # op is a cond switch. Guard the zero tensor with a switch.

        pred = grad_state.history_map.get(op_ctxt.pred.name)

        branch = op_ctxt.branch

        result = _SwitchRefOrTensor(result, pred)[1 - branch]

    else:

      # Unknown shape so keep a history of the shape at runtime.

      if dead_branch:

        # Need to add a special switch to guard the value.

        pred = op_ctxt.pred

        branch = op_ctxt.branch

        op_ctxt.outer_context.Enter()

        val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]

        zeros_shape = array_ops.shape_internal(val, optimize=False)

        op_ctxt.outer_context.Exit()

        val.op._set_control_flow_context(op_ctxt)

        zeros_shape.op._set_control_flow_context(op_ctxt)

      else:

        op_ctxt.Enter()

        zeros_shape = array_ops.shape_internal(val, optimize=False)

        op_ctxt.Exit()

      # Add forward accumulator for shape.

      grad_state.grad_context.Exit()

      history_zeros_shape = grad_state.AddForwardAccumulator(

          zeros_shape, dead_branch=dead_branch)

      grad_state.grad_context.Enter()

      # Create a zero tensor with the right shape.

      shape = grad_state.AddBackPropAccumulatedValue(

          history_zeros_shape, zeros_shape, dead_branch)

      result = array_ops.zeros(shape, val.dtype)

    return result

  def PostProcessing(self):

    """Perform postprocessing at the end of gradients().

    We have created the gradient graph at this point. So this function

    can be used to perform any postprocessing on the gradient graph.

    We currently perform the following postprocessing:

      1. Patch the gradient graph if the output of a loop variable

         doesn't depend on its input.

    """

    for _, grad_state in self._map.items():

      for _, b_merge in grad_state.switch_map.items():

        if b_merge.op.inputs[0] == b_merge.op.inputs[1]:

          # The value of this loop variable at iteration i+1 doesn't

          # depend on its value at iteration i. So use zeros as the

          # gradients for all iterations > 0.

          dtype = b_merge.op.inputs[0].dtype

          shape = b_merge.op.inputs[0].get_shape()

          # pylint: disable=protected-access

          if shape.is_fully_defined():

            grad_state.grad_context.Enter()

            # Create a zeros and use it for iterations > 0.

            grad_val = constant_op.constant(0, dtype=dtype, shape=shape)

            next_grad_val = _NextIteration(grad_val)

            grad_state.grad_context.Exit()

          else:

            # Create a zeros in the outer grad context.

            outer_grad_ctxt = grad_state.grad_context.outer_context

            if outer_grad_ctxt: outer_grad_ctxt.Enter()

            enter_grad_op = b_merge.op.inputs[0].op

            enter_grad = enter_grad_op.inputs[0]

            grad_shape = array_ops.shape_internal(enter_grad, optimize=False)

            grad_val = array_ops.zeros(grad_shape)

            if outer_grad_ctxt: outer_grad_ctxt.Exit()

            # Use the zeros for iterations > 0.

            grad_state.grad_context.Enter()

            next_grad_val = _NextIteration(grad_val)

            grad_state.grad_context.Exit()

          b_merge.op._update_input(1, next_grad_val)

          # pylint: enable=protected-access

def MaybeCreateControlFlowState(between_op_list, between_ops,

                                colocate_gradients_with_ops):

  """Create the state for all the while loops involved in one gradients().

  We create a ControlFlowState when there are while loops involved in

  gradients(). In gradients(), control flow logic is only invoked when

  the ControlFlowState is not None.

  Note that this method modifies `between_op_list` and `between_ops`.

  """

  loop_state = None

  for op in between_op_list:

    if IsLoopExit(op):

      if loop_state is None:

        loop_state = ControlFlowState()

      if colocate_gradients_with_ops:

        with ops.colocate_with(op):

          loop_state.AddWhileContext(op, between_op_list, between_ops)

      else:

        loop_state.AddWhileContext(op, between_op_list, between_ops)

  return loop_state

def IsSwitch(op):

  """Return true if `op` is a Switch."""

  return op.type == "Switch" or op.type == "RefSwitch"

def IsLoopExit(op):

  """Return true if `op` is an Exit."""

  return op.type == "Exit" or op.type == "RefExit"

def IsLoopSwitch(op):

  """Return true if `op` is the Switch for a while loop."""

  if IsSwitch(op):

    ctxt = op._get_control_flow_context()

    return ctxt and isinstance(ctxt, WhileContext)

  return False

def ZerosLikeOutsideLoop(op, index):

  """Create zeros_like for the specified output of an op."""

  val = op.outputs[index]

  if not IsSwitch(op):

    return array_ops.zeros_like(val, optimize=False)

  else:

    op_ctxt = op._get_control_flow_context()

    if op_ctxt:

      # We are in a cond context. Use a switch to create zeros only when needed.

      pred = op_ctxt.pred

      branch = op_ctxt.branch

      switch_val = switch(op.inputs[0], pred)[1 - branch]

      zeros_shape = array_ops.shape_internal(switch_val, optimize=False)

      return array_ops.zeros(zeros_shape, dtype=val.dtype)

    else:

      return array_ops.zeros_like(val, optimize=False)

class ControlFlowContext(object):

  """The base class for control flow context.

  The usage pattern is a sequence of (Enter, Exit) followed by a final

  ExitResult.

  We maintain the following state for control flow contexts during graph

  construction:

   1. graph has _control_flow_context: the current context used to

      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()

   2. op has _control_flow_context: the context to which the op belongs.

      Set at the time the op is created. Immutable.

   3. A ControlFlowContext has _outer_context: the context in which this

      context is created. Set at the time a context is created. Immutable.

   4. A ControlFlowContext has _context_stack.

      Pushed and popped by ctxt.Enter() and ctxt.Exit()

  """

  def __init__(self, values_def=None, import_scope=None):

    self._outer_context = ops.get_default_graph()._get_control_flow_context()

    self._context_stack = []

    if values_def:

      self._init_values_from_proto(values_def,

                                   import_scope=import_scope)

    else:

      # Values that have been already seen in this context.

      self._values = set()

      # Values referenced by but external to this context.

      self._external_values = {}

  def _init_values_from_proto(self, values_def, import_scope=None):

    """Initializes values and external_values from `ValuesDef` protocol buffer.

    Args:

      values_def: `ValuesDef` protocol buffer.

      import_scope: Optional `string`. Name scope to add.

    """

    assert isinstance(values_def, control_flow_pb2.ValuesDef)

    self._values = set(values_def.values)

    g = ops.get_default_graph()

    self._external_values = {}

    for k, v in values_def.external_values.items():

      self._external_values[k] = g.as_graph_element(

          ops.prepend_name_scope(v, import_scope))

    op_names = set([op.split(":")[0]

                    for op in self._values - set(self._external_values)])

    for op in op_names:

      # pylint: disable=protected-access

      g.as_graph_element(ops.prepend_name_scope(

          op, import_scope))._set_control_flow_context(self)

      # pylint: enable=protected-access

  @property

  def outer_context(self):

    """Return the context containing this context."""

    return self._outer_context

  @property

  def grad_state(self):

    raise NotImplementedError("Abstract method")

  @property

  def back_prop(self):

    raise NotImplementedError("Abstract method")

  def _to_proto(self, export_scope=None):

    """Converts the values to a `ValuesDef` protocol buffer.

    Args:

      export_scope: Optional `string`. Name scope to remove.

    Returns:

      A `ValuesDef` protocol buffer.

    """

    values_def = control_flow_pb2.ValuesDef()

    values_def.values.extend(

        [ops.strip_name_scope(v, export_scope)

         for v in sorted(self._values)])

    for k, v in self._external_values.items():

      values_def.external_values[k] = ops.strip_name_scope(

          v.name, export_scope)

    return values_def

  @staticmethod

  def _from_proto(values_def, import_scope=None):

    """Returns a `ControlFlowContext` created from `values_def`."""

    return ControlFlowContext(values_def=values_def,

                              import_scope=import_scope)

  def AddName(self, name):

    self._values.add(name)

  # pylint: disable=protected-access

  def Enter(self):

    """Enter this control flow context."""

    graph = ops.get_default_graph()

    self._context_stack.append(graph._get_control_flow_context())

    graph._set_control_flow_context(self)

  def Exit(self):

    """Exit this control flow context."""

    graph = ops.get_default_graph()

    last_context = self._context_stack.pop()

    graph._set_control_flow_context(last_context)

  def ExitResult(self, result):

    """Make a list of tensors available in the outer context."""

    if self._outer_context:

      nest.map_structure(lambda x: self._outer_context.AddName(x.name), result)

  def GetWhileContext(self):

    """Return the while context containing this context."""

    if self._outer_context:

      return self._outer_context.GetWhileContext()

    return None

  def _IsInOuterContext(self, op):

    op_ctxt = _GetOutputContext(op)

    outer_ctxt = self.outer_context

    while outer_ctxt != op_ctxt:

      if outer_ctxt is None:

        return False

      outer_ctxt = outer_ctxt.outer_context

    return True

  def _RemoveExternalControlEdges(self, op):

    """Remove any external control dependency on this op."""

    while_ctxt = self.GetWhileContext()

    # A control input of `op` is internal if it is in the same while

    # loop context as the enclosing while loop context of self.

    if while_ctxt is None:

      internal_control_inputs = op.control_inputs

    else:

      internal_control_inputs = []

      for x in op.control_inputs:

        ctxt = _GetOutputContext(x)

        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:

          internal_control_inputs.append(x)

    if len(internal_control_inputs) != len(op.control_inputs):

      del op.control_inputs[:]

      op._add_control_inputs(internal_control_inputs)

    return internal_control_inputs

  # pylint: enable=protected-access

  def AddInnerOp(self, op):

    """Notifies a scope about an operator added to an inner scope."""

    pass

  def GetControlPivot(self):

    """Returns the pivot node for this context, or None."""

    return None

class CondContext(ControlFlowContext):

  """The context for the conditional construct."""

  def __init__(self, pred=None, pivot=None, branch=None,

               name="cond_text", context_def=None, import_scope=None):

    """Creates a `CondContext`.

    Args:

      pred: The `boolean` tensor for the conditional predicate.

      pivot: The predicate tensor in this branch.

      branch: 0 or 1 representing this branch.

      name: Name of the `CondContext` python object.

      context_def: Optional `ContextDef` protocol buffer to initialize the

        `CondContext` object from.

      import_scope: Optional `string`. Name scope to add. Only used when

        initialing from protocol buffer.

    """

    self._name = ops.get_default_graph().unique_name(name)

    if context_def:

      self._init_from_proto(context_def, import_scope=import_scope)

    else:

      # Initializes the default fields.

      ControlFlowContext.__init__(self)

      self._pred = pred         # The boolean tensor for the cond predicate

      self._pivot = pivot       # The predicate tensor in this branch

      self._branch = branch     # 0 or 1 representing this branch

      # Values considered to have been already seen in this context.

      self._values.add(pred.name)

      self._values.add(pivot.name)

  def _init_from_proto(self, context_def, import_scope=None):

    """Creates a new `CondContext` from protocol buffer.

    Args:

      context_def: `CondContextDef` protocol buffer.

      import_scope: Optional `string`. Name scope to add.

    """

    assert isinstance(context_def, control_flow_pb2.CondContextDef)

    # Create from context_def.

    g = ops.get_default_graph()

    self._name = ops.prepend_name_scope(

        context_def.context_name, import_scope)

    self._pred = g.as_graph_element(ops.prepend_name_scope(

        context_def.pred_name, import_scope))

    self._pivot = g.as_graph_element(ops.prepend_name_scope(

        context_def.pivot_name, import_scope))

    self._branch = context_def.branch

    super(CondContext, self).__init__(values_def=context_def.values_def,

                                      import_scope=import_scope)

  @property

  def name(self):

    return self._name

  @property

  def pred(self):

    return self._pred

  @property

  def pivot(self):

    return self._pivot

  @property

  def branch(self):

    return self._branch

  @property

  def grad_state(self):

    if self.GetWhileContext():

      return self.GetWhileContext().grad_state

    return None

  @property

  def back_prop(self):

    if self.GetWhileContext():

      self.GetWhileContext().back_prop

    return False

  def GetControlPivot(self):

    return self._pivot

  def to_proto(self, export_scope=None):

    """Converts a `CondContext` to a `CondContextDef` protocol buffer.

    Args:

      export_scope: Optional `string`. Name scope to remove.

    Returns:

      A `CondContextDef` protocol buffer.

    """

    if (export_scope is None or

        self.name.startswith(export_scope)):

      context_def = control_flow_pb2.CondContextDef()

      context_def.context_name = ops.strip_name_scope(

          self.name, export_scope)

      context_def.pred_name = ops.strip_name_scope(

          self._pred.name, export_scope)

      context_def.pivot_name = ops.strip_name_scope(

          self._pivot.name, export_scope)

      context_def.branch = self._branch

      context_def.values_def.MergeFrom(super(CondContext, self)._to_proto(

          export_scope))

      return context_def

    else:

      return None

  @staticmethod

  def from_proto(context_def, import_scope=None):

    """Returns a `CondContext` object created from `context_def`."""

    return CondContext(context_def=context_def,

                       import_scope=import_scope)

  def AddValue(self, val):

    """Add `val` to the current context and its outer context recursively."""

    if val.name in self._values:

      # Use the real value if it comes from outer context. This is needed in

      # particular for nested conds.

      result = self._external_values.get(val.name)

      result = val if result is None else result

    else:

      result = val

      self._values.add(val.name)

      if self._outer_context:

        result = self._outer_context.AddValue(val)

        self._values.add(result.name)

      with ops.control_dependencies(None):

        result = _SwitchRefOrTensor(result, self._pred)[self._branch]

      result.op.graph.prevent_fetching(result.op)

      # pylint: disable=protected-access

      result.op._set_control_flow_context(self)

      # pylint: enable=protected-access

      self._values.add(result.name)

      self._external_values[val.name] = result

    return result

  def AddOp(self, op):

    self._AddOpInternal(op)

  def _AddOpInternal(self, op):

    """Add `op` to the current context."""

    if not op.inputs:

      # Remove any external control dependency on this op

      self._RemoveExternalControlEdges(op)

      # pylint: disable=protected-access

      op._add_control_input(self._pivot.op)

      # pylint: enable=protected-access

      for x in op.outputs:

        self._values.add(x.name)

    else:

      for index in range(len(op.inputs)):

        x = op.inputs[index]

        real_x = self.AddValue(x)

        if real_x != x:

          # pylint: disable=protected-access

          op._update_input(index, real_x)

          # pylint: enable=protected-access

      for x in op.outputs:

        self._values.add(x.name)

      # pylint: disable=protected-access

      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":

        op._add_control_input(self._pivot.op)

      # pylint: enable=protected-access

    if self._outer_context or not IsLoopExit(op):

      op.graph.prevent_fetching(op)

  def _ProcessOutputTensor(self, val):

    """Process an output tensor of a conditional branch."""

    real_val = val

    if val.name not in self._values:

      # Handle the special case of lambda: x

      self._values.add(val.name)

      if self._outer_context:

        real_val = self._outer_context.AddValue(val)

        self._values.add(real_val.name)

      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]

      self._external_values[val.name] = real_val

    else:

      external_val = self._external_values.get(val.name)

      if external_val is not None:

        real_val = external_val

    return real_val

  def _BuildCondTensor(self, v):

    if isinstance(v, ops.Operation):

      # Use pivot as the proxy for this op.

      return with_dependencies([v], self._pivot)

    elif isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

      values = self._ProcessOutputTensor(v.values)

      indices = self._ProcessOutputTensor(v.indices)

      if isinstance(v, ops.IndexedSlices):

        dense_shape = v.dense_shape

        if dense_shape is not None:

          dense_shape = self._ProcessOutputTensor(dense_shape)

        return ops.IndexedSlices(values, indices, dense_shape)

      else:

        dense_shape = self._ProcessOutputTensor(v.dense_shape)

        return sparse_tensor.SparseTensor(indices, values, dense_shape)

    else:

      v = nest.map_structure(_convert_tensorarray_to_flow, v)

      return self._ProcessOutputTensor(ops.convert_to_tensor(v))

  def BuildCondBranch(self, fn):

    """Add the subgraph defined by fn() to the graph."""

    original_result = fn()

    if original_result is None:

      return None, None

    result = nest.map_structure(self._BuildCondTensor, original_result)

    if not isinstance(result, (list, _basetuple)):

      result = [result]

    return original_result, result

def _UnpackIfSingleton(res):

  if isinstance(res, (list, _basetuple)) and len(res) == 1:

    return res[0]

  else:

    return res

# pylint: disable=g-doc-args

@deprecation.deprecated_args(

    None,

    "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",

    "fn1", "fn2")

def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,

         fn1=None, fn2=None):

  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.

  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and

  `false_fn` must have the same non-zero number and type of outputs.

  Note that the conditional execution applies only to the operations defined in

  `true_fn` and `false_fn`. Consider the following simple program:

  ```python

  z = tf.multiply(a, b)

  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

  ```

  If `x < y`, the `tf.add` operation will be executed and `tf.square`

  operation will not be executed. Since `z` is needed for at least one

  branch of the `cond`, the `tf.multiply` operation is always executed,

  unconditionally.

  Although this behavior is consistent with the dataflow model of TensorFlow,

  it has occasionally surprised some users who expected a lazier semantics.

  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the

  call to `cond`, and not at all during `Session.run()`). `cond`

  stitches together the graph fragments created during the `true_fn` and

  `false_fn` calls with some additional graph nodes to ensure that the right

  branch gets executed depending on the value of `pred`.

  `tf.cond` supports nested structures as implemented in

  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the

  same (possibly nested) value structure of lists, tuples, and/or named tuples.

  Singleton lists and tuples form the only exceptions to this: when returned by

  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.

  This behavior is disabled by passing `strict=True`.

  Args:

    pred: A scalar determining whether to return the result of `true_fn` or

      `false_fn`.

    true_fn: The callable to be performed if pred is true.

    false_fn: The callable to be performed if pred is false.

    strict: A boolean that enables/disables 'strict' mode; see above.

    name: Optional name prefix for the returned tensors.

  Returns:

    Tensors returned by the call to either `true_fn` or `false_fn`. If the

    callables return a singleton list, the element is extracted from the list.

  Raises:

    TypeError: if `true_fn` or `false_fn` is not callable.

    ValueError: if `true_fn` and `false_fn` do not return the same number of

      tensors, or return tensors of different types.

  Example:

  ```python

    x = tf.constant(2)

    y = tf.constant(5)

    def f1(): return tf.multiply(x, 17)

    def f2(): return tf.add(y, 23)

    r = tf.cond(tf.less(x, y), f1, f2)

    # r is set to f1().

    # Operations in f2 (e.g., tf.add) are not executed.

  ```

  """

  # We needed to make true_fn/false_fn keyword arguments for

  # backwards-compatibility. This check exists so that we can convert back to

  # having them be positional arguments.

  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after

  # `fn1` and `fn2` are deleted.

  if fn1 is not None:

    if true_fn is not None:

      raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")

    true_fn = fn1

  elif true_fn is None:

    raise TypeError("cond(): true_fn argument required")

  if fn2 is not None:

    if false_fn is not None:

      raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")

    false_fn = fn2

  elif false_fn is None:

    raise TypeError("cond(): false_fn argument required")

  if not callable(true_fn):

    raise TypeError("true_fn must be callable.")

  if not callable(false_fn):

    raise TypeError("false_fn must be callable.")

  with ops.name_scope(name, "cond", [pred]) as name:

    # Add the Switch to the graph.

    if isinstance(pred, bool):

      raise TypeError("pred must not be a Python bool")

    p_2, p_1 = switch(pred, pred)

    pivot_1 = array_ops.identity(p_1, name="switch_t")

    pivot_2 = array_ops.identity(p_2, name="switch_f")

    pred = array_ops.identity(pred, name="pred_id")

    # Disable the fetching of tensors that are only on one branch of cond.

    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:

      tensor.op.graph.prevent_fetching(tensor.op)

    # Build the graph for the true branch in a new context.

    context_t = CondContext(pred, pivot_1, branch=1)

    context_t.Enter()

    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)

    if orig_res_t is None:

      raise ValueError("true_fn must have a return value.")

    context_t.ExitResult(res_t)

    context_t.Exit()

    # Build the graph for the false branch in a new context.

    context_f = CondContext(pred, pivot_2, branch=0)

    context_f.Enter()

    orig_res_f, res_f = context_f.BuildCondBranch(false_fn)

    if orig_res_f is None:

      raise ValueError("false_fn must have a return value.")

    context_f.ExitResult(res_f)

    context_f.Exit()

    if not strict:

      orig_res_t = _UnpackIfSingleton(orig_res_t)

      orig_res_f = _UnpackIfSingleton(orig_res_f)

    # Check that the return values of the two branches have the same structure.

    try:

      nest.assert_same_structure(orig_res_t, orig_res_f)

    except TypeError as e:

      raise TypeError(

          "Incompatible return types of true_fn and false_fn: {}".format(e))

    except ValueError as e:

      raise ValueError(

          "Incompatible return values of true_fn and false_fn: {}".format(e))

    # Add the final merge to the graph.

    if not res_t:

      raise ValueError("true_fn and false_fn must return at least one result.")

    res_t_flat = nest.flatten(res_t)

    res_f_flat = nest.flatten(res_f)

    for x, y in zip(res_t_flat, res_f_flat):

      assert ((isinstance(x, ops.IndexedSlices) and

               isinstance(y, ops.IndexedSlices)) or

              (isinstance(x, sparse_tensor.SparseTensor) and

               isinstance(y, sparse_tensor.SparseTensor)) or

              (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))

      val_x = x if isinstance(x, ops.Tensor) else x.values

      val_y = y if isinstance(y, ops.Tensor) else y.values

      if val_x.dtype.base_dtype != val_y.dtype.base_dtype:

        raise ValueError(

            "Outputs of true_fn and false_fn must have the same type: %s, %s" %

            (val_x.dtype.name, val_y.dtype.name))

    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]

    merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)

    # Add to collections

    ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)

    ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)

    merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)

    # Singleton lists and tuples are automatically unpacked if strict == False.

    if not strict:

      merges = _UnpackIfSingleton(merges)

    return merges

# pylint: enable=g-doc-args

def _resource_safe_shape(t):

  """Returns the shape of t or the variable it points to."""

  if t.dtype == dtypes.resource:

    while t.op.inputs:

      t = t.op.inputs[0]

    return tensor_shape.TensorShape(t.op.get_attr("shape"))

  return array_ops.shape_internal(t, optimize=False)

# TODO(yuanbyu): Consider having a unified notion of context for

# not only conditionals and loops but also control dependency and

# subgraphs.

class WhileContext(ControlFlowContext):

  """The context for the loop construct."""

  def __init__(self, parallel_iterations=10, back_prop=True, swap_memory=False,

               name="while_context", grad_state=None, context_def=None,

               import_scope=None):

    """"Creates a `WhileContext`.

    Args:

      parallel_iterations: The number of iterations allowed to run in parallel.

      back_prop: Whether backprop is enabled for this while loop.

      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.

      name: Optional name prefix for the returned tensors.

      grad_state: The gradient loop state.

      context_def: Optional `WhileContextDef` protocol buffer to initialize

        the `Whilecontext` python object from.

      import_scope: Optional `string`. Name scope to add. Only used when

        initialing from protocol buffer.

    """

    if context_def:

      self._init_from_proto(context_def, import_scope=import_scope)

    else:

      ControlFlowContext.__init__(self)

      self._init_from_args(parallel_iterations, back_prop, swap_memory,

                           name)

    # The gradient loop state.

    self._grad_state = grad_state

  def _init_from_args(self, parallel_iterations, back_prop, swap_memory,

                      name):

    """Creates a new `WhileContext` from arguments.

    Args:

      parallel_iterations: The number of iterations allowed to run in parallel.

      back_prop: Whether backprop is enabled for this while loop.

      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.

      name: Optional name prefix for the returned tensors.

    Raises:

      ValueError: If `parallel_iterations` has invalid value.

    """

    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):

      raise ValueError("`parallel_iterations` must be a positive integer: "

                       "%s" % parallel_iterations)

    self._name = ops.get_default_graph().unique_name(name)

    self._parallel_iterations = parallel_iterations

    self._back_prop = back_prop

    self._swap_memory = swap_memory

    # We use this node to control constants created by the pred lambda.

    self._pivot_for_pred = None

    # We use this node to control constants created by the body lambda.

    self._pivot_for_body = None

    # The boolean tensor for loop termination condition. Used in code

    # generation for gradient computation

    self._pivot = None

    # The list of exit tensors for loop variables.

    self._loop_exits = []

    # The list of enter tensors for loop variables.

    self._loop_enters = []

  def _init_from_proto(self, context_def, import_scope=None):

    """Creates a new `WhileContext` from protocol buffer.

    Args:

      context_def: `WhileContextDef` protocol buffer.

      import_scope: Optional `string`. Name scope to add.

    """

    assert isinstance(context_def, control_flow_pb2.WhileContextDef)

    # Create from context_def.

    g = ops.get_default_graph()

    self._name = ops.prepend_name_scope(

        context_def.context_name, import_scope)

    self._parallel_iterations = context_def.parallel_iterations

    self._back_prop = context_def.back_prop

    self._swap_memory = context_def.swap_memory

    self._pivot_for_pred = g.as_graph_element(ops.prepend_name_scope(

        context_def.pivot_for_pred_name, import_scope))

    # We use this node to control constants created by the body lambda.

    self._pivot_for_body = g.as_graph_element(ops.prepend_name_scope(

        context_def.pivot_for_body_name, import_scope))

    # The boolean tensor for loop termination condition. Used in code

    # generation for gradient computation.

    self._pivot = g.as_graph_element(

        ops.prepend_name_scope(context_def.pivot_name, import_scope))

    # The list of exit tensors for loop variables.

    self._loop_exits = [g.as_graph_element(

        ops.prepend_name_scope(exit_name, import_scope))

                        for exit_name in context_def.loop_exit_names]

    # The list of enter tensors for loop variables.

    self._loop_enters = [g.as_graph_element(

        ops.prepend_name_scope(enter_name, import_scope))

                         for enter_name in context_def.loop_enter_names]

    super(WhileContext, self).__init__(values_def=context_def.values_def,

                                       import_scope=import_scope)

  @property

  def name(self):

    return self._name

  @property

  def parallel_iterations(self):

    """The number of iterations allowed to run in parallel."""

    return self._parallel_iterations

  @property

  def back_prop(self):

    """True iff backprop is enabled for this while loop."""

    return self._back_prop

  @property

  def swap_memory(self):

    """True iff GPU-CPU memory swap is enabled for this while loop."""

    return self._swap_memory

  @property

  def pivot(self):

    """The boolean tensor representing the loop termination condition."""

    return self._pivot

  @property

  def loop_enters(self):

    """The list of enter tensors for loop variables."""

    return self._loop_enters

  @property

  def loop_exits(self):

    """The list of exit tensors for loop variables."""

    return self._loop_exits

  @property

  def grad_state(self):

    """The gradient loop state."""

    return self._grad_state

  def to_proto(self, export_scope=None):

    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.

    Args:

      export_scope: Optional `string`. Name scope to remove.

    Returns:

      A `WhileContextDef` protocol buffer.

    """

    if (export_scope is None or

        self.name.startswith(export_scope)):

      context_def = control_flow_pb2.WhileContextDef()

      context_def.context_name = ops.strip_name_scope(

          self.name, export_scope)

      context_def.parallel_iterations = self._parallel_iterations

      context_def.back_prop = self._back_prop

      context_def.swap_memory = self._swap_memory

      context_def.pivot_for_pred_name = ops.strip_name_scope(

          self._pivot_for_pred.name, export_scope)

      context_def.pivot_for_body_name = ops.strip_name_scope(

          self._pivot_for_body.name, export_scope)

      context_def.pivot_name = ops.strip_name_scope(

          self._pivot.name, export_scope)

      context_def.loop_exit_names.extend(

          [ops.strip_name_scope(l.name, export_scope)

           for l in self._loop_exits])

      context_def.loop_enter_names.extend(

          [ops.strip_name_scope(l.name, export_scope)

           for l in self._loop_enters])

      context_def.values_def.MergeFrom(

          super(WhileContext, self)._to_proto(

              export_scope=export_scope))

      return context_def

    else:

      return None

  @staticmethod

  def from_proto(context_def, import_scope=None):

    """Returns a `WhileContext` object created from `context_def`.

    Args:

      context_def: A `WhileContextDef` protocol buffer.

      import_scope: Optional `string`. Name scope to add.

    Returns:

      A `WhileContext` Python object.

    """

    return WhileContext(context_def=context_def,

                        import_scope=import_scope)

  def GetWhileContext(self):

    return self

  def GetControlPivot(self):

    if self._pivot_for_body is not None:

      return self._pivot_for_body

    return self._pivot_for_pred

  def AddValue(self, val):

    """Add `val` to the current context and its outer context recursively."""

    result = val

    if val.name not in self._values:

      self._values.add(val.name)

      # If we are in a grad context and val is from its forward context,

      # use GetRealValue(), which adds the logic to save the history of

      # val in forward.

      grad_ctxt = ops.get_default_graph()._get_control_flow_context()

      if grad_ctxt:

        grad_ctxt = grad_ctxt.GetWhileContext()

        if grad_ctxt.grad_state:

          forward_ctxt = _GetWhileContext(val.op)

          if IsLoopExit(val.op):

            forward_ctxt = forward_ctxt.outer_context

            if forward_ctxt:

              forward_ctxt = forward_ctxt.GetWhileContext()

          if forward_ctxt == grad_ctxt.grad_state.forward_context:

            real_val = grad_ctxt.grad_state.GetRealValue(val)

            self._external_values[val.name] = real_val

            return real_val

      if self._outer_context is not None:

        result = self._outer_context.AddValue(val)

      # Create an Enter to make `result` known to this loop context.

      with ops.control_dependencies(None):

        enter = _Enter(result, self._name, is_constant=True,

                       parallel_iterations=self._parallel_iterations)

        enter.graph.prevent_feeding(enter)

        if self._outer_context:

          self._outer_context.AddInnerOp(enter.op)

      # Fix the control inputs and control flow context of these enter ops.

      self._FixControlInputsAndContext([enter])

      # Add `enter` in this context.

      self._values.add(enter.name)

      self._external_values[val.name] = enter

      result = enter

    else:

      actual_val = self._external_values.get(val.name)

      if actual_val is not None:

        result = actual_val

    return result

  def AddOp(self, op):

    """Add `op` to the current context."""

    # For a reduction op, if op is in a grad context and its input is from

    # its forward context, moving op to the forward context means we would

    # store the tensor after the reduction as opposed to the tensor before

    # reduction, and therefore could significantly reduce memory consumption.

    # For now, we do this only for a few ops.

    if op.type in {"Shape", "Size", "Rank"}:

      grad_ctxt = ops.get_default_graph()._get_control_flow_context()

      if grad_ctxt:

        grad_ctxt = grad_ctxt.GetWhileContext()

        if grad_ctxt.grad_state:

          op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)

          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:

            op_input_ctxt = op.inputs[0].op._get_control_flow_context()

            op._set_control_flow_context(op_input_ctxt)

            op_input_ctxt._AddOpInternal(op)

            return

    self._AddOpInternal(op)

  def _AddOpInternal(self, op):

    """Add `op` to the current context.

    In the case that op has only external data inputs, we remove all of its

    external control inputs so all its inputs are in the same while loop

    context. This is valid because op now has an Enter input that has all

    the right control dependency.

    """

    if not op.inputs:

      # Remove any external control dependency on this op

      control_inputs = self._RemoveExternalControlEdges(op)

      # Add a control edge from the control pivot to this op.

      if not control_inputs:

        # pylint: disable=protected-access

        op._add_control_input(self.GetControlPivot().op)

        # pylint: enable=protected-access

      for x in op.outputs:

        self._values.add(x.name)

    else:

      for index in range(len(op.inputs)):

        x = op.inputs[index]

        real_x = self.AddValue(x)

        if real_x != x:

          op._update_input(index, real_x)

      # Remove any external control dependency on this op.

      self._RemoveExternalControlEdges(op)

      # Add a control dependency to prevent loop invariants from

      # enabling ops that should not be executed.

      self._MaybeAddControlDependency(op)

      for x in op.outputs:

        self._values.add(x.name)

    if self._outer_context or not IsLoopExit(op):

      op.graph.prevent_fetching(op)

      for x in op.outputs:

        op.graph.prevent_feeding(x)

    if self._outer_context:

      self._outer_context.AddInnerOp(op)

  def _MaybeAddControlDependency(self, op):

    """Add a control input to the op if it only depends on loop invariants."""

    def _IsOpFree(op):

      """Determines if `op` needs a control dependency."""

      if op.control_inputs:

        return False

      # pylint: disable=protected-access

      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":

        return True

      # pylint: enable=protected-access

      for x in op.inputs:

        if not _IsLoopConstantEnter(x.op):

          return False

      return True

    if _IsOpFree(op):

      # pylint: disable=protected-access

      op._add_control_input(self.GetControlPivot().op)

      # pylint: enable=protected-access

  def AddForwardLoopCounter(self, outer_grad_state):

    """Adds a loop that counts the number of iterations.

    This is added to the forward loop at the time when we start to

    create the loop for backprop gradient computation. Called in

    the outer context of this forward context.

    The pseudocode is:

      `n = 0; while (_pivot) { n++; }`

    Note that a control dependency is added to `n` to ensure the correct

    execution order of stack push ops.

    Args:

      outer_grad_state: The outer grad state. None if not nested.

    Returns:

      The number of iterations taken by the forward loop and the loop index.

    """

    n = constant_op.constant(0, name="f_count")

    if outer_grad_state is not None:

      # Force the stack pushes of i-th execution of an inner loop to be ordered

      # before the pushes of (i+1)-th execution of the same inner loop.

      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op

      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access

    self.Enter()

    self.AddName(n.name)

    enter_n = _Enter(n, self._name, is_constant=False,

                     parallel_iterations=self._parallel_iterations,

                     name="f_count")

    self.loop_enters.append(enter_n)

    merge_n = merge([enter_n, enter_n])[0]

    switch_n = switch(merge_n, self._pivot)

    index = math_ops.add(switch_n[1], 1)

    next_n = _NextIteration(index)

    merge_n.op._update_input(1, next_n)

    total_iterations = exit(switch_n[0], name="f_count")

    self.loop_exits.append(total_iterations)

    self.ExitResult([total_iterations])

    self.Exit()

    return total_iterations, next_n

  def AddBackPropLoopCounter(self, count, outer_grad_state):

    """Add the backprop loop that controls the iterations.

    This is added to the backprop loop. It is used to control the loop

    termination of the backprop loop. Called in the outer context of

    this grad context.

    The pseudocode is:

      `n = count; while (n >= 1) { n--; }`

    Note that a control dependency is added to `final_zero` to ensure the

    correct execution order of stack pop ops.

    Args:

      count: The number of iterations for backprop.

      outer_grad_state: The outer grad state. None if not nested.

    Returns:

      The loop index.

    """

    one = constant_op.constant(1, name="b_count")

    self.Enter()

    self.AddName(count.name)

    enter_count = _Enter(count, self._name, is_constant=False,

                         parallel_iterations=self._parallel_iterations,

                         name="b_count")

    self.loop_enters.append(enter_count)

    merge_count = merge([enter_count, enter_count])[0]

    self._pivot_for_pred = merge_count

    pred = math_ops.greater_equal(merge_count, one)

    self._pivot = loop_cond(pred, name="b_count")

    switch_count = switch(merge_count, self._pivot)

    index = math_ops.subtract(switch_count[1], one)

    self._pivot_for_body = index

    next_count = _NextIteration(index)

    merge_count.op._update_input(1, next_count)

    final_zero = exit(switch_count[0], name="b_count")

    self.loop_exits.append(final_zero)

    if outer_grad_state is not None:

      # Force the stack pops of i-th execution of an inner loop to be ordered

      # before the pops of (i+1)-th execution of the same inner loop.

      # pylint: disable=protected-access

      outer_grad_state.grad_sync._add_control_input(final_zero.op)

      # pylint: enable=protected-access

    self.ExitResult([final_zero])

    self.Exit()

    return next_count

  def AddBackPropAccumulator(self, op, grad):

    """Add an accumulation loop for every loop invariant.

    This is added to the backprop loop. It is used to accumulate partial

    gradients within each loop iteration. Called when in the gradient while

    context.

    The pseudocode is:

      ```

      acc = 0.0;

      while (_pivot) {

        acc += grad;

      }

      ```

    Args:

      op: The Enter op for a loop invariant.

      grad: The partial gradient of an iteration for a loop invariant.

    Returns:

      The gradient for a loop invariant.

    """

    self.Exit()

    # Create a zeros tensor with the right shape for acc. If we don't

    # know the full shape statically, we will have to get the shape

    # dynamically from the forward inference. Getting the shape right

    # for the zeros is only needed for the base case when the loop exits

    # without running any iterations.

    shape = grad.get_shape()

    if shape.is_fully_defined():

      if self.outer_context: self.outer_context.Enter()

      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")

      if self.outer_context: self.outer_context.Exit()

    else:

      value = op.inputs[0]

      if (isinstance(self.outer_context, WhileContext) and

          self.outer_context.grad_state is not None):

        # We are in a nested while loop.

        forward_ctxt = self.grad_state.forward_context

        forward_ctxt.outer_context.Enter()

        zeros_shape = array_ops.shape_internal(value, optimize=False)

        forward_ctxt.outer_context.Exit()

        outer_grad_state = self.grad_state.outer_grad_state

        history_zeros_shape = outer_grad_state.AddForwardAccumulator(

            zeros_shape)

        self.outer_context.Enter()

        real_shape = outer_grad_state.AddBackPropAccumulatedValue(

            history_zeros_shape, zeros_shape)

        acc = array_ops.zeros(real_shape, grad.dtype)

        self.outer_context.Exit()

      else:

        if self.outer_context: self.outer_context.Enter()

        zeros_shape = array_ops.shape_internal(value, optimize=False)

        acc = array_ops.zeros(zeros_shape, grad.dtype)

        if self.outer_context: self.outer_context.Exit()

      acc._shape = grad.get_shape()  # pylint: disable=protected-access

    self.Enter()

    self.AddName(acc.name)

    enter_acc = _Enter(acc, self._name, is_constant=False,

                       parallel_iterations=self._parallel_iterations,

                       name="b_acc")

    self.loop_enters.append(enter_acc)

    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]

    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)

    add_acc = math_ops.add(switch_acc_true, grad)

    next_acc = _NextIteration(add_acc)

    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access

    result_acc = exit(switch_acc_false, name="b_acc")

    self.loop_exits.append(result_acc)

    self.ExitResult([result_acc])

    return result_acc

  def AddBackPropIndexedSlicesAccumulator(self, op, grad):

    """This is used for accumulating gradients that are IndexedSlices.

    This is essentially the equavalent of AddBackPropAccumulator but optimized

    for things like updating embeddings from within a while loop.

    Args:

      op: The Enter op for a loop invariant.

      grad: The partial gradients represented as an IndexedSlices.

    Returns:

      The accumulated IndexedSlices gradient of the loop invariant.

    """

    values = grad.values

    indices = grad.indices

    dense_shape = grad.dense_shape

    self.Exit()

    if self.outer_context: self.outer_context.Enter()

    if values.get_shape().is_fully_defined():

      values_shape = tensor_shape.TensorShape(

          [tensor_shape.Dimension(1)] + values.get_shape().dims[1:])

      if self.outer_context: self.outer_context.Enter()

      values_acc = constant_op.constant(0, values.dtype, shape=values_shape,

                                        name="b_acc")

      if self.outer_context: self.outer_context.Exit()

    else:

      values_shape = _resource_safe_shape(op.inputs[0])[1:]

      values_shape = array_ops.concat([[1], values_shape], 0)

      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)

    indices_acc = constant_op.constant([0], indices.dtype)

    shape_acc = None

    if dense_shape is not None:

      if dense_shape.get_shape().is_fully_defined():

        if self.outer_context: self.outer_context.Enter()

        shape_acc = constant_op.constant(0, dense_shape.dtype,

                                         shape=dense_shape.get_shape())

        if self.outer_context: self.outer_context.Exit()

      else:

        shape_acc = array_ops.zeros_like(

            array_ops.shape_internal(op.inputs[0], optimize=False),

            optimize=False)

    if self.outer_context: self.outer_context.Exit()

    self.Enter()

    self.AddName(values_acc.name)

    self.AddName(indices_acc.name)

    init_acc = [indices_acc, values_acc]

    if shape_acc is not None:

      self.AddName(shape_acc.name)

      init_acc.append(shape_acc)

    enter_acc = [_Enter(x, self._name, is_constant=False,

                        parallel_iterations=self._parallel_iterations,

                        name="b_acc") for x in init_acc]

    self.loop_enters.extend(enter_acc)

    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]

    switch_acc = [switch(x, self._pivot) for x in merge_acc]

    # The actual accumulation.

    acc_indexed_slices = [

        array_ops.concat([xa[1], xv], 0)

        for xa, xv in zip(switch_acc[:2], [indices, values])

    ]

    if shape_acc is not None:

      # For the shape we just keep the maximum

      acc_indexed_slices.append(

          math_ops.maximum(dense_shape, switch_acc[2][1]))

    next_acc = [_NextIteration(x) for x in acc_indexed_slices]

    for xm, xn in zip(merge_acc, next_acc):

      xm.op._update_input(1, xn)  # pylint: disable=protected-access

    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]

    self.loop_exits.extend(exit_acc)

    self.ExitResult(exit_acc)

    return ops.IndexedSlices(

        indices=exit_acc[0], values=exit_acc[1],

        dense_shape=exit_acc[2] if shape_acc is not None else None)

  def _InitializeValues(self, values):

    """Makes the values known to this context."""

    self._values = set()

    for x in values:

      if isinstance(x, ops.Tensor):

        self._values.add(x.name)

      else:

        self._values.add(x.values.name)

        self._values.add(x.indices.name)

        if isinstance(x, ops.IndexedSlices):

          dense_shape = x.dense_shape

        elif isinstance(x, sparse_tensor.SparseTensor):

          dense_shape = x.dense_shape

        else:

          raise TypeError("Type %s not supported" % type(x))

        if dense_shape is not None:

          self._values.add(dense_shape.name)

  def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,

                 shape_invariants):

    """Core: Add the loop termination condition and body to the graph."""

    flat_loop_vars = nest.flatten(original_loop_vars)

    # Let the context know the loop variables so the loop variables

    # would be added in the outer contexts properly.

    self._InitializeValues(loop_vars)

    real_vars = loop_vars

    if self._outer_context:

      real_vars = [self._outer_context.AddValue(x) for x in loop_vars]

    with ops.control_dependencies(None):

      enter_vars = [_Enter(x, self._name, is_constant=False,

                           parallel_iterations=self._parallel_iterations,

                           use_input_shape=(shape_invariants is None))

                    for x in real_vars]

      for x in enter_vars:

        x.graph.prevent_feeding(x)

        if self._outer_context:

          self._outer_context.AddInnerOp(x.op)

    # Finds the closest enclosing non-None control pivot.

    outer_context = self._outer_context

    control_pivot = None

    while outer_context is not None and control_pivot is None:

      control_pivot = outer_context.GetControlPivot()

      # pylint: disable=protected-access

      outer_context = outer_context._outer_context

      # pylint: enable=protected-access

    if control_pivot is not None:

      for var in enter_vars:

        if _IsLoopConstantEnter(var.op.inputs[0].op):

          # pylint: disable=protected-access

          var.op._add_control_input(control_pivot.op)

          # pylint: enable=protected-access

    _SetShapeInvariants(real_vars, enter_vars, shape_invariants)

    # Fix the control inputs and control flow context of these enter ops.

    self._FixControlInputsAndContext(enter_vars)

    self._InitializeValues(enter_vars)

    self._loop_enters = enter_vars

    merge_vars = [merge([x, x])[0] for x in enter_vars]

    self._pivot_for_pred = merge_vars[0]

    # Build the graph for pred.

    merge_vars_with_tensor_arrays = (

        _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))

    packed_vars = nest.pack_sequence_as(

        structure=original_loop_vars,

        flat_sequence=merge_vars_with_tensor_arrays)

    c = ops.convert_to_tensor(pred(*packed_vars))

    self._pivot = loop_cond(c, name="LoopCond")

    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]

    # Build the graph for body.

    vars_for_body = [_Identity(x[1]) for x in switch_vars]

    self._pivot_for_body = vars_for_body[0]

    # Convert TensorArray flow variables inside the context back into

    # their associated TensorArrays for calling the body.

    vars_for_body_with_tensor_arrays = (

        _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))

    packed_vars_for_body = nest.pack_sequence_as(

        structure=original_loop_vars,

        flat_sequence=vars_for_body_with_tensor_arrays)

    body_result = body(*packed_vars_for_body)

    if not nest.is_sequence(body_result):

      body_result = [body_result]

    # Compare the structure types of input and output of body.

    # For backwards compatibility, the first layer is forced to a list

    # during this comparison, because inputs are typically lists and

    # outputs of the body are typically tuples.

    nest.assert_same_structure(list(packed_vars_for_body), list(body_result))

    # Store body_result to keep track of TensorArrays returned by body

    original_body_result = body_result

    # Convert TensorArrays returned by body into their flow variables

    result = nest.map_structure(_convert_tensorarray_to_flow,

                                nest.flatten(body_result))

    result = ops.convert_n_to_tensor_or_indexed_slices(result)

    # Add NextIteration and the back edges to complete the loop.

    if len(merge_vars) != len(result):

      raise ValueError("Number of inputs and outputs of body must match "

                       "loop_vars: %d, %d" % (len(merge_vars), len(result)))

    next_vars = []

    for m, v in zip(merge_vars, result):

      next_vars.append(_AddNextAndBackEdge(m, v))

    # Add the exit ops.

    exit_vars = [exit(x[0]) for x in switch_vars]

    self._loop_exits = exit_vars

    # Make sure the shapes of loop outputs are correct.

    for m_var, n_var in zip(merge_vars, next_vars):

      if isinstance(m_var, ops.Tensor):

        _EnforceShapeInvariant(m_var, n_var)

    # Exit the loop.

    self.ExitResult(exit_vars)

    return original_body_result, exit_vars

  def BuildLoop(self, pred, body, loop_vars, shape_invariants):

    """Add the loop termination condition and body to the graph."""

    # Keep original_loop_vars to identify which are TensorArrays

    original_loop_vars = loop_vars

    # Convert TensorArrays to their flow variables

    loop_vars = nest.map_structure(_convert_tensorarray_to_flow,

                                   nest.flatten(loop_vars))

    loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)

    try:

      self.Enter()

      original_body_result, exit_vars = self._BuildLoop(

          pred, body, original_loop_vars, loop_vars, shape_invariants)

    finally:

      self.Exit()

    flat_result = nest.flatten(original_body_result)

    # Convert TensorArray flow variables outside the context back into

    # their associated TensorArrays for returning to caller.

    exit_vars_with_tensor_arrays = (

        _convert_flows_to_tensorarrays(flat_result, exit_vars))

    packed_exit_vars = nest.pack_sequence_as(

        structure=original_body_result,

        flat_sequence=exit_vars_with_tensor_arrays)

    return (packed_exit_vars[0] if len(exit_vars) == 1

            else packed_exit_vars)

  def _FixControlInputsAndContext(self, enters):

    graph = ops.get_default_graph()

    # pylint: disable=protected-access

    for e in enters:

      if isinstance(e, ops.Tensor):

        xs = [e]

      else:

        if not isinstance(e, (ops.IndexedSlices, sparse_tensor.SparseTensor)):

          raise TypeError("Type %s not supported" % type(e))

        xs = [e.values, e.indices]

        shape = e.dense_shape

        if shape is not None:

          xs.append(shape)

      for x in xs:

        inp_op = x.op.inputs[0]

        control_inputs = graph._control_dependencies_for_inputs([inp_op])

        outer_control_inputs = [op for op in control_inputs

                                if self._IsInOuterContext(op)]

        x.op._set_control_flow_context(self)

        x.op._add_control_inputs(outer_control_inputs)

        graph._record_op_seen_by_control_dependencies(x.op)

    # pylint: enable=protected-access

def while_loop(cond, body, loop_vars, shape_invariants=None,

               parallel_iterations=10, back_prop=True, swap_memory=False,

               name=None):

  """Repeat `body` while the condition `cond` is true.

  `cond` is a callable returning a boolean scalar tensor. `body` is a callable

  returning a (possibly nested) tuple, namedtuple or list of tensors of the same

  arity (length and structure) and types as `loop_vars`. `loop_vars` is a

  (possibly nested) tuple, namedtuple or list of tensors that is passed to both

  `cond` and `body`. `cond` and `body` both take as many arguments as there are

  `loop_vars`.

  In addition to regular Tensors or IndexedSlices, the body may accept and

  return TensorArray objects.  The flows of the TensorArray objects will

  be appropriately forwarded between loops and during gradient calculations.

  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the

  call to `while_loop`, and not at all during `Session.run()`). `while_loop`

  stitches together the graph fragments created during the `cond` and `body`

  calls with some additional graph nodes to make something the repeats

  `body` until `cond` returns false.

  For correctness, `tf.while_loop()` strictly enforces shape invariants for

  the loop variables. A shape invariant is a (possibly partial) shape that

  is unchanged across the iterations of the loop. An error will be raised

  if the shape of a loop variable after an iteration is determined to be more

  general than or incompatible with its shape invariant. For example, a shape

  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not

  compatible with [11, 17]. By default (if the argument `shape_invariants` is

  not specified), it is assumed that the initial shape of each tensor in

  `loop_vars` is the same in every iteration. The `shape_invariants` argument

  allows the caller to specify a less specific shape invariant for each loop

  variable, which is needed if the shape varies between iterations. The

  @{tf.Tensor.set_shape}

  function may also be used in the `body` function to indicate that

  the output loop variable has a particular shape. The shape invariant for

  SparseTensor and IndexedSlices are treated specially as follows:

  a) If a loop variable is a SparseTensor, the shape invariant must be

  TensorShape([r]) where r is the rank of the dense tensor represented

  by the sparse tensor. It means the shapes of the three tensors of the

  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here

  is the shape of the SparseTensor.dense_shape property. It must be the shape of

  a vector.

  b) If a loop variable is an IndexedSlices, the shape invariant must be

  a shape invariant of the values tensor of the IndexedSlices. It means

  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],

  [shape.ndims]).

  `while_loop` implements non-strict semantics, enabling multiple iterations

  to run in parallel. The maximum number of parallel iterations can be

  controlled by `parallel_iterations`, which gives users some control over

  memory consumption and execution order. For correct programs, `while_loop`

  should return the same result for any parallel_iterations > 0.

  For training, TensorFlow remembers the tensors that are produced in the

  forward inference but needed in back propagation. These tensors can be a

  main source of memory consumption and often cause OOM problems when training

  on GPUs.  When the flag swap_memory is true, we swap out these tensors from

  GPU to CPU.  This for example allows us to train RNN models with very long

  sequences and large batches.

  Args:

    cond: A callable that represents the termination condition of the loop.

    body: A callable that represents the loop body.

    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,

      `Tensor`, and `TensorArray` objects.

    shape_invariants: The shape invariants for the loop variables.

    parallel_iterations: The number of iterations allowed to run in parallel.

      It must be a positive integer.

    back_prop: Whether backprop is enabled for this while loop.

    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.

    name: Optional name prefix for the returned tensors.

  Returns:

    The output tensors for the loop variables after the loop. When the length

    of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when

    the length of `loop_vars` is greater than 1 it returns a list.

  Raises:

    TypeError: if `cond` or `body` is not callable.

    ValueError: if `loop_vars` is empty.

  Example:

  ```python

  i = tf.constant(0)

  c = lambda i: tf.less(i, 10)

  b = lambda i: tf.add(i, 1)

  r = tf.while_loop(c, b, [i])

  ```

  Example with nesting and a namedtuple:

  ```python

  import collections

  Pair = collections.namedtuple('Pair', 'j, k')

  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))

  c = lambda i, p: i < 10

  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))

  ijk_final = tf.while_loop(c, b, ijk_0)

  ```

  Example using shape_invariants:

  ```python

  i0 = tf.constant(0)

  m0 = tf.ones([2, 2])

  c = lambda i, m: i < 10

  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]

  tf.while_loop(

      c, b, loop_vars=[i0, m0],

      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])

  ```

  """

  with ops.name_scope(name, "while", loop_vars) as name:

    if not loop_vars:

      raise ValueError("No loop variables provided")

    if not callable(cond):

      raise TypeError("cond must be callable.")

    if not callable(body):

      raise TypeError("body must be callable.")

    if parallel_iterations < 1:

      raise TypeError("parallel_iterations must be a positive integer.")

    if shape_invariants is not None:

      nest.assert_same_structure(loop_vars, shape_invariants)

    context = WhileContext(parallel_iterations, back_prop, swap_memory, name)

    ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)

    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)

    return result

def _AsTensorList(x, p):

  """Return x as a list of Tensors or IndexedSlices.

  For entries of `x` that are Operations, this returns an Identity of `p`

  with a dependency on the operation.

  Args:

    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.

    p: A Tensor to return for entries in `x` that are Operations.

  Returns:

    A list of Tensors or IndexedSlices.

  """

  if not isinstance(x, (list, _basetuple)):

    x = [x]

  l = []

  for v in x:

    if isinstance(v, ops.Operation):

      v = with_dependencies([v], p)

    v = ops.convert_to_tensor_or_indexed_slices(v)

    if isinstance(v, ops.Tensor):

      l.append(array_ops.identity(v))

    else:

      l.append(ops.IndexedSlices(array_ops.identity(v.values),

                                 array_ops.identity(v.indices)))

  return l

def _CheckResults(a, b):

  assert len(a) == len(b), (

      "Values returned by a() and b() must have the same length.")

  for x, y in zip(a, b):

    assert x.dtype == y.dtype, (

        "Values returned by a() [%s] and b() [%s] must have "

        "the same type: %s, %s." %

        (x.name, y.name, x.dtype.name, y.dtype.name))

def with_dependencies(dependencies, output_tensor, name=None):

  """Produces the content of `output_tensor` only after `dependencies`.

  In some cases, a user may want the output of an operation to be

  consumed externally only after some other dependencies have run

  first. This function ensures returns `output_tensor`, but only after all

  operations in `dependencies` have run. Note that this means that there is

  no guarantee that `output_tensor` will be evaluated after any `dependencies`

  have run.

  See also @{tf.tuple$tuple} and @{tf.group$group}.

  Args:

    dependencies: Iterable of operations to run before this op finishes.

    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.

    name: (Optional) A name for this operation.

  Returns:

    Same as `output_tensor`.

  Raises:

    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.

  """

  with ops.name_scope(name, "control_dependency",

                      list(dependencies) + [output_tensor]) as name:

    with ops.colocate_with(output_tensor):

      with ops.control_dependencies(dependencies):

        output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)

        if isinstance(output_tensor, ops.Tensor):

          return _Identity(output_tensor, name=name)

        else:

          return ops.IndexedSlices(_Identity(output_tensor.values, name=name),

                                   output_tensor.indices,

                                   output_tensor.dense_shape)

def _GroupControlDeps(dev, deps, name=None):

  with ops.control_dependencies(deps):

    if dev is None:

      return no_op(name=name)

    else:

      with ops.device(dev):

        return no_op(name=name)

# TODO(touts): Accept "inputs" as a list.

def group(*inputs, **kwargs):

  """Create an op that groups multiple operations.

  When this op finishes, all ops in `input` have finished. This op has no

  output.

  See also @{tf.tuple$tuple} and

  @{tf.control_dependencies$control_dependencies}.

  Args:

    *inputs: Zero or more tensors to group.

    **kwargs: Optional parameters to pass when constructing the NodeDef.

    name: A name for this operation (optional).

  Returns:

    An Operation that executes all its inputs.

  Raises:

    ValueError: If an unknown keyword argument is provided.

  """

  name = kwargs.pop("name", None)

  if kwargs:

    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))

  with ops.name_scope(name, "group_deps", inputs) as name:

    # Grouping no inputs means do nothing

    if not inputs:

      return no_op(name=name)

    # Sorts *inputs according to their devices.

    ops_on_device = {}  # device -> operations specified on the device.

    for inp in inputs:

      dev = inp.device

      if dev in ops_on_device:

        ops_on_device[dev].append(inp)

      else:

        ops_on_device[dev] = [inp]

    if len(ops_on_device) == 1:

      # 1-level tree. The root node is the returned NoOp node.

      (dev, deps), = ops_on_device.items()

      return _GroupControlDeps(dev, deps, name=name)

    # 2-level tree. The root node is the returned NoOp node.

    # deps contains 1 NoOp node for each device.

    deps = []

    def device_key(dev):

      """A sort key that allows None to be compared to strings."""

      return "" if dev is None else dev

    for dev in sorted(six.iterkeys(ops_on_device), key=device_key):

      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))

    with ops.control_dependencies(deps):

      return no_op(name=name)

def tuple(tensors, name=None, control_inputs=None):

  """Group tensors together.

  This creates a tuple of tensors with the same values as the `tensors`

  argument, except that the value of each tensor is only returned after the

  values of all tensors have been computed.

  `control_inputs` contains additional ops that have to finish before this op

  finishes, but whose outputs are not returned.

  This can be used as a "join" mechanism for parallel computations: all the

  argument tensors can be computed in parallel, but the values of any tensor

  returned by `tuple` are only available after all the parallel computations

  are done.

  See also @{tf.group$group} and

  @{tf.control_dependencies$control_dependencies}.

  Args:

    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.

    name: (optional) A name to use as a `name_scope` for the operation.

    control_inputs: List of additional ops to finish before returning.

  Returns:

    Same as `tensors`.

  Raises:

    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.

    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`

      objects.

  """

  with ops.name_scope(name, "tuple", tensors) as name:

    gating_ops = [t.op for t in tensors if t is not None]

    if control_inputs:

      for c in control_inputs:

        if isinstance(c, ops.Tensor):

          c = c.op

        elif not isinstance(c, ops.Operation):

          raise TypeError("Control input must be Operation or Tensor: %s" % c)

        gating_ops.append(c)

    # Note that in order to ensure ordering in the pbtxt, we must take care to

    # ensure the order here.

    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.

    if not gating_ops:

      raise ValueError("Must have at least one Tensor: %s" % tensors)

    gate = group(*gating_ops)

    tpl = []

    for t in tensors:

      if t is not None:

        tpl.append(with_dependencies([gate], t))

      else:

        tpl.append(None)

    return tpl

def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):

  """Create a case operation.

  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.

  Each pair contains a boolean scalar tensor and a python callable that

  creates the tensors to be returned if the boolean evaluates to True.

  `default` is a callable generating a list of tensors. All the callables

  in `pred_fn_pairs` as well as `default` should return the same number

  and types of tensors.

  If `exclusive==True`, all predicates are evaluated, and an exception is

  thrown if more than one of the predicates evaluates to `True`.

  If `exclusive==False`, execution stops are the first predicate which

  evaluates to True, and the tensors generated by the corresponding function

  are returned immediately. If none of the predicates evaluate to True, this

  operation returns the tensors generated by `default`.

  `tf.case` supports nested structures as implemented in

  `tensorflow.python.util.nest`. All of the callables must return the same

  (possibly nested) value structure of lists, tuples, and/or named tuples.

  Singleton lists and tuples form the only exceptions to this: when returned by

  a callable, they are implicitly unpacked to single values. This

  behavior is disabled by passing `strict=True`.

  If an unordered dictionary is used for `pred_fn_pairs`, the order of the

  conditional tests is not guaranteed. However, the order is guaranteed to be

  deterministic, so that variables created in conditional branches are created

  in fixed order across runs.

  Example 1:

    Pseudocode:

    ```

      if (x < y) return 17;

      else return 23;

    ```

    Expressions:

    ```

      f1 = lambda: tf.constant(17)

      f2 = lambda: tf.constant(23)

      r = case([(tf.less(x, y), f1)], default=f2)

    ```

  Example 2:

    Pseudocode:

    ```

      if (x < y && x > z) raise OpError("Only one predicate may evaluate true");

      if (x < y) return 17;

      else if (x > z) return 23;

      else return -1;

    ```

    Expressions:

    ```

      def f1(): return tf.constant(17)

      def f2(): return tf.constant(23)

      def f3(): return tf.constant(-1)

      r = case({tf.less(x, y): f1, tf.greater(x, z): f2},

               default=f3, exclusive=True)

    ```

  Args:

    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a

                   callable which returns a list of tensors.

    default: A callable that returns a list of tensors.

    exclusive: True iff at most one predicate is allowed to evaluate to `True`.

    strict: A boolean that enables/disables 'strict' mode; see above.

    name: A name for this operation (optional).

  Returns:

    The tensors returned by the first pair whose predicate evaluated to True, or

    those returned by `default` if none does.

  Raises:

    TypeError: If `pred_fn_pairs` is not a list/dictionary.

    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.

    TypeError: If `fns[i]` is not callable for any i, or `default` is not

               callable.

  """

  pfp = pred_fn_pairs  # For readability

  if not (isinstance(pfp, list) or isinstance(pfp, _basetuple)

          or isinstance(pfp, dict)):

    raise TypeError("fns must be a list, tuple, or dict")

  if isinstance(pfp, dict):

    if isinstance(pfp, collections.OrderedDict):

      pfp = pfp.items()

    else:

      pfp = sorted(pfp.items(), key=lambda item: item[0].name)

      if not exclusive:

        logging.warn("%s: An unordered dictionary of predicate/fn pairs was "

                     "provided, but exclusive=False. The order of conditional "

                     "tests is deterministic but not guaranteed.", name)

  for tup in pfp:

    if not isinstance(tup, _basetuple) or len(tup) != 2:

      raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")

    pred, fn = tup

    if pred.dtype != dtypes.bool:

      raise TypeError("pred must be of type bool: %s", pred.name)

    if not callable(fn):

      raise TypeError("fn for pred %s must be callable." % pred.name)

  if not callable(default):

    raise TypeError("default must be callable.")

  preds, fns = map(list, zip(*pfp))

  with ops.name_scope(name, "case", [preds]):

    if not preds:

      return default()

    not_preds = []

    for i, p in enumerate(preds):

      with ops.name_scope("not_%d" % i):

        not_preds.append(math_ops.logical_not(p))

    and_not_preds = [constant_op.constant(True, name="always_true")]

    for i, notp in enumerate(not_preds):

      with ops.name_scope("and_not_%d" % i):

        and_not_preds.append(math_ops.logical_and(and_not_preds[-1], notp))

    # preds = [p1, p2, p3]

    # fns = [f1, f2, f3]

    # not_preds = [~p1, ~p2, ~p3]

    # and_not_preds = [True, ~p1, ~p1 & ~p2, ~p1 & ~p2 & ~p3]

    # case_preds = [p1,

    #               p2 & ~p1,

    #               p3 & ~p2 & ~p1,

    #              ~p3 & ~p2 & ~p1]

    case_preds = []

    for i, (p, and_not_p_prev) in enumerate(zip(preds, and_not_preds[:-1])):

      with ops.name_scope("case_%d" % i):

        case_preds.append(math_ops.logical_and(p, and_not_p_prev))

    with ops.name_scope("case_none_are_true"):

      case_preds.append(and_not_preds[-1])

    # Create an empty tensor, or list, with the right type and shape

    with ops.name_scope("case_create_empty"):

      def _create_empty_constant(dtype, shape):

        value = ("" if dtype == dtypes.string else dtype.as_numpy_dtype())

        if shape.ndims is None:

          return array_ops.constant(value, dtype=dtype)

        else:

          temp_shape = [1 if x.value is None else x.value for x in shape]

          result = array_ops.constant(value, shape=temp_shape, dtype=dtype)

          result._shape = shape  # pylint: disable=protected-access

          return result

      def _correct_empty(v):

        if isinstance(v, ops.Operation):

          return no_op()

        elif isinstance(v, tensor_array_ops.TensorArray):

          return v

        elif not hasattr(v, "dtype"):

          return ops.convert_to_tensor(v)

        elif isinstance(v, sparse_tensor.SparseTensor):

          return sparse_tensor.SparseTensor(indices=[[0] * len(v.get_shape())],

                                            values=[v.dtype.as_numpy_dtype()],

                                            dense_shape=v.get_shape())

        else:

          return _create_empty_constant(v.dtype, v.get_shape())

      empty = lambda: nest.map_structure(_correct_empty, default())

    # case_sequence = [

    #   cond(~p3 & ~p2 & ~p1, default, empty),

    #   cond(p3 & ~p2 & ~p1, f3, lambda: case_sequence[0]),

    #   cond(p2 & ~p1, f2, lambda: case_sequence[1]),

    #   cond(p1, f1, lambda: case_sequence[2])

    # ]

    #

    # And the return value will be case_sequence[-1]

    def _build_case():

      all_fns = [fn for fn in fns]

      all_fns.append(default)

      prev_case = None

      for i, (cp, fn) in enumerate(list(zip(case_preds, all_fns))[::-1]):

        prev_case = cond(

            cp, fn,

            empty if i == 0 else lambda: prev_case,

            strict=strict, name="If_%d" % i)

      return prev_case

    if exclusive:

      preds_c = array_ops.stack(preds, name="preds_c")

      num_true_conditions = math_ops.reduce_sum(

          math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")

      at_most_one_true_condition = math_ops.less(

          num_true_conditions, constant_op.constant(2, name="two_true_conds"))

      error_msg = [

          ("More than one condition evaluated as True but "

           "exclusive=True.  Conditions: (%s), Values:"

           % ", ".join([p.name for p in preds])),

          preds_c]

      with ops.control_dependencies([

          Assert(condition=at_most_one_true_condition,

                 data=error_msg, summarize=len(preds))]):

        case_seq = _build_case()

    else:

      case_seq = _build_case()

    if not strict:

      case_seq = _UnpackIfSingleton(case_seq)

    return case_seq

ops.register_proto_function(ops.GraphKeys.COND_CONTEXT,

                            proto_type=control_flow_pb2.CondContextDef,

                            to_proto=CondContext.to_proto,

                            from_proto=CondContext.from_proto)

ops.register_proto_function(ops.GraphKeys.WHILE_CONTEXT,

                            proto_type=control_flow_pb2.WhileContextDef,

                            to_proto=WhileContext.to_proto,

                            from_proto=WhileContext.from_proto)