def _infer_fft_length_for_rfft(input_tensor, fft_rank):
"""Infers the `fft_length` argument for a `rank` RFFT from `input_tensor`."""
# A TensorShape for the inner fft_rank dimensions.
fft_shape = input_tensor.get_shape()[-fft_rank:]
# If any dim is unknown, fall back to tensor-based math.
if not fft_shape.is_fully_defined():
return _array_ops.shape(input_tensor)[-fft_rank:]
# Otherwise, return a constant.
return _ops.convert_to_tensor(fft_shape.as_list(), _dtypes.int32)