diff --git a/.bazelrc b/.bazelrc index 53a4cf9581f718..fcef170ddedfe5 100644 --- a/.bazelrc +++ b/.bazelrc @@ -17,6 +17,9 @@ # ios_x86_64: # ios_fat: # +# Macosx options +# darwin_arm64: +# # Compiler options: # cuda_clang: Use clang when building CUDA code. # c++17: Build with C++17 options (links with libc++) @@ -35,6 +38,9 @@ # monolithic: Build all TF C++ code into a single shared object. # dynamic_kernels: Try to link all kernels dynamically (experimental). # libc++: Link against libc++ instead of stdlibc++ +# asan: Build with the clang address sanitizer +# msan: Build with the clang memory sanitizer +# ubsan: Build with the clang undefined behavior sanitizer # # # TF version options; @@ -44,12 +50,10 @@ # Feature and Third party library support options: # xla: Build TF with XLA # tpu: Build TF with TPU support -# using_cuda: CUDA is available to build system. # cuda: Build with full cuda support. # rocm: Build with AMD GPU support (rocm). # mkl: Enable full mkl support. # tensorrt: Enable Tensorrt support. -# ngraph: Enable ngraph support. # numa: Enable numa using hwloc. # noaws: Disable AWS S3 storage support # nogcp: Disable GCS support. @@ -80,15 +84,65 @@ # elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support. # # Release build options (for all operating systems) -# release_common: Common options for all builds on all operating systems. -# release_windows_common: Common options for all builds on Windows. -# release_gpu_common: Common options for GPU builds on Linux and Windows. -# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds. -# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. -# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. -# release_gpu_linux_cuda_10_1: Toolchain and CUDA options for CUDA 10.1 Linux GPU builds. -# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. -# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. +# release_base: Common options for all builds on all operating systems. +# release_gpu_base: Common options for GPU builds on Linux and Windows. +# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds. +# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. +# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. +# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. +# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. + +# Default build options. These are applied first and unconditionally. + +# For projects which use TensorFlow as part of a Bazel build process, putting +# nothing in a bazelrc will default to a monolithic build. The following line +# opts in to modular op registration support by default. +build --define framework_shared_object=true + +# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1 +build --java_toolchain=@tf_toolchains//toolchains/java:tf_java_toolchain +build --host_java_toolchain=@tf_toolchains//toolchains/java:tf_java_toolchain + +build --define=use_fast_cpp_protos=true +build --define=allow_oversize_protos=true + +build --spawn_strategy=standalone +build -c opt + +# Make Bazel print out all options from rc files. +build --announce_rc + +build --define=grpc_no_ares=true + +# See https://github.com/bazelbuild/bazel/issues/7362 for information on what +# --incompatible_remove_legacy_whole_archive flag does. +# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate +# Tensorflow to the default, however test coverage wasn't enough to catch the +# errors. +# There is ongoing work on Bazel team's side to provide support for transitive +# shared libraries. As part of migrating to transitive shared libraries, we +# hope to provide a better mechanism for control over symbol exporting, and +# then tackle this issue again. +# +# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library +# archives in -whole_archive -no_whole_archive. +build --noincompatible_remove_legacy_whole_archive + +# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0 +# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC: +# https://github.com/tensorflow/community/pull/179 +build --noincompatible_prohibit_aapt1 + +build --enable_platform_specific_config + +# Enable XLA support by default. +build --define=with_xla_support=true + +build --config=short_logs + +build --config=v2 + +# Default options should come above this line. # Allow builds using libc++ as a linker library # This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file @@ -118,7 +172,13 @@ build:android_x86_64 --cpu=x86_64 build:android_x86_64 --fat_apk_cpu=x86_64 # Sets the default Apple platform to macOS. -build --apple_platform_type=macos +build:macos --apple_platform_type=macos + +# gRPC on MacOS requires this #define +build:macos --copt=-DGRPC_BAZEL_BUILD + +# Settings for MacOS on ARM CPUs. +build:macos_arm64 --cpu=darwin_arm64 # iOS configs for each architecture and the fat binary builds. build:ios --apple_platform_type=ios @@ -141,19 +201,6 @@ build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64 # //tensorflow:libtensorflow_framework.so. build:monolithic --define framework_shared_object=false -# For projects which use TensorFlow as part of a Bazel build process, putting -# nothing in a bazelrc will default to a monolithic build. The following line -# opts in to modular op registration support by default. -build --define framework_shared_object=true - -# Flags for open source build, always set to be true. -build --define open_source_build=true -test --define open_source_build=true - -# For workaround https://github.com/bazelbuild/bazel/issues/8772 with Bazel >= 0.29.1 -build --java_toolchain=//third_party/toolchains/java:tf_java_toolchain -build --host_java_toolchain=//third_party/toolchains/java:tf_java_toolchain - # Please note that MKL on MacOS or windows is still not supported. # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. @@ -166,45 +213,28 @@ build:mkl -c opt build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_threadpool --define=build_with_mkl_opensource=true -build:mkl_threadpool --define=build_with_mkldnn_threadpool=true build:mkl_threadpool -c opt -# Config setting to build with oneDNN and without the binary blob -build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true -build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 -build:mkl_opensource_only --define=build_with_mkl_opensource=true -build:mkl_opensource_only --define=build_with_openmp=true -build:mkl_opensource_only -c opt - -# Config setting to build with oneDNN for Arm. +# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). +# This build is for the inference regime only. build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_aarch64 --define=build_with_mkl_opensource=true +build:mkl_aarch64 --define=build_with_openmp=true build:mkl_aarch64 -c opt -# This config refers to building with CUDA available. It does not necessarily -# mean that we build CUDA op kernels. -build:using_cuda --define=using_cuda=true -build:using_cuda --action_env TF_NEED_CUDA=1 -build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain - -# Enable the mlir generated GPU kernels only for cuda builds. -build --define=tensorflow_enable_mlir_generated_gpu_kernels=0 -# This is a more specific option, so it takes precedence over the line above for cuda builds. -build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1 - # This config refers to building CUDA op kernels with nvcc. -build:cuda --config=using_cuda -build:cuda --define=using_cuda_nvcc=true +build:cuda --repo_env TF_NEED_CUDA=1 +build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +build:cuda --@local_config_cuda//:enable_cuda # This config refers to building CUDA op kernels with clang. -build:cuda_clang --config=using_cuda -build:cuda_clang --define=using_cuda_clang=true -build:cuda_clang --define=using_clang=true -build:cuda_clang --action_env TF_CUDA_CLANG=1 +build:cuda_clang --config=cuda +build:cuda_clang --repo_env TF_CUDA_CLANG=1 +build:cuda_clang --@local_config_cuda//:cuda_compiler=clang -# dbg config, as a shorthand for '--config=opt -c dbg' -build:dbg --config=opt -c dbg +# Debug config +build:dbg -c dbg # for now, disable arm_neon. see: https://github.com/tensorflow/tensorflow/issues/33360 build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON # AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 @@ -213,14 +243,13 @@ build:dbg --copt -DDEBUG_BUILD # Config to build TPU backend build:tpu --define=with_tpu_support=true -build:tensorrt --action_env TF_NEED_TENSORRT=1 +build:tensorrt --repo_env TF_NEED_TENSORRT=1 build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true -build:rocm --action_env TF_NEED_ROCM=1 +build:rocm --repo_env TF_NEED_ROCM=1 # Options extracted from configure script -build:ngraph --define=with_ngraph_support=true build:numa --define=with_numa_support=true # Options to disable default on features @@ -231,37 +260,6 @@ build:nonccl --define=no_nccl_support=true build:stackdriver_support --define=stackdriver_support=true -build --define=use_fast_cpp_protos=true -build --define=allow_oversize_protos=true - -build --spawn_strategy=standalone -build -c opt - -# Make Bazel print out all options from rc files. -build --announce_rc - -# Other build flags. -build --define=grpc_no_ares=true - -# See https://github.com/bazelbuild/bazel/issues/7362 for information on what -# --incompatible_remove_legacy_whole_archive flag does. -# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate -# Tensorflow to the default, however test coverage wasn't enough to catch the -# errors. -# There is ongoing work on Bazel team's side to provide support for transitive -# shared libraries. As part of migrating to transitive shared libraries, we -# hope to provide a better mechanism for control over symbol exporting, and -# then tackle this issue again. -# -# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library -# archives in -whole_archive -no_whole_archive. -build --noincompatible_remove_legacy_whole_archive - -# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0 -# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC: -# https://github.com/tensorflow/community/pull/179 -build --noincompatible_prohibit_aapt1 - # Modular TF build options build:dynamic_kernels --define=dynamic_loaded_kernels=true build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS @@ -273,9 +271,7 @@ build:c++1z --config=c++17 build:c++17_gcc --cxxopt=-std=c++1z build:c++1z_gcc --config=c++17_gcc -# Enable using platform specific build settings, except when cross-compiling for -# mobile platforms. -build --enable_platform_specific_config +# Don't trigger --config= when cross-compiling. build:android --noenable_platform_specific_config build:ios --noenable_platform_specific_config @@ -296,9 +292,11 @@ build:windows --host_copt=/D_USE_MATH_DEFINES build:linux --define=PREFIX=/usr build:linux --define=LIBDIR=$(PREFIX)/lib build:linux --define=INCLUDEDIR=$(PREFIX)/include +build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include build:macos --define=PREFIX=/usr build:macos --define=LIBDIR=$(PREFIX)/lib build:macos --define=INCLUDEDIR=$(PREFIX)/include +build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. # By default, build TF in C++ 14 mode. @@ -345,10 +343,9 @@ build:windows --verbose_failures # On windows, we never cross compile build:windows --distinct_host_configuration=false -# Suppress all warning messages. +# Configure short or long logs build:short_logs --output_filter=DONT_MATCH_ANYTHING build:verbose_logs --output_filter= -build --config=short_logs # Instruction set optimizations # TODO(gunan): Create a feature in toolchains for avx/avx2 to @@ -361,15 +358,13 @@ build:avx_win --copt=/arch=AVX build:avx2_win --copt=/arch=AVX2 # Options to build TensorFlow 1.x or 2.x. -build:v1 --define=tf_api_version=1 -build:v2 --define=tf_api_version=2 -build:v1 --action_env=TF2_BEHAVIOR=0 -build:v2 --action_env=TF2_BEHAVIOR=1 -build --config=v2 -test --config=v2 +build:v1 --define=tf_api_version=1 --action_env=TF2_BEHAVIOR=0 +build:v2 --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1 -# Enable XLA -build:xla --define=with_xla_support=true +# Disable XLA on mobile. +build:xla --define=with_xla_supprt=true # TODO: remove, it's on by default. +build:android --define=with_xla_support=false +build:ios --define=with_xla_support=false # BEGIN TF REMOTE BUILD EXECUTION OPTIONS # Options when using remote execution @@ -378,7 +373,7 @@ build:xla --define=with_xla_support=true # Flag to enable remote config common --experimental_repo_remote_exec -build:rbe --action_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 +build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 build:rbe --google_default_credentials build:rbe --bes_backend=buildeventservice.googleapis.com build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" @@ -403,9 +398,7 @@ build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 # Non-rbe settings we should include because we do not run configure -build:rbe_linux --config=xla build:rbe_linux --config=avx_linux -build:rbe_linux --config=short_logs # TODO(gunan): Check why we need this specified in rbe, but not in other builds. build:rbe_linux --linkopt=-lrt build:rbe_linux --host_linkopt=-lrt @@ -413,82 +406,63 @@ build:rbe_linux --linkopt=-lm build:rbe_linux --host_linkopt=-lm build:rbe_cpu_linux --config=rbe_linux -build:rbe_cpu_linux --host_crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" -build:rbe_cpu_linux --crosstool_top="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" -build:rbe_cpu_linux --extra_toolchains="//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" +build:rbe_cpu_linux --host_crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" +build:rbe_cpu_linux --crosstool_top="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain" +build:rbe_cpu_linux --extra_toolchains="@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:cc-toolchain-k8" build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_cpu_linux --extra_execution_platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_cpu_linux --host_platform="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_cpu_linux --platforms="@ubuntu16.04-manylinux2010-py3_config_platform//:platform" build:rbe_linux_cuda_base --config=rbe_linux -build:rbe_linux_cuda_base --repo_env=TF_NEED_TENSORRT=1 -build:rbe_linux_cuda_base --repo_env=TF_CUDA_VERSION=10 -build:rbe_linux_cuda_base --repo_env=TF_CUDNN_VERSION=7 +build:rbe_linux_cuda_base --config=cuda +build:rbe_linux_cuda_base --config=tensorrt +build:rbe_linux_cuda_base --action_env=TF_CUDA_VERSION=11 +build:rbe_linux_cuda_base --action_env=TF_CUDNN_VERSION=8 build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1 test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true -build:rbe_linux_cuda10.1_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" -build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" -build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" -build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" -build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" -build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" -build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" -build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" - -build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true -build:rbe_linux_cuda11.0_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" -build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" -build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" -build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda" -build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt" -build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl" -build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7" -build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5" -build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6" -build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7" -build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8" - -# Map default to CUDA 11 for PY35 and greater. -build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7 -build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5 -build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6 -build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7 -build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8 +build:rbe_linux_cuda11.2_nvcc_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda11.2_nvcc_base --host_crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.2_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.2_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda11.2_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.2_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.2_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda" +build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt" +build:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl" +build:rbe_linux_cuda11.2_nvcc_py3.6 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.6" +build:rbe_linux_cuda11.2_nvcc_py3.7 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7" +build:rbe_linux_cuda11.2_nvcc_py3.8 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8" +build:rbe_linux_cuda11.2_nvcc_py3.9 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9" + +# Map default to CUDA 11.2. +build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.2_nvcc_py3.6 +build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.2_nvcc_py3.7 +build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.2_nvcc_py3.8 +build:rbe_linux_cuda_nvcc_py39 --config=rbe_linux_cuda11.2_nvcc_py3.9 # Deprecated configs that people might still use. build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36 build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_clang_base --host_platform="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_clang_base --platforms="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" -build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" -build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" -build:rbe_linux_cuda_clang_base --define=using_cuda_clang=true -build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" -build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" -build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" -build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" -build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" +build:rbe_linux_cuda_clang_base --repo_env TF_CUDA_CLANG=1 +build:rbe_linux_cuda_clang_base --@local_config_cuda//:cuda_compiler=clang +build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda_clang_base --extra_toolchains="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda_clang_base --extra_execution_platforms="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda_clang_base --host_platform="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda_clang_base --platforms="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda_clang_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda" +build:rbe_linux_cuda_clang_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt" +build:rbe_linux_cuda_clang_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl" +build:rbe_linux_cuda_clang_py27 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python2.7" +build:rbe_linux_cuda_clang_py35 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.5" +build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.6" +build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7" +build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-clang_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8" # ROCm build:rbe_linux_rocm_base --config=rbe_linux @@ -544,8 +518,6 @@ build:rbe_win_py38 --python_path=C:\\Python38\\python.exe build:tensorflow_testing_rbe --project_id=tensorflow-testing common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe -build:tensorflow_testing_rbe_linux --config=rbe -build:tensorflow_testing_rbe_linux --config=rbe_linux common:tensorflow_testing_rbe_win --remote_instance_name=projects/tensorflow-testing/instances/windows build:tensorflow_testing_rbe_win --config=tensorflow_testing_rbe @@ -559,54 +531,77 @@ build:elinux_armhf --config=elinux build:elinux_armhf --cpu=armhf # END TF REMOTE BUILD EXECUTION OPTIONS -# Default options should come above this line +# Config-specific options should come above this line. -# Options from ./configure +# Load rc file written by ./configure. try-import %workspace%/.tf_configure.bazelrc -# Put user-specific options in .bazelrc.user +# Load rc file with user-specific options. try-import %workspace%/.bazelrc.user # Here are bazelrc configs for release builds -build:release_common --config=opt -build:release_common --config=v2 -build:release_common --distinct_host_configuration=false -build:release_common --action_env TF_CONFIGURE_IOS="0" +build:release_base --config=v2 +build:release_base --distinct_host_configuration=false +test:release_base --flaky_test_attempts=3 +test:release_base --test_size_filters=small,medium -build:release_cpu_linux --config=release_common +build:release_cpu_linux --config=release_base build:release_cpu_linux --config=avx_linux -# We use the same toolchain for CPU/GPU packages. -# Did not add this to the defaults in case this changes. -build:release_cpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain +build:release_cpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain +test:release_cpu_linux --test_env=LD_LIBRARY_PATH -build:release_cpu_macos --config=release_common +build:release_cpu_macos --config=release_base build:release_cpu_macos --config=avx_linux -build:release_gpu_common --config=release_common -build:release_gpu_common --config=cuda -build:release_gpu_common --config=tensorrt -build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0" -build:release_gpu_common --action_env=TF_CUDA_VERSION="11" -build:release_gpu_common --action_env=TF_CUDNN_VERSION="8" -build:release_gpu_common --action_env=TF_NEED_TENSORRT="1" -build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt" -build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" -build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" - - -build:release_gpu_linux --config=release_gpu_common -build:release_gpu_linux --config=avx_linux -build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain -build:release_windows_common --config=release_common -build:release_windows_common --define=no_tensorflow_py_deps=true -build:release_windows_common --announce_rc - -build:release_cpu_windows --config=release_windows_common - -build:release_gpu_windows --config=release_windows_common - -build:release_gpu_linux_cuda_10_1 --config=release_gpu_linux -build:release_gpu_linux_cuda_10_1 --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1" -build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDA_VERSION="10" -build:release_gpu_linux_cuda_10_1 --action_env=TF_CUDNN_VERSION="7" +build:release_gpu_base --config=cuda +build:release_gpu_base --action_env=TF_CUDA_VERSION="11" +build:release_gpu_base --action_env=TF_CUDNN_VERSION="8" +build:release_gpu_base --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" + +build:release_gpu_linux --config=release_cpu_linux +build:release_gpu_linux --config=release_gpu_base +build:release_gpu_linux --config=tensorrt +build:release_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" +build:release_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build:release_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" +build:release_gpu_linux --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain + +build:release_cpu_windows --config=release_base +build:release_cpu_windows --config=avx_win +build:release_cpu_windows --define=no_tensorflow_py_deps=true +# First available in VS 16.4. Speeds Windows compile times by a lot. See +# https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion +build:release_cpu_windows --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions + +build:release_gpu_windows --config=release_cpu_windows +build:release_gpu_windows --config=release_gpu_base + +# Address sanitizer +# CC=clang bazel build --config asan +build:asan --strip=never +build:asan --copt -fsanitize=address +build:asan --copt -DADDRESS_SANITIZER +build:asan --copt -g +build:asan --copt -O3 +build:asan --copt -fno-omit-frame-pointer +build:asan --linkopt -fsanitize=address + +# Memory sanitizer +# CC=clang bazel build --config msan +build:msan --strip=never +build:msan --copt -fsanitize=memory +build:msan --copt -DADDRESS_SANITIZER +build:msan --copt -g +build:msan --copt -O3 +build:msan --copt -fno-omit-frame-pointer +build:msan --linkopt -fsanitize=memory + +# Undefined Behavior Sanitizer +# CC=clang bazel build --config ubsan +build:ubsan --strip=never +build:ubsan --copt -fsanitize=undefined +build:ubsan --copt -g +build:ubsan --copt -O3 +build:ubsan --copt -fno-omit-frame-pointer +build:ubsan --linkopt -fsanitize=undefined +build:ubsan --linkopt -lubsan diff --git a/.bazelversion b/.bazelversion index fd2a01863fdd30..0b2eb36f508590 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -3.1.0 +3.7.2 diff --git a/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md b/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md index 6eab765e84e418..70be52989048c2 100644 --- a/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md +++ b/.github/ISSUE_TEMPLATE/60-tflite-converter-issue.md @@ -1,46 +1,47 @@ --- -name: TensorFlow Lite New Converter Issue +name: TensorFlow Lite Converter Issue about: Use this template for reporting issues during model conversion to TFLite labels: 'TFLiteConverter' --- +### 1. System information -**System information** - OS Platform and Distribution (e.g., Linux Ubuntu 16.04): -- TensorFlow installed from (source or binary): -- TensorFlow version (or github SHA if from source): +- TensorFlow installation (pip package or built from source): +- TensorFlow library (version, if pip package or github SHA, if built from source): +### 2. Code -**Command used to run the converter or code if you’re using the Python API** -If possible, please share a link to Colab/Jupyter/any notebook. +Provide code to help us reproduce your issues using one of the following options: -``` -# Copy and paste here the exact command -``` +#### Option A: Reference colab notebooks -**The output from the converter invocation** +1) Reference [TensorFlow Model Colab](https://colab.research.google.com/gist/ymodak/e96a4270b953201d5362c61c1e8b78aa/tensorflow-datasets.ipynb?authuser=1): Demonstrate how to build your TF model. +2) Reference [TensorFlow Lite Model Colab](https://colab.research.google.com/gist/ymodak/0dfeb28255e189c5c48d9093f296e9a8/tensorflow-lite-debugger-colab.ipynb): Demonstrate how to convert your TF model to a TF Lite model (with quantization, if used) and run TFLite Inference (if possible). ``` -# Copy and paste the output here. +(You can paste links or attach files by dragging & dropping them below) +- Provide links to your updated versions of the above two colab notebooks. +- Provide links to your TensorFlow model and (optionally) TensorFlow Lite Model. ``` -**Also, please include a link to the saved model or GraphDef** +#### Option B: Paste your code here or provide a link to a custom end-to-end colab ``` -# Put link here or attach to the issue. +(You can paste links or attach files by dragging & dropping them below) +- Include code to invoke the TFLite Converter Python API and the errors. +- Provide links to your TensorFlow model and (optionally) TensorFlow Lite Model. ``` -**Failure details** -If the conversion is successful, but the generated model is wrong, -state what is wrong: -- Producing wrong results and/or decrease in accuracy -- Producing correct results, but the model is slower than expected (model generated from old converter) +### 3. Failure after conversion +If the conversion is successful, but the generated model is wrong, then state what is wrong: +- Model produces wrong results and/or has lesser accuracy. +- Model produces correct results, but it is slower than expected. -**RNN conversion support** +### 4. (optional) RNN conversion support If converting TF RNN to TFLite fused RNN ops, please prefix [RNN] in the title. -**Any other info / logs** - +### 5. (optional) Any other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. diff --git a/.github/workflows/update-nightly.yml b/.github/workflows/update-nightly.yml index 01b5147d0538f7..0265ffbebe2ec0 100644 --- a/.github/workflows/update-nightly.yml +++ b/.github/workflows/update-nightly.yml @@ -20,6 +20,7 @@ on: name: Set nightly branch to master HEAD jobs: master-to-nightly: + if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks runs-on: ubuntu-latest steps: - uses: zofrex/mirror-branch@v1 diff --git a/.zenodo.json b/.zenodo.json new file mode 100644 index 00000000000000..7161180c51ae3e --- /dev/null +++ b/.zenodo.json @@ -0,0 +1,13 @@ +{ + "description": "TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, and community resources that lets researchers push the state-of-the-art in ML and developers easily build and deploy ML-powered applications.", + "license": "Apache-2.0", + "title": "TensorFlow", + "upload_type": "software", + "creators": [ + { + "name": "TensorFlow Developers" + } + ], + "access_right": "open", + "notes": "Specific TensorFlow versions can be found in the \"Versions\" list on the right side of this page.
See the full list of authors on GitHub." +} diff --git a/BUILD b/BUILD index 1200cf5f7103ca..8238d5e1acf065 100644 --- a/BUILD +++ b/BUILD @@ -1,8 +1,6 @@ -exports_files( - [ - "LICENSE", - "ACKNOWLEDGEMENTS", - "configure", - "configure.py", - ], -) +exports_files([ + "configure", + "configure.py", + "ACKNOWLEDGEMENTS", + "LICENSE", +]) diff --git a/CODEOWNERS b/CODEOWNERS index 9de1922a262794..3b0565b3e4acf8 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -3,6 +3,8 @@ /tensorflow/c/eager @qqfish @kkimdev /tensorflow/core/common_runtime/eager @qqfish @kkimdev /tenosrflow/core/debug @caisq +/tensorflow/core/kernels/mkl/ @penpornk +/tensorflow/core/kernels/sparse/ @penpornk /tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mihaimaruseac /tensorflow/lite/experimental/micro @petewarden @advaitjain diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 4992e54e7f60f5..e5203a7cf2286e 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -64,13 +64,7 @@ If you are experiencing or witnessing conflict, we ask you to use the following ## Reporting Violations -Violations of the Code of Conduct can be reported to TensorFlow’s Project -Stewards, Edd Wilder-James (ewj@google.com) and Thea Lamkin -(thealamkin@google.com). The Project Steward will determine whether the Code of -Conduct was violated, and will issue an appropriate sanction, possibly including -a written warning or expulsion from the project, project sponsored spaces, or -project forums. We ask that you make a good-faith effort to resolve your -conflict via the conflict resolution policy before submitting a report. +Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Thea Lamkin (thealamkin@google.com) and Joana Carrasqueira (joanafilipa@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report. diff --git a/ISSUES.md b/ISSUES.md index aabd3158b39d37..a6c77f76950d39 100644 --- a/ISSUES.md +++ b/ISSUES.md @@ -1,7 +1,9 @@ -If you open a GitHub Issue, here is our policy: 1. It must be a bug/performance -issue or a feature request or a build issue or a documentation issue (for small -doc fixes please send a PR instead). 2. Make sure the Issue Template is filled -out. 3. The issue should be related to the repo it is created in. +If you open a GitHub Issue, here is our policy: + +1. It must be a bug/performance issue or a feature request or a build issue or + a documentation issue (for small doc fixes please send a PR instead). +1. Make sure the Issue Template is filled out. +1. The issue should be related to the repo it is created in. **Here's why we have this policy:** We want to focus on the work that benefits the whole community, e.g., fixing bugs and adding features. Individual support diff --git a/LICENSE b/LICENSE index 40f8c347693afa..fb26962baedc4e 100644 --- a/LICENSE +++ b/LICENSE @@ -201,3 +201,48 @@ Copyright 2019 The TensorFlow Authors. All rights reserved. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +MIT License + +Copyright (c) 2017-2021 Arm Limited + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md index 63d85ce2df4a9a..fb3eddce6f110d 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,6 @@ [![Python](https://img.shields.io/pypi/pyversions/tensorflow.svg?style=plastic)](https://badge.fury.io/py/tensorflow) [![PyPI](https://badge.fury.io/py/tensorflow.svg)](https://badge.fury.io/py/tensorflow) - **`Documentation`** | ------------------- | [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | @@ -61,6 +60,7 @@ commands. *Nightly binaries are available for testing using the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and [tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.* + #### *Try your first TensorFlow program* ```shell @@ -114,11 +114,11 @@ Build Type | Status **Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl) **Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl) -**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) -**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) -**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) -**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) -**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [Nightly GCS](https://storage.googleapis.com/libtensorflow-nightly) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow MacOS CPU** | Status Temporarily Unavailable | [Nightly Binary](https://storage.googleapis.com/libtensorflow-nightly/prod/tensorflow/release/macos/latest/macos_cpu_libtensorflow_binaries.tar.gz) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux CPU** | Status Temporarily Unavailable | [Nightly Binary](https://storage.googleapis.com/libtensorflow-nightly/prod/tensorflow/release/ubuntu_16/latest/cpu/ubuntu_cpu_libtensorflow_binaries.tar.gz) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Linux GPU** | Status Temporarily Unavailable | [Nightly Binary](https://storage.googleapis.com/libtensorflow-nightly/prod/tensorflow/release/ubuntu_16/latest/gpu/ubuntu_gpu_libtensorflow_binaries.tar.gz) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows CPU** | Status Temporarily Unavailable | [Nightly Binary](https://storage.googleapis.com/libtensorflow-nightly/prod/tensorflow/release/windows/latest/cpu/windows_cpu_libtensorflow_binaries.tar.gz) [Official GCS](https://storage.googleapis.com/tensorflow/) +**Libtensorflow Windows GPU** | Status Temporarily Unavailable | [Nightly Binary](https://storage.googleapis.com/libtensorflow-nightly/prod/tensorflow/release/windows/latest/gpu/windows_gpu_libtensorflow_binaries.tar.gz) [Official GCS](https://storage.googleapis.com/tensorflow/) ### Community Supported Builds @@ -132,12 +132,20 @@ Build Type **Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) **Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) **Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) -**Linux aarch64 CPU** Nightly
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) -**Linux aarch64 CPU** Stable Release | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show) +**Linux aarch64 CPU** Nightly (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow-nightly)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow-nightly/) | [Nightly](http://snapshots.linaro.org/ldcg/python/tensorflow-nightly/latest/) +**Linux aarch64 CPU** Stable Release (Linaro) | [![Build Status](https://ci.linaro.org/jenkins/buildStatus/icon?job=ldcg-python-tensorflow)](https://ci.linaro.org/jenkins/job/ldcg-python-tensorflow/) | Release [1.x & 2.x](http://snapshots.linaro.org/ldcg/python/tensorflow/latest/) +**Linux aarch64 CPU** Nightly (OpenLab)
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) +**Linux aarch64 CPU** Stable Release (OpenLab) | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) **Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) **Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) +### Community Supported Containers + +Container Type | Status | Artifacts +----------------------------------------------------------------- | ------ | --------- +**TensorFlow aarch64 Neoverse-N1 CPU** Stable (Linaro)
Debian | Static | Release [2.3](https://hub.docker.com/r/linaro/tensorflow-arm-neoverse-n1) + ## Resources * [TensorFlow.org](https://www.tensorflow.org) @@ -147,12 +155,12 @@ Build Type * [DeepLearning.AI TensorFlow Developer Professional Certificate](https://www.coursera.org/specializations/tensorflow-in-practice) * [TensorFlow: Data and Deployment from Coursera](https://www.coursera.org/specializations/tensorflow-data-and-deployment) * [Getting Started with TensorFlow 2 from Coursera](https://www.coursera.org/learn/getting-started-with-tensor-flow2) +* [TensorFlow: Advanced Techniques from Coursera](https://www.coursera.org/specializations/tensorflow-advanced-techniques) +* [Intro to TensorFlow for A.I, M.L, and D.L from Coursera](https://www.coursera.org/learn/introduction-tensorflow) * [Intro to TensorFlow for Deep Learning from Udacity](https://www.udacity.com/course/intro-to-tensorflow-for-deep-learning--ud187) * [Introduction to TensorFlow Lite from Udacity](https://www.udacity.com/course/intro-to-tensorflow-lite--ud190) * [Machine Learning with TensorFlow on GCP](https://www.coursera.org/specializations/machine-learning-tensorflow-gcp) * [TensorFlow Codelabs](https://codelabs.developers.google.com/?cat=TensorFlow) -* [TensorFlow Chat Room on StackOverflow (not actively monitored by the - TensorFlow team)](https://chat.stackoverflow.com/rooms/216694/tensorflow) * [TensorFlow Blog](https://blog.tensorflow.org) * [Learn ML with TensorFlow](https://www.tensorflow.org/resources/learn-ml) * [TensorFlow Twitter](https://twitter.com/tensorflow) diff --git a/RELEASE.md b/RELEASE.md index f2d3c3c6efe5b6..257b822306443c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,405 +1,1029 @@ -# Release 2.4.0 +# Release 2.5.1 + +This release introduces several vulnerability fixes: + +* Fixes a heap out of bounds access in sparse reduction operations ([CVE-2021-37635](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37635)) +* Fixes a floating point exception in `SparseDenseCwiseDiv` ([CVE-2021-37636](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37636)) +* Fixes a null pointer dereference in `CompressElement` ([CVE-2021-37637](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37637)) +* Fixes a null pointer dereference in `RaggedTensorToTensor` ([CVE-2021-37638](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37638)) +* Fixes a null pointer dereference and a heap OOB read arising from operations restoring tensors ([CVE-2021-37639](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37639)) +* Fixes an integer division by 0 in sparse reshaping ([CVE-2021-37640](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37640)) +* Fixes a division by 0 in `ResourceScatterDiv` ([CVE-2021-37642](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37642)) +* Fixes a heap OOB in `RaggedGather` ([CVE-2021-37641](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37641)) +* Fixes a `std::abort` raised from `TensorListReserve` ([CVE-2021-37644](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37644)) +* Fixes a null pointer dereference in `MatrixDiagPartOp` ([CVE-2021-37643](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37643)) +* Fixes an integer overflow due to conversion to unsigned ([CVE-2021-37645](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37645)) +* Fixes a bad allocation error in `StringNGrams` caused by integer conversion ([CVE-2021-37646](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37646)) +* Fixes a null pointer dereference in `SparseTensorSliceDataset` ([CVE-2021-37647](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37647)) +* Fixes an incorrect validation of `SaveV2` inputs ([CVE-2021-37648](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37648)) +* Fixes a null pointer dereference in `UncompressElement` ([CVE-2021-37649](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37649)) +* Fixes a segfault and a heap buffer overflow in `{Experimental,}DatasetToTFRecord` ([CVE-2021-37650](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37650)) +* Fixes a heap buffer overflow in `FractionalAvgPoolGrad` ([CVE-2021-37651](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37651)) +* Fixes a use after free in boosted trees creation ([CVE-2021-37652](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37652)) +* Fixes a division by 0 in `ResourceGather` ([CVE-2021-37653](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37653)) +* Fixes a heap OOB and a `CHECK` fail in `ResourceGather` ([CVE-2021-37654](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37654)) +* Fixes a heap OOB in `ResourceScatterUpdate` ([CVE-2021-37655](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37655)) +* Fixes an undefined behavior arising from reference binding to nullptr in `RaggedTensorToSparse` ([CVE-2021-37656](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37656)) +* Fixes an undefined behavior arising from reference binding to nullptr in `MatrixDiagV*` ops ([CVE-2021-37657](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37657)) +* Fixes an undefined behavior arising from reference binding to nullptr in `MatrixSetDiagV*` ops ([CVE-2021-37658](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37658)) +* Fixes an undefined behavior arising from reference binding to nullptr and heap OOB in binary cwise ops ([CVE-2021-37659](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37659)) +* Fixes a division by 0 in inplace operations ([CVE-2021-37660](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37660)) +* Fixes a crash caused by integer conversion to unsigned ([CVE-2021-37661](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37661)) +* Fixes an undefined behavior arising from reference binding to nullptr in boosted trees ([CVE-2021-37662](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37662)) +* Fixes a heap OOB in boosted trees ([CVE-2021-37664](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37664)) +* Fixes vulnerabilities arising from incomplete validation in `QuantizeV2` ([CVE-2021-37663](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37663)) +* Fixes vulnerabilities arising from incomplete validation in MKL requantization ([CVE-2021-37665](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37665)) +* Fixes an undefined behavior arising from reference binding to nullptr in `RaggedTensorToVariant` ([CVE-2021-37666](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37666)) +* Fixes an undefined behavior arising from reference binding to nullptr in unicode encoding ([CVE-2021-37667](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37667)) +* Fixes an FPE in `tf.raw_ops.UnravelIndex` ([CVE-2021-37668](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37668)) +* Fixes a crash in NMS ops caused by integer conversion to unsigned ([CVE-2021-37669](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37669)) +* Fixes a heap OOB in `UpperBound` and `LowerBound` ([CVE-2021-37670](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37670)) +* Fixes an undefined behavior arising from reference binding to nullptr in map operations ([CVE-2021-37671](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37671)) +* Fixes a heap OOB in `SdcaOptimizerV2` ([CVE-2021-37672](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37672)) +* Fixes a `CHECK`-fail in `MapStage` ([CVE-2021-37673](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37673)) +* Fixes a vulnerability arising from incomplete validation in `MaxPoolGrad` ([CVE-2021-37674](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37674)) +* Fixes an undefined behavior arising from reference binding to nullptr in shape inference ([CVE-2021-37676](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37676)) +* Fixes a division by 0 in most convolution operators ([CVE-2021-37675](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37675)) +* Fixes vulnerabilities arising from missing validation in shape inference for `Dequantize` ([CVE-2021-37677](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37677)) +* Fixes an arbitrary code execution due to YAML deserialization ([CVE-2021-37678](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37678)) +* Fixes a heap OOB in nested `tf.map_fn` with `RaggedTensor`s ([CVE-2021-37679](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37679)) +* Fixes a division by zero in TFLite ([CVE-2021-37680](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37680)) +* Fixes an NPE in TFLite ([CVE-2021-37681](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37681)) +* Fixes a vulnerability arising from use of unitialized value in TFLite ([CVE-2021-37682](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37682)) +* Fixes an FPE in TFLite division operations ([CVE-2021-37683](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37683)) +* Fixes an FPE in TFLite pooling operations ([CVE-2021-37684](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37684)) +* Fixes an infinite loop in TFLite ([CVE-2021-37686](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37686)) +* Fixes a heap OOB in TFLite ([CVE-2021-37685](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37685)) +* Fixes a heap OOB in TFLite's `Gather*` implementations ([CVE-2021-37687](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37687)) +* Fixes an undefined behavior arising from null pointer dereference in TFLite ([CVE-2021-37688](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37688)) +* Fixes an undefined behavior arising from null pointer dereference in TFLite MLIR optimizations ([CVE-2021-37689](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37689)) +* Fixes a FPE in LSH in TFLite ([CVE-2021-37691](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37691)) +* Fixes a segfault on strings tensors with mismatched dimensions, arising in Go code ([CVE-2021-37692](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37692)) +* Fixes a use after free and a potential segfault in shape inference functions ([CVE-2021-37690](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-37690)) +* Updates `curl` to `7.77.0` to handle [CVE-2021-22876](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-22876), [CVE-2021-22897](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-22897), [CVE-2021-22898](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-22898), and [CVE-2021-22901](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-22901). + +# Release 2.5.0 - +## Major Features and Improvements +* Support for Python3.9 has been added. +* TPU embedding support + * Added `profile_data_directory` to `EmbeddingConfigSpec` in + `_tpu_estimator_embedding.py`. This allows embedding lookup statistics + gathered at runtime to be used in embedding layer partitioning decisions. +* `tf.keras.metrics.AUC` now support logit predictions. +* Creating `tf.random.Generator` under `tf.distribute.Strategy` scopes is now allowed (except for `tf.distribute.experimental.CentralStorageStrategy` and `tf.distribute.experimental.ParameterServerStrategy`). Different replicas will get different random-number streams. +* `tf.data`: + * tf.data service now supports strict round-robin reads, which is useful + for synchronous training workloads where example sizes vary. With strict + round robin reads, users can guarantee that consumers get similar-sized + examples in the same step. + * tf.data service now supports optional compression. Previously data would + always be compressed, but now you can disable compression by passing + `compression=None` to `tf.data.experimental.service.distribute(...)`. + * `tf.data.Dataset.batch()` now supports `num_parallel_calls` and + `deterministic` arguments. `num_parallel_calls` is used to indicate that + multiple input batches should be computed in parallel. With + `num_parallel_calls` set, `deterministic` is used to indicate that + outputs can be obtained in the non-deterministic order. + * Options returned by `tf.data.Dataset.options()` are no longer mutable. + * tf.data input pipelines can now be executed in debug mode, which + disables any asynchrony, parallelism, or non-determinism and forces + Python execution (as opposed to trace-compiled graph execution) of + user-defined functions passed into transformations such as `map`. The + debug mode can be enabled through `tf.data.experimental.enable_debug_mode()`. +* `tf.lite` + * Enabled the new MLIR-based quantization backend by default + * The new backend is used for 8 bits full integer post-training quantization + * The new backend removes the redundant rescales and fixes some bugs (shared weight/bias, extremely small scales, etc) + * Set `experimental_new_quantizer` in tf.lite.TFLiteConverter to False to disable this change +* `tf.keras` + * Enabled a new supported input type in `Model.fit`, + `tf.keras.utils.experimental.DatasetCreator`, which takes a + callable, `dataset_fn`. + `DatasetCreator` is intended to work across all `tf.distribute` + strategies, and is the only input type supported for Parameter Server + strategy. +* `tf.distribute` + * `tf.distribute.experimental.ParameterServerStrategy` now supports + training with Keras `Model.fit` when used with `DatasetCreator`. +* PluggableDevice + * Third-party devices can now connect to TensorFlow as plug-ins through + [StreamExecutor C API](https://github.com/tensorflow/community/blob/master/rfcs/20200612-stream-executor-c-api.md). + and [PluggableDevice](https://github.com/tensorflow/community/blob/master/rfcs/20200624-pluggable-device-for-tensorflow.md) interface. + * Add custom ops and kernels through + [kernel and op registration C API](https://github.com/tensorflow/community/blob/master/rfcs/20190814-kernel-and-op-registration.md). + * Register custom graph optimization passes with + [graph optimization C API](https://github.com/tensorflow/community/blob/master/rfcs/20201027-modular-tensorflow-graph-c-api.md). +* [oneAPI Deep Neural Network Library (oneDNN)](https://github.com/oneapi-src/oneDNN) + CPU performance optimizations from + [Intel-optimized TensorFlow](https://software.intel.com/content/www/us/en/develop/articles/intel-optimization-for-tensorflow-installation-guide.html) + are now available in the official x86-64 Linux and Windows builds. + * They are off by default. Enable them by setting the environment variable + `TF_ENABLE_ONEDNN_OPTS=1`. + * We do not recommend using them in GPU systems, as they have not been + sufficiently tested with GPUs yet. +* TensorFlow pip packages are now built with CUDA11.2 and cuDNN 8.1.0 + ## Breaking Changes -* -* -* Certain float32 ops run in lower precsion on Ampere based GPUs, including - matmuls and convolutions, due to the use of - [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). - Specifically, inputs to such ops are rounded from 23 bits of precision to 10 - bits of precision. This is unlikely to cause issues in practice for deep - learning models. In some cases, TensorFloat-32 is also used for complex64 ops. - TensorFloat-32 can be disabled by running - `config.experimental.enable_tensor_float_32_execution(False)`. The "Major - Features and Improvements" section has more details. -* The byte layout for string tensors across the C-API has been updated to match - TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s. -* C-API functions `TF_StringDecode`, `TF_StringEncode`, and - `TF_StringEncodedSize` are no longer relevant and have been removed; see - core/platform/ctstring.h for string access/modification in C. -* Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2. -* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are - now hidden. These modules are not part of TensorFlow public API. -* A major refactoring of the internals of the Keras Functional API may affect code that is relying on certain internal details: - * Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`. - * Code that is overly dependent on the exact names attached to symbolic tensors (e.g. assumes there will be ":0" at the end of the inputs, treats names as unique identifiers instead of using `tensor.ref()`, etc.) - * Code that uses `get_concrete_function` to trace Keras symbolic inputs directly should switch to building matching `tf.TensorSpec`s directly and tracing the `TensorSpec` objects. - * Code that relies on the exact number and names of the op layers that TensorFlow operations were converted into. These may have changed. - * Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy. - * Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values. - * Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now. - * Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead. - * Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient. - * Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know. - * Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed. -* Start enforcing input shape assumptions when calling Functional API Keras - models. This may potentially break some users, in case there is a mismatch - between the shape used when creating `Input` objects in a Functional model, - and the shape of the data passed to that model. You can fix this mismatch by - either calling the model with correctly-shaped data, or by relaxing `Input` - shape assumptions (note that you can pass shapes with `None` entries for axes - that are meant to be dynamic). You can also disable the input checking - entirely by setting `model.input_spec = None`. -* TF pip packages now use CUDA11 and cuDNN 8.0.2. -* XLA:CPU and XLA:GPU devices are no longer registered by default. Use - `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be - removed). -* `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type - `tf.complex64` or `tf.complex128`, because the behavior of these ops is not - well defined for complex types. -* `tf.data.experimental.service.DispatchServer` now takes a config tuple - instead of individual arguments. Usages should be updated to - `tf.data.experimental.service.DispatchServer(dispatcher_config)`. -* `tf.data.experimental.service.WorkerServer` now takes a config tuple - instead of individual arguments. Usages should be updated to - `tf.data.experimental.service.WorkerServer(worker_config)`. -* `tf.quantization.quantize_and_dequantize_v2` has been introduced, which - updates the gradient definition for quantization which is outside the range - to be 0. To simulate the V1 the behavior of - tf.quantization.quantize_and_dequantize(...) use - tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...). -* `tf.distribute.Strategy.experimental_make_numpy_dataset` is removed. Please - use `tf.data.Dataset.from_tensor_slices` instead. -* `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`, - `tf.distribute.StrategyExtended.batch_reduce_to`, - `tf.distribute.ReplicaContext.all_reduce` are renamed to `options`. - `tf.distribute.experimental.CollectiveHints` is renamed - `tf.distribute.experimental.CommunicationOptions`. - `tf.distribute.experimental.CollectiveCommunication` is renamed - `tf.distribute.experimental.CommunicationImplementation`. +* The `TF_CPP_MIN_VLOG_LEVEL` environment variable has been renamed to to + `TF_CPP_MAX_VLOG_LEVEL` which correctly describes its effect. -## Known Caveats +## Bug Fixes and Other Changes -* +* `tf.keras`: + * Preprocessing layers API consistency changes: + * `StringLookup` added `output_mode`, `sparse`, and + `pad_to_max_tokens` arguments with same semantics as + `TextVectorization`. + * `IntegerLookup` added `output_mode`, `sparse`, and + `pad_to_max_tokens` arguments with same semantics as + `TextVectorization`. Renamed `max_values`, `oov_value` and + `mask_value` to `max_tokens`, `oov_token` and `mask_token` to align + with `StringLookup` and `TextVectorization`. + * `TextVectorization` default for `pad_to_max_tokens` switched to + False. + * `CategoryEncoding` no longer supports `adapt`, `IntegerLookup` + now supports equivalent functionality. `max_tokens` argument renamed + to `num_tokens`. + * `Discretization` added `num_bins` argument for learning bins + boundaries through calling `adapt` on a dataset. Renamed `bins` + argument to `bin_boundaries` for specifying bins without `adapt`. + * Improvements to model saving/loading: + * `model.load_weights` now accepts paths to saved models. + * Keras inputs can now be created directly from arbitrary `tf.TypeSpecs`. + * Two new learning rate schedules added: + `tf.keras.optimizers.schedules.CosineDecay` and + `tf.keras.optimizers.schedules.CosineDecayRestarts`. -## Major Features and Improvements +* `tf.data`: + * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used + to control how external state should be handled during dataset + serialization or iterator checkpointing. + * Changing `tf.data.experimental.save` to store the type specification of + the dataset elements. This avoids the need for explicitly specifying the + `element_spec` argument of `tf.data.experimental.load` when loading the + previously saved dataset. + * Add `.element_spec` property to `tf.data.DatasetSpec` to access the + inner spec. This can be used to extract the structure of nested + datasets. + * Add `tf.data.experimental.AutoShardingPolicy.HINT` which can be used + to provide hints to tf.distribute-based auto-sharding as to where in + the input pipeline to insert sharding transformations. + * Make tf.data.Options persistent across `tf.function` and `GraphDef` + boundaries. + +* XLA compilation: + * `tf.function(experimental_compile=True)` has become a stable API, + renamed `tf.function(jit_compile=True)`. + * XLA can now compile MirroredStrategy: the step function passed to + `strategy.run` can now be annoted with `jit_compile=True`. -* -* -* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) for details of what operations are supported and what are the differences from NumPy. -* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. -* Support for - [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) - on Ampere based GPUs has been added. TensorFloat-32, or TF32 for short, is a - math mode for NVIDIA Ampere GPUs which causes certain float32 ops, such as - matrix multiplications and convolutions, to run much faster on Ampere GPUs but - with reduced precision. This reduced precision has not been found to effect - convergence quality of deep learning models in practice. TensorFloat-32 is - enabled by default, but can be disabled with - `tf.config.experimental.enable_tensor_float_32_execution`. +* `tf.distribute`: + * Rename `experimental_prefetch_to_device` in `tf.distribute.InputOptions` + to `experimental_fetch_to_device` to better reflect the purpose. -* `tf.distribute`: - * `MultiWorkerMirroredStrategy` is graduated out of experimental. - * Peer failure will no longer cause the cluster to hang. - * Major issues with saving are fixed. - * See [Multi-worker training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for a tutorial. - * Deprecated `experimental_distribute_datasets_from_function` method and renamed it to `distribute_datasets_from_function` as it is no longer experimental. +* `tf.lite`: + * class `tflite::Subgraph`: + * Removed the `tensors()` method and the non-const overload of the + `nodes_and_registration()` method, both of which were previously + documented as temporary and to be removed. + * Uses of `tensors()` can be replaced by calling the existing + methods `tensors_size()` and `tensor(int)`. + * Uses of the non-const overload of `nodes_and_registration` + can be replaced by calling the existing methods `nodes_size()` + and `context()`, and then calling the `GetNodeAndRegistration` + method in the `TfLiteContext` returned by `context()`. + * NNAPI + * Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API. + * Use `NnApiDelegate()` and related delegate configuration methods + directly. + * Replaced the model cache key for models computation algorithm with + one guaranteed to be stable across runs. + * 16 bits quantization + * Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators. + * Additional tests and fixes for ADD and SUB operators. + * Added support for saved model's session initializer through + `TFLiteConverter.from_saved_model`. + * Added DEPTH_TO_SPACE support in Post training quantization. + * Added dynamic range quantization support for the BatchMatMul op. + * Both symmetric and asymmetric quantized input tensor are supported. + * Add `RFFT2D` as builtin op. (`RFFT2D` also supports `RFFTD`.) Currently + only supports float32 input. + * Add 5D support to `SLICE` op. + * TFLite Supports SingatureDef: + * TFLiteConverter exports models with SignatureDef + * Interpreter supports getting a list of signatures and getting callable + function for a given signaturedef. + * Add int8 support for `ReshapeV2`. + * Add experimental support for optimization with sparsity. + * Add nominal support for unsigned 32-bit integer tensor types. Note that + very few TFLite kernels support this type natively, so its use in mobile + ML authoring is generally discouraged. + * Add support for static hash tables through + `TFLiteConverter.from_saved_model`. + * The Python TF Lite Interpreter bindings now has an option + `experimental_preserve_all_tensors` to aid in debugging conversion. + * Quantized x86 execution defaults to Ruy GEMM library for platforms with + AVX support. + * Deprecate `tf.compat.v1.lite.experimental.get_potentially_supported_ops`. + Use `tf.lite.TFLiteConverter` directly to check whether a model is + convertible. + * Add support to select one of three different built-in op resolvers to be + * Enabled post training with calibrations for models that require user + provied TensorFlow Lite custom op libraries via + `converter.target_spec._experimental_custom_op_registerers`. + used in Python Interpreter API. +* TF Core: + * Corrected higher-order gradients of control flow constructs (`tf.cond`, + `tf.while_loop`, and compositions like `tf.foldl`) computed with + `tf.GradientTape` inside a `tf.function`. + * Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests. + * Added `tf.config.experimental.get_memory_info`, returning a dict with the + current and peak memory usage. Deprecated + `tf.config.experimental.get_memory_usage` in favor of this new function. + * Extended `tf.config.experimental.enable_tensor_float_32_execution` to + control Tensor-Float-32 evaluation in RNNs. + * Added a 'experimental_payloads' field to tf.errors.OpError and + its subclasses to support more detailed error reporting. + This is inspired from Abseil Status payloads: + https://github.com/abseil/abseil-cpp/blob/master/absl/status/status.h + +* `tf.summary`: + * New `tf.summary.graph` allows manual write of TensorFlow graph + (`tf.Graph` or `tf.compat.v1.GraphDef`) as a summary. This is not a + replacement for the trace-based API. + +* Set `/d2ReducedOptimizeHugeFunctions` by default for Windows builds. This + provides a big compile-time speedup, and effectively raises the minimum + supported MSVC version to 16.4 (current: 16.8). + * See: https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion + +* TensorRT + * Removed the deprecated `session_config` parameter for the TF1-TRT + converter `TrtGraphConverter`. Previously, we issued a warning when the + value of the parameter is not None. + * The TF2-TRT converter `TrtGraphConverterV2` takes an object of class + TrtConversionParams as a parameter. Removed three deprecated fields from + this class: `rewriter_config_template`, `is_dynamic_op`, and + `max_batch_size`. Previously, we issued a warning when the value of + `rewriter_config_template` is not None. We issued an error when the + value of `is_dynamic_op` is not True. We didn't use the value for + `max_batch_size` for building TensorRT engines. Add parameters + `use_dynamic_shape` to enable dynamic shape support. The default is to + disable dynamic shape support. Add `dynamic_shape_profile_strategy` + for selecting a dynamic shape profile strategy. The default is profile + strategy is `Range`. + * Issue a warning when function get_tensorrt_rewriter_config is used. + +* TF XLA + * Add new enum value `MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED` to + `tf.config.experimental.mlir_bridge_rollout` to enable a \"safe\" mode. + This runs the MLIR bridge only when an analysis of the graph only when + an analysis of the graph determines that it is safe to run. + * Add new enum value 'MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED' to + `tf.config.experimental.mlir_bridge_rollout` to enable a fallback for + the MLIR bridge in a \"safe\" mode. This runs the MLIR bridge in a + FallbackEnabled mode when an analysis of the graph determines + that the graph does not have unsupported features. +* Deterministic Op Functionality: + * Add determinism-unimplemented exception-throwing to the segment-sum ops. + When the environment variable `TF_DETERMINISTIC_OPS` is set to `"true"` + or `"1"` (when op-determinism is expected), an attempt to run the + folowing ops on a GPU will throw `tf.errors.UnimplementedError` (with an + understandable message) when `data` is a floating-point type, including + complex types (if supported): `tf.math.segment_prod`, + `tf.math.segment_sum`, `tf.math.unsorted_segment_mean`, + `tf.math.unsorted_segment_sqrt_n`, `tf.math.unsorted_segment_prod`, + `tf.math.unsorted_segment_sum`, and therefore also + `tf.convert_to_tensor` when `value` is of type `tf.IndexedSlices` (such + as in the backprop though `tf.gather` into a dense embedding). See + issue [39751](https://github.com/tensorflow/tensorflow/issues/39751) + which this change addresses, but does not solve. This exception-throwing + behavior can be disabled by setting the environment variable + `TF_DISABLE_SEGMENT_REDUCTION_OP_DETERMINISM_EXCEPTIONS` to `"true"` or + `"1"`. For more information about these changes, see the description in + pull request + [47772](https://github.com/tensorflow/tensorflow/pull/47772). + * In previous versions of TensorFlow, when a GPU was available, + `tf.sparse.sparse_dense_matmul` introduced truly random noise in the + forward path for data of type `tf.float32` but not for data of type + `tf.float64` (for which there was no GPU implementation). In this + current release, GPU support for other floating-point types + (`tf.float16`, `tf.float64`, `tf.complex64`, and `tf.complex128`) has + been added for this op. If you were relying on the determinism of the + `tf.float64` CPU implementation being automatically selected because of + the absence of the `tf.float64` GPU implementation, you with either + need to force the op to run on the CPU or use a different data type. +* Security + * Fixes a heap buffer overflow in `RaggedBinCount` ([CVE-2021-29512](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29512)) + * Fixes a heap out of bounds write in `RaggedBinCount` ([CVE-2021-29514](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29514)) + * Fixes a type confusion during tensor casts which leads to dereferencing null pointers ([CVE-2021-29513](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29513)) + * Fixes a reference binding to null pointer in `MatrixDiag*` ops ([CVE-2021-29515](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29515)) + * Fixes a null pointer dereference via invalid Ragged Tensors ([CVE-2021-29516](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29516)) + * Fixes a division by zero in `Conv3D` ([CVE-2021-29517](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29517)) + * Fixes vulnerabilities where session operations in eager mode lead to null pointer dereferences ([CVE-2021-29518](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29518)) + * Fixes a `CHECK`-fail in `SparseCross` caused by type confusion ([CVE-2021-29519](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29519)) + * Fixes a segfault in `SparseCountSparseOutput` ([CVE-2021-29521](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29521)) + * Fixes a heap buffer overflow in `Conv3DBackprop*` ([CVE-2021-29520](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29520)) + * Fixes a division by 0 in `Conv3DBackprop*` ([CVE-2021-29522](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29522)) + * Fixes a `CHECK`-fail in `AddManySparseToTensorsMap` ([CVE-2021-29523](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29523)) + * Fixes a division by 0 in `Conv2DBackpropFilter` ([CVE-2021-29524](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29524)) + * Fixes a division by 0 in `Conv2DBackpropInput` ([CVE-2021-29525](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29525)) + * Fixes a division by 0 in `Conv2D` ([CVE-2021-29526](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29526)) + * Fixes a division by 0 in `QuantizedConv2D` ([CVE-2021-29527](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29527)) + * Fixes a division by 0 in `QuantizedMul` ([CVE-2021-29528](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29528)) + * Fixes vulnerabilities caused by invalid validation in `SparseMatrixSparseCholesky` ([CVE-2021-29530](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29530)) + * Fixes a heap buffer overflow caused by rounding ([CVE-2021-29529](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29529)) + * Fixes a `CHECK`-fail in `tf.raw_ops.EncodePng` ([CVE-2021-29531](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29531)) + * Fixes a heap out of bounds read in `RaggedCross` ([CVE-2021-29532](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29532)) + * Fixes a `CHECK`-fail in `DrawBoundingBoxes` ([CVE-2021-29533](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29533)) + * Fixes a heap buffer overflow in `QuantizedMul` ([CVE-2021-29535](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29535)) + * Fixes a `CHECK`-fail in `SparseConcat` ([CVE-2021-29534](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29534)) + * Fixes a heap buffer overflow in `QuantizedResizeBilinear` ([CVE-2021-29537](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29537)) + * Fixes a heap buffer overflow in `QuantizedReshape` ([CVE-2021-29536](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29536)) + * Fixes a division by zero in `Conv2DBackpropFilter` ([CVE-2021-29538](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29538)) + * Fixes a heap buffer overflow in `Conv2DBackpropFilter` ([CVE-2021-29540](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29540)) + * Fixes a heap buffer overflow in `StringNGrams` ([CVE-2021-29542](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29542)) + * Fixes a null pointer dereference in `StringNGrams` ([CVE-2021-29541](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29541)) + * Fixes a `CHECK`-fail in `QuantizeAndDequantizeV4Grad` ([CVE-2021-29544](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29544)) + * Fixes a `CHECK`-fail in `CTCGreedyDecoder` ([CVE-2021-29543](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29543)) + * Fixes a heap buffer overflow in `SparseTensorToCSRSparseMatrix` ([CVE-2021-29545](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29545)) + * Fixes a division by 0 in `QuantizedBiasAdd` ([CVE-2021-29546](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29546)) + * Fixes a heap out of bounds in `QuantizedBatchNormWithGlobalNormalization` ([CVE-2021-29547](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29547)) + * Fixes a division by 0 in `QuantizedBatchNormWithGlobalNormalization` ([CVE-2021-29548](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29548)) + * Fixes a division by 0 in `QuantizedAdd` ([CVE-2021-29549](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29549)) + * Fixes a division by 0 in `FractionalAvgPool` ([CVE-2021-29550](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29550)) + * Fixes an OOB read in `MatrixTriangularSolve` ([CVE-2021-29551](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29551)) + * Fixes a heap OOB in `QuantizeAndDequantizeV3` ([CVE-2021-29553](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29553)) + * Fixes a `CHECK`-failure in `UnsortedSegmentJoin` ([CVE-2021-29552](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29552)) + * Fixes a division by 0 in `DenseCountSparseOutput` ([CVE-2021-29554](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29554)) + * Fixes a division by 0 in `FusedBatchNorm` ([CVE-2021-29555](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29555)) + * Fixes a division by 0 in `SparseMatMul` ([CVE-2021-29557](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29557)) + * Fixes a division by 0 in `Reverse` ([CVE-2021-29556](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29556)) + * Fixes a heap buffer overflow in `SparseSplit` ([CVE-2021-29558](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29558)) + * Fixes a heap OOB access in unicode ops ([CVE-2021-29559](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29559)) + * Fixes a heap buffer overflow in `RaggedTensorToTensor` ([CVE-2021-29560](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29560)) + * Fixes a `CHECK`-fail in `LoadAndRemapMatrix` ([CVE-2021-29561](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29561)) + * Fixes a `CHECK`-fail in `tf.raw_ops.IRFFT` ([CVE-2021-29562](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29562)) + * Fixes a `CHECK`-fail in `tf.raw_ops.RFFT` ([CVE-2021-29563](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29563)) + * Fixes a null pointer dereference in `EditDistance` ([CVE-2021-29564](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29564)) + * Fixes a null pointer dereference in `SparseFillEmptyRows` ([CVE-2021-29565](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29565)) + * Fixes a heap OOB access in `Dilation2DBackpropInput` ([CVE-2021-29566](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29566)) + * Fixes a reference binding to null in `ParameterizedTruncatedNormal` ([CVE-2021-29568](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29568)) + * Fixes a set of vulnerabilities caused by lack of validation in `SparseDenseCwiseMul` ([CVE-2021-29567](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29567)) + * Fixes a heap out of bounds read in `MaxPoolGradWithArgmax` ([CVE-2021-29570](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29570)) + * Fixes a heap out of bounds read in `RequantizationRange` ([CVE-2021-29569](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29569)) + * Fixes a memory corruption in `DrawBoundingBoxesV2` ([CVE-2021-29571](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29571)) + * Fixes a reference binding to nullptr in `SdcaOptimizer` ([CVE-2021-29572](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29572)) + * Fixes an overflow and a denial of service in `tf.raw_ops.ReverseSequence` ([CVE-2021-29575](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29575)) + * Fixes a division by 0 in `MaxPoolGradWithArgmax` ([CVE-2021-29573](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29573)) + * Fixes an undefined behavior in `MaxPool3DGradGrad` ([CVE-2021-29574](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29574)) + * Fixes a heap buffer overflow in `MaxPool3DGradGrad` ([CVE-2021-29576](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29576)) + * Fixes a heap buffer overflow in `AvgPool3DGrad` ([CVE-2021-29577](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29577)) + * Fixes an undefined behavior and a `CHECK`-fail in `FractionalMaxPoolGrad` ([CVE-2021-29580](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29580)) + * Fixes a heap buffer overflow in `FractionalAvgPoolGrad` ([CVE-2021-29578](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29578)) + * Fixes a heap buffer overflow in `MaxPoolGrad` ([CVE-2021-29579](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29579)) + * Fixes a segfault in `CTCBeamSearchDecoder` ([CVE-2021-29581](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29581)) + * Fixes a heap OOB read in `tf.raw_ops.Dequantize` ([CVE-2021-29582](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29582)) + * Fixes a `CHECK`-fail due to integer overflow ([CVE-2021-29584](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29584)) + * Fixes a heap buffer overflow and undefined behavior in `FusedBatchNorm` ([CVE-2021-29583](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29583)) + * Fixes a division by zero in padding computation in TFLite ([CVE-2021-29585](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29585)) + * Fixes a division by zero in optimized pooling implementations in TFLite ([CVE-2021-29586](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29586)) + * Fixes a division by zero in TFLite's implementation of `SpaceToDepth` ([CVE-2021-29587](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29587)) + * Fixes a division by zero in TFLite's implementation of `GatherNd` ([CVE-2021-29589](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29589)) + * Fixes a division by zero in TFLite's implementation of `TransposeConv` ([CVE-2021-29588](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29588)) + * Fixes a heap OOB read in TFLite's implementation of `Minimum` or `Maximum` ([CVE-2021-29590](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29590)) + * Fixes a null pointer dereference in TFLite's `Reshape` operator ([CVE-2021-29592](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29592)) + * Fixes a stack overflow due to looping TFLite subgraph ([CVE-2021-29591](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29591)) + * Fixes a division by zero in TFLite's implementation of `DepthToSpace` ([CVE-2021-29595](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29595)) + * Fixes a division by zero in TFLite's convolution code ([CVE-2021-29594](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29594)) + * Fixes a division by zero in TFLite's implementation of `EmbeddingLookup` ([CVE-2021-29596](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29596)) + * Fixes a division by zero in TFLite's implementation of `BatchToSpaceNd` ([CVE-2021-29593](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29593)) + * Fixes a division by zero in TFLite's implementation of `SpaceToBatchNd` ([CVE-2021-29597](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29597)) + * Fixes a division by zero in TFLite's implementation of `SVDF` ([CVE-2021-29598](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29598)) + * Fixes a division by zero in TFLite's implementation of `Split` ([CVE-2021-29599](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29599)) + * Fixes a division by zero in TFLite's implementation of `OneHot` ([CVE-2021-29600](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29600)) + * Fixes a division by zero in TFLite's implementation of `DepthwiseConv` ([CVE-2021-29602](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29602)) + * Fixes a division by zero in TFLite's implementation of hashtable lookup ([CVE-2021-29604](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29604)) + * Fixes a integer overflow in TFLite concatentation ([CVE-2021-29601](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29601)) + * Fixes a integer overflow in TFLite memory allocation ([CVE-2021-29605](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29605)) + * Fixes a heap OOB write in TFLite ([CVE-2021-29603](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29603)) + * Fixes a heap OOB read in TFLite ([CVE-2021-29606](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29606)) + * Fixes a heap OOB and null pointer dereference in `RaggedTensorToTensor` ([CVE-2021-29608](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29608)) + * Fixes vulnerabilities caused by incomplete validation in `SparseAdd` ([CVE-2021-29609](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29609)) + * Fixes vulnerabilities caused by incomplete validation in `SparseSparseMinimum` ([CVE-2021-29607](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29607)) + * Fixes vulnerabilities caused by incomplete validation in `SparseReshape` ([CVE-2021-29611](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29611)) + * Fixes vulnerabilities caused by invalid validation in `QuantizeAndDequantizeV2` ([CVE-2021-29610](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29610)) + * Fixes a heap buffer overflow in `BandedTriangularSolve` ([CVE-2021-29612](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29612)) + * Fixes vulnerabilities caused by incomplete validation in `tf.raw_ops.CTCLoss` ([CVE-2021-29613](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29613)) + * Fixes an interpreter crash from vulnerabilities in `tf.io.decode_raw` ([CVE-2021-29614](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29614)) + * Fixes a stack overflow in `ParseAttrValue` with nested tensors ([CVE-2021-29615](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29615)) + * Fixes a null dereference in Grappler's `TrySimplify` ([CVE-2021-29616](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29616)) + * Fixes a crash in `tf.transpose` with complex inputs ([CVE-2021-29618](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29618)) + * Fixes a crash in `tf.strings.substr` due to `CHECK`-fail ([CVE-2021-29617](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29617)) + * Fixes a segfault in `tf.raw_ops.SparseCountSparseOutput` ([CVE-2021-29619](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29619)) + * Fixes a segfault in `tf.raw_ops.ImmutableConst` ([CVE-2021-29539](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2021-29539)) + * Updates `curl` to `7.76.0` to handle [CVE-2020-8169](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-8169), [CVE-2020-8177](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-8177), [CVE-2020-8231](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-8231), [CVE-2020-8284](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-8284), [CVE-2020-8285](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-8285) and [CVE-2020-8286](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-8286). + +* Other + * Added `show_debug_info` to `mlir.convert_graph_def` and + `mlir.convert_function`. + * Added [Arm Compute Library (ACL)](https://github.com/ARM-software/ComputeLibrary) + support to `--config=mkl_aarch64` build. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +8bitmp3, Aaron S. Mondal, Abhilash Mahendrakar, Abhinav Upadhyay, Abhishek Kulkarni, Abolfazl Shahbazi, Adam Hillier, Aditya Kane, Ag Ramesh, ahmedsabie, Albert Villanova Del Moral, Aleksey Vitebskiy, Alex Hoffman, Alexander Bayandin, Alfie Edwards, Aman Kishore, Amogh Joshi, andreABbauer, Andrew Goodbody, Andrzej Pomirski, Artemiy Ryabinkov, Ashish Jha, ather, Ayan Moitra, Bairen Yi, Bart Ribbers, Bas Aarts, Behzad Abghari, Ben Arnao, Ben Barsdell, Benjamin Klimczak, bhack, Brendan Collins, Can Wang, Cheng Ren, Chris Leary, Chris Olivier, Clemens Giuliani, Cloud Han, Corey Cole, Cui, Yifeng, Cuong V. Nguyen, Daniel Moore, Dawid Wojciechowski, Ddavis-2015, Dean Wyatte, Denisa Roberts, dependabot[bot], Dmitry Volodin, Dominic Jack, Duncan Riach, dushuai, Elena Zhelezina, Eli Osherovich, Erik Smistad, ewsn1593, Felix Fent, fo40225, François Chollet, Frederic Bastien, Freedom" Koan-Sin Tan, fsx950223, ganand1, gbaned, Georgiy Manuilov, gerbauz, Guillaume Klein, Guozhong Zhuang, Harry Slatyer, Harsh188, henri, Henri Woodcock, Hiran Sarkar, Hollow Man, Håkon Sandsmark, I Wayan Dharmana, icysapphire, Ikko Ashimine, Jab Hofmeier, Jack Hessel, Jacob Valdez, Jakub Jatczak, James Bernardi, Jared Smolens, Jason Zaman, jedlimlx, Jenny Plunkett, Jens Elofsson, Jerry Shih, jgehw, Jia Fu Low, Jim Fisher, jpodivin, Julien Stephan, Jungsub Lim, Junha Park, Junhyuk So, justkw, Kaixi Hou, kashyapraval, Kasra Bigdeli, Kazuaki Ishizaki, Keith Mok, Kevin Cheng, kopytjuk, Kristian Hartikainen, ksood12345, Kulin Seth, kushanam, latyas, Lequn Chen, Leslie-Fang, Long M. Lưu, Lukas Geiger, machineko, Mahmoud Abuzaina, Manish, Mao Yunfei, Maozhou, Ge, Marcin Juszkiewicz, Marcin Owsiany, Marconi Jiang, Marcos Pereira, Maria Romanenko Vexlard, Maria Vexlard, Marius Brehler, marload, Martin Kubovčík, Matej, Mateusz Holenko, Maxiwell S. Garcia, Mazhar, mazharul, mbhuiyan, mdfaijul, Michael Gielda, Michael Kuchnik, Michal Szutenberg, Mikhail Stepanov, Milan Straka, Mitchel Humpherys, Mohamed Moselhy, Mohamed Nour Abouelseoud, Måns Bermell, Måns Nilsson, Nathan Luehr, Nico Jahn, Niroop Ammbashankar, Oceania2018, Omri Steiner, Orivej Desh, Oskar Flordal, oujiafan, Patrik Laurell, Paul B. Isaac'S, Paul Klinger, Pawel Piskorski, Pedro Marques, Phat Tran, Piotr Zierhoffer, piyushdatta, Pnikam-Cad, Prashant Kumar, Prateek Gupta, PratsBhatt, Pravin Karandikar, qqq.jq, QQ喵, Quintin, Rama Ketineni, ravikyram, Rehan Guha, rhdong, rmothukuru, Roger Cheng, Rohit Santhanam, rposts, Rsanthanam-Amd, rsun, Rsun-Bdti, Ryan Kuester, ryanking13, Saduf2019, Sami Kama, Samuel Marks, Scott Tseng, Sean Moriarity, Sergey Popov, Sergii Khomenko, Sheng, Yang, shwetaoj, Sidong-Wei, Simon Maurer, Simrit Kaur, Srini511, Srinivasan Narayanamoorthy, Stephan, Stephen Matthews, Sungmann Cho, Sunoru, Suraj Sudhir, Suraj Upadhyay, Taebum Kim, Takayoshi Koizumi, Tamas Bela Feher, Teng Lu, Thibaut Goetghebuer-Planchon, Tomwildenhain-Microsoft, Tony, Traun Leyden, Trent Lo, TVLIgnacy, Tzu-Wei Sung, vaibhav, Vignesh Kothapalli, Vikram Dattu, viktprog, Vinayaka Bandishti, Vincent Abriou, Vishakha Agrawal, Vivek Panyam, Vladimir Silyaev, Võ Văn Nghĩa, wamuir, Wang, Yanzhang, wangsiyu, Waqar Hameed, wxinix, Xiao Yang, xiaohong1031, Xiaoming (Jason) Cui, Xinan Jiang, Yair Ehrenwald, Yajush Vyas, Yasir Modak, Yimei Sun, Yong Tang, Yosshi999, youshenmebutuo, yqtianust, Yuan Tang, yuanbopeng, Yuriy Chernyshov, Yuta Fukasawa, Zachary Deane-Mayer, Zeno Gantner, Zhoulong Jiang, zhuyie, zilinzhu, 彭震东 + +# Release 2.4.1 + +* This release removes the AVX2 requirement from TF 2.4.0. + +# Release 2.3.2 ## Bug Fixes and Other Changes +* Fixes an access to unitialized memory in Eigen code + ([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266)) +* Fixes a security vulnerability caused by lack of validation in + `tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap` + ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267)) +* Fixes a vulnerability caused by attempting to write to immutable memory region in + `tf.raw_ops.ImmutableConst` + ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268) +* Fixes a `CHECK`-fail in LSTM with zero-length input + ([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270)) +* Fixes a security vulnerability caused by accessing heap data outside of bounds + when loading a specially crafted `SavedModel` + ([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271)) +* Solves an OOM issue on TPUs when XLA contexts use fused average updates +* Updates `libjpeg-turbo` to `2.0.5` to handle + [CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790). +* Updates `junit` to `4.13.1` to handle + [CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250). +* Updates `PCRE` to `8.44` to handle + [CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838) + and + [CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155). +* Updates `sqlite3` to `3.44.0` to keep in sync with master branch. -* -* -* -* Security: - * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch` - ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) - * Fixes three vulnerabilities in conversion to DLPack format - ([CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), - [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), - [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193)) - * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` - ([CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), - [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195)) - * Fixes several vulnerabilities in `RaggedCountSparseOutput` and - `SparseCountSparseOutput` operations - ([CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), - [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), - [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), - [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), - [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), - [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201)) - * Fixes an integer truncation vulnerability in code using the work sharder - API - ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) - * Fixes a format string vulnerability in `tf.strings.as_string` - ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) - * Fixes segfault raised by calling session-only ops in eager mode - ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) - * Fixes data leak and potential ASLR violation from - `tf.raw_ops.StringNGrams` - ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) - * Fixes segfaults caused by incomplete `SavedModel` validation - ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) - * Fixes a data corruption due to a bug in negative indexing support in - TFLite - ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) - * Fixes a data corruption due to dimension mismatch in TFLite - ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) - * Fixes several vulnerabilities in TFLite saved model format - ([CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), - [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), - [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211)) - * Fixes several vulnerabilities in TFLite implementation of segment sum - ([CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), - [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), - [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214)) - * Fixes a segfault in `tf.quantization.quantize_and_dequantize` - ([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265)) - * Fixes an undefined behavior float cast causing a crash - ([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266)) -* TF Core: - * `tf.types.experimental.TensorLike` is a new `Union` type that can be - used as type annotation for variables representing a Tensor or a value - that can be converted to Tensor by `tf.convert_to_tensor`. - * Calling ops with a python constants or numpy values is now consistent - with tf.convert_to_tensor behavior. This avoids operations like - tf.reshape truncating inputs such as from int64 to int32. - * Added `tf.sparse.map_values` to apply a function to the `.value`s of - `SparseTensor` arguments. - * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, - `__xor__` and `__invert__` now support non-`bool` arguments and apply - the corresponding bitwise ops. `bool` arguments continue to be supported - and dispatch to logical ops. This brings them more in line with Python - and NumPy behavior. - * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor - with the same sparsity pattern, but with new provided values. It is - similar to the `with_values` function of `RaggedTensor`. - * Added `StatelessCase` op, and uses it if none of case branches has - stateful ops. - * Added `tf.config.experimental.get_memory_usage` to return total memory - usage of the device. - * Added gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. - * Improve shape inference of nested function calls by supporting constant folding across Arg nodes which makes more static values available to shape inference functions. -* `tf.data`: - * tf.data service: - * Added new `tf.data.experimental.service.register_dataset` and - `tf.data.experimental.service.from_dataset_id` APIs to enable one - process to register a dataset with the tf.data service, and another - process to consume data from the dataset. - * Added support for dispatcher fault tolerance. To enable fault tolerance, - configure a `work_dir` when running your dispatcher server and set - `dispatcher_fault_tolerance=True`. The dispatcher will store its state - to `work_dir`, so that on restart it can continue from its previous - state after restart. - * Added support for sharing dataset graphs via shared filesystem instead - of over RPC. This reduces load on the dispatcher, improving performance - of distributing datasets. For this to work, the dispatcher's `work_dir` - must be accessible from workers. If the worker fails to read from the - `work_dir`, it falls back to using RPC for dataset graph transfer. - * Added support for a new "distributed_epoch" processing mode. This - processing mode distributes a dataset across all tf.data workers, - instead of having each worker process the full dataset. See - [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode) - to learn more. - * Added optional `exclude_cols` parameter to CsvDataset. This parameter is - the complement of `select_cols`; at most one of these should be - specified. - * We have implemented an optimization which reorders data-discarding - transformations such as `take` and `shard` to happen earlier in the - dataset when it is safe to do so. The optimization can be disabled via - the `experimental_optimization.reorder_data_discarding_ops` dataset - option. - * `tf.data.Options` were previously immutable and can now be overridden. - * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors - with a new `output_signature` argument, which allows `from_generator` to - produce any type describable by a `tf.TypeSpec`. - * `tf.data.experimental.AUTOTUNE` is now available in the core API as - `tf.data.AUTOTUNE`. -* `tf.image`: - * Added deterministic `tf.image.stateless_random_*` functions for each - `tf.image.random_*` function. Added a new op - `stateless_sample_distorted_bounding_box` which is a deterministic - version of `sample_distorted_bounding_box` op. Given the same seed, - these stateless functions/ops produce the same results independent of - how many times the function is called, and independent of global seed - settings. -* `tf.distribute`: - * (Experimental) Parameter server training: - * Replaced the existing - `tf.distribute.experimental.ParameterServerStrategy` symbol with - a new class that is for parameter server training in TF2. Usage with - the old symbol, usually with Estimator, should be replaced with - `tf.compat.v1.distribute.experimental.ParameterServerStrategy`. - * Added `tf.distribute.experimental.coordinator.*` namespace, - including the main API `ClusterCoordinator` for coordinating the - training cluster, the related data structure `RemoteValue` - and `PerWorkerValue`. -* `tf.keras`: - * Improvements from the functional API refactoring: - * Functional model construction does not need to maintain a global - workspace graph, removing memory leaks especially when building many - models or very large models. - * Functional model construction should be ~8-10% faster on average. - * Functional models can now contain non-symbolic values in their call - inputs inside of the first positional argument. - * Several classes of TF ops that were not reliably converted to Keras - layers during functional API construction should now work, e.g. - `tf.image.ssim_multiscale` - * Error messages when Functional API construction goes wrong (and when - ops cannot be converted to Keras layers automatically) should be - clearer and easier to understand. - * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` - as an alternative to accepting a `callable` loss. - * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) - to match FTRL paper - (https://research.google.com/pubs/archive/41159.pdf). - * Added `mobilenet_v3` to keras application model. - * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for - customization of how gradients are aggregated across devices, as well as - `gradients_transformers` to allow for custom gradient transformations - (such as gradient clipping). - * The `steps_per_execution` argument in `compile()` is no longer - experimental; if you were passing `experimental_steps_per_execution`, - rename it to `steps_per_execution` in your code. This argument controls - the number of batches to run during each `tf.function` call when calling - `fit()`. Running multiple batches inside a single `tf.function` call can - greatly improve performance on TPUs or small models with a large Python - overhead. - * Improvements to Keras preprocessing layers: - * TextVectorization can now accept a vocabulary list or file as an - init arg. - * Normalization can now accept mean and variance values as init args. - * In `Attention` and `AdditiveAttention` layers, the `call()` method now - accepts a `return_attention_scores` argument. When set to - True, the layer returns the attention scores as an additional output - argument. - * Added `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints - with the same implementation as their `tf.losses` equivalent. - * For Keras model, the individual call of `Model.evaluate` uses no cached - data for evaluation, while `Model.fit` uses cached data when - `validation_data` arg is provided for better performance. - * Added a `save_traces` argument to `model.save`/ - `tf.keras.models.save_model` which determines whether the SavedModel - format stores the Keras model/layer call functions. The traced functions - allow Keras to revive custom models and layers without the original - class definition, but if this isn't required the tracing can be - disabled with the added option. -* `tf.function` / AutoGraph: - * Added `experimental_follow_type_hints` argument for `tf.function`. When - True, the function may use type annotations to optimize the tracing - performance. - * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. - * AutoGraph now allows creating new symbols inside a TensorFLow loop, if - the values of these symbols at an iteration does not depend on the - previous iteration. These types of loops must run at least one - iteration, and will raise a runtime error otherwise. - - Example: - - ``` - for batch in data: - outputs = train_step(batch) - tf.print('final outputs', outputs) - ``` - - See tensorflow/python/autograph/g3doc/reference/limitations.md for more - info. +# Release 2.2.2 -* `tf.lite`: +## Bug Fixes and Other Changes +* Fixes an access to unitialized memory in Eigen code + ([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266)) +* Fixes a security vulnerability caused by lack of validation in + `tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap` + ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267)) +* Fixes a vulnerability caused by attempting to write to immutable memory region in + `tf.raw_ops.ImmutableConst` + ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268) +* Fixes a `CHECK`-fail in LSTM with zero-length input + ([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270)) +* Fixes a security vulnerability caused by accessing heap data outside of bounds + when loading a specially crafted `SavedModel` + ([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271)) +* Prevents memory leaks in loading `SavedModel`s that import functions +* Updates `libjpeg-turbo` to `2.0.5` to handle + [CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790). +* Updates `junit` to `4.13.1` to handle + [CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250). +* Updates `PCRE` to `8.44` to handle + [CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838) + and + [CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155). +* Updates `sqlite3` to `3.44.0` to keep in sync with master branch. - * `TFLiteConverter`: - * Support optional flags `inference_input_type` and - `inference_output_type` for full integer quantized models. This - allows users to modify the model input and output type to integer - types (`tf.int8`, `tf.uint8`) instead of defaulting to float type - (`tf.float32`). - * TFLite Profiler for Android is available. See the detailed - [guide](https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android). - * NNAPI - * Added NNAPI Delegation support for requantization use cases by - converting the operation into a dequantize-quantize pair. - * Removed deprecated `Interpreter.setUseNNAPI(boolean)` Java API. - * Use `Interpreter.Options.setUseNNAPI` instead. - * Deprecate `Interpreter::UseNNAPI(bool)` C++ API. - * Use `NnApiDelegate()` and related delegate configuration methods - directly. - * Deprecate `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API - * Prefer controlling this via delegate options, e.g. - `tflite::StatefulNnApiDelegate::Options::allow_fp16' or - `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. - * `DynamicBuffer::AddJoinedString()` will now add a separator if the first - string to be joined is empty. - * Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion. - * +# Release 2.1.3 -* `tf.random`: +## Bug Fixes and Other Changes +* Fixes an access to unitialized memory in Eigen code + ([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266)) +* Fixes a security vulnerability caused by lack of validation in + `tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap` + ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267)) +* Fixes a vulnerability caused by attempting to write to immutable memory region in + `tf.raw_ops.ImmutableConst` + ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268) +* Fixes a `CHECK`-fail in LSTM with zero-length input + ([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270)) +* Fixes a security vulnerability caused by accessing heap data outside of bounds + when loading a specially crafted `SavedModel` + ([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271)) +* Updates `libjpeg-turbo` to `2.0.5` to handle + [CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790). +* Updates `junit` to `4.13.1` to handle + [CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250). +* Updates `PCRE` to `8.44` to handle + [CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838) + and + [CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155). +* Updates `sqlite3` to `3.44.0` to keep in sync with master branch. +* Newer ROCm versions are supported on the 2.1 branch. - * +# Release 2.0.4 -* Math and Linear Algebra: +Note that this is the last patch release for the TensorFlow 2.0.x series. - * Add `tf.math.erfcinv`, the inverse to `tf.math.erfc`. +## Bug Fixes and Other Changes +* Fixes an access to unitialized memory in Eigen code + ([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266)) +* Fixes a security vulnerability caused by lack of validation in + `tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap` + ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267)) +* Fixes a vulnerability caused by attempting to write to immutable memory region in + `tf.raw_ops.ImmutableConst` + ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268) +* Fixes a `CHECK`-fail in LSTM with zero-length input + ([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270)) +* Fixes a security vulnerability caused by accessing heap data outside of bounds + when loading a specially crafted `SavedModel` + ([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271)) +* Updates `libjpeg-turbo` to `2.0.5` to handle + [CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790). +* Updates `junit` to `4.13.1` to handle + [CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250). +* Updates `PCRE` to `8.44` to handle + [CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838) + and + [CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155). +* Updates `sqlite3` to `3.44.0` to keep in sync with master branch. -* TPU Enhancements: +# Release 1.15.5 - * Added support for the `beta` parameter of the FTRL optimizer for TPU - embeddings. Users of other TensorFlow platforms can implement equivalent - behavior by adjusting the `l2` parameter. - * +Note that this is the last patch release for the TensorFlow 1.x series. -* XLA Support: +## Bug Fixes and Other Changes +* Fixes an access to unitialized memory in Eigen code + ([CVE-2020-26266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26266)) +* Fixes a security vulnerability caused by lack of validation in + `tf.raw_ops.DataFormatVecPermute` and `tf.raw_ops.DataFormatDimMap` + ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267)) +* Fixes a vulnerability caused by attempting to write to immutable memory region in + `tf.raw_ops.ImmutableConst` + ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268) +* Fixes a `CHECK`-fail in LSTM with zero-length input + ([CVE-2020-26270](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26270)) +* Fixes a security vulnerability caused by accessing heap data outside of bounds + when loading a specially crafted `SavedModel` + ([CVE-2020-26271](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26271)) +* Updates `libjpeg-turbo` to `2.0.5` to handle + [CVE-2020-13790](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-13790). +* Updates `junit` to `4.13.1` to handle + [CVE-2020-15250](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15250). +* Updates `PCRE` to `8.44` to handle + [CVE-2019-20838](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2019-20838) + and + [CVE-2020-14155](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14155). +* Updates `sqlite3` to `3.44.0` to keep in sync with master branch. - * xla.experimental.compile is deprecated, use - `tf.function(experimental_compile=True)` instead - * Added `tf.function.experimental_get_compiler_ir` which returns compiler - IR (currently 'hlo' and 'optimized_hlo') for given input for given - function. - * +# Release 2.4.0 -* Tracing and Debugging: + ## Major Features and Improvements - * +* `tf.distribute` introduces experimental support for asynchronous training of + models via the [`tf.distribute.experimental.ParameterServerStrategy`] + (https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/ParameterServerStrategy) + API. Please see the [tutorial](https://www.tensorflow.org/tutorials/distribute/parameter_server_training) + to learn more. -* `tf.train.Checkpoint`: +* [`MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy) + is now a stable API and is no longer considered experimental. Some of the + major improvements involve handling peer failure and many bug fixes. Please + check out the detailed tutorial on [Multi-worker training with Keras] + (https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras). - * Now accepts a `root` argument in the initialization, which generates a - checkpoint with a root object. This allows users to create a - `Checkpoint` object that is compatible with Keras `model.save_weights()` - and `model.load_weights`. The checkpoint is also compatible with the - checkpoint saved in the `variables/` folder in the SavedModel. - * When restoring, `save_path` can be a path to a SavedModel. The function - will automatically find the checkpoint in the SavedModel. +* Introduces experimental support for a new module named [`tf.experimental.numpy`] + (https://www.tensorflow.org/api_docs/python/tf/experimental/numpy) which is a + NumPy-compatible API for writing TF programs. See the [detailed guide] + (https://www.tensorflow.org/guide/tf_numpy) to learn more. Additional details below. -* `tf.nn`: +* Adds Support for + [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) + on Ampere based GPUs. TensorFloat-32, or TF32 for short, is a math mode for + NVIDIA Ampere based GPUs and is enabled by default. - * `tf.nn.max_pool2d` now supports explicit padding. +* A major refactoring of the internals of the Keras Functional API has been + completed, that should improve the reliability, stability, and performance of + constructing Functional models. -* `tf.debugging`: +* Keras mixed precision API [`tf.keras.mixed_precision`] + (https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision?version=nightly) + is no longer experimental and allows the use of 16-bit floating point formats + during training, improving performance by up to 3x on GPUs and 60% on TPUs. + Please see below for additional details. - * `tf.debugging.assert_shapes()` now works on `SparseTensor`s (#36268). +* TensorFlow Profiler now supports profiling `MultiWorkerMirroredStrategy` and + tracing multiple workers using the [sampling mode API] + (https://www.tensorflow.org/guide/profiler#profiling_apis). -* `tf.print`: +* TFLite Profiler for Android is available. See the detailed [guide] + (https://www.tensorflow.org/lite/performance/measurement#trace_tensorflow_lite_internals_in_android) + to learn more. - * Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict` - didn't have the keys sorted, the keys and values were not being printed - in accordance with their correct mapping. +* TensorFlow pip packages are now built with CUDA11 and cuDNN 8.0.2. -* `TensorRT` +## Breaking Changes - * We now issue a warning when the `session_config` parameter for the TF1 - converter is used or the `rewrite_config_template` field in the TF2 - converter parameter object is used. +* TF Core: + * Certain float32 ops run in lower precision on Ampere based GPUs, including + matmuls and convolutions, due to the use of [TensorFloat-32] + (https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). + Specifically, inputs to such ops are rounded from 23 bits of precision to 10 + bits of precision. This is unlikely to cause issues in practice for deep learning + models. In some cases, TensorFloat-32 is also used for complex64 ops. + TensorFloat-32 can be disabled by running `tf.config.experimental.enable_tensor_float_32_execution(False)`. + * The byte layout for string tensors across the C-API has been updated to match + TF Core/C++; i.e., a contiguous array of `tensorflow::tstring`/`TF_TString`s. + * C-API functions `TF_StringDecode`, `TF_StringEncode`, and `TF_StringEncodedSize` + are no longer relevant and have been removed; see `core/platform/ctstring.h` for + string access/modification in C. + * `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are + now hidden. These modules are not part of TensorFlow public API. + * `tf.raw_ops.Max` and `tf.raw_ops.Min` no longer accept inputs of type + `tf.complex64` or `tf.complex128`, because the behavior of these ops is not + well defined for complex types. + * XLA:CPU and XLA:GPU devices are no longer registered by default. Use + `TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them, but this + flag will eventually be removed in subsequent releases. -* Other: +* `tf.keras`: + * The `steps_per_execution` argument in `model.compile()` is no longer experimental; + if you were passing `experimental_steps_per_execution`, rename it to + `steps_per_execution` in your code. This argument controls the number of batches + to run during each `tf.function` call when calling `model.fit()`. Running multiple + batches inside a single `tf.function` call can greatly improve performance on + TPUs or small models with a large Python overhead. + * A **major refactoring** of the internals of the Keras Functional API may affect code that + is relying on certain internal details: + * Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when + checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`. + * Code that is overly dependent on the exact names attached to symbolic tensors + (e.g. assumes there will be ":0" at the end of the inputs, treats names as + unique identifiers instead of using `tensor.ref()`, etc.) may break. + * Code that uses full path for `get_concrete_function` to trace Keras symbolic + inputs directly should switch to building matching `tf.TensorSpec`s directly and + tracing the `TensorSpec` objects. + * Code that relies on the exact number and names of the op layers that TensorFlow + operations were converted into may have changed. + * Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers + and happens to work before TF 2.4. These will explicitly be unsupported now. + Converting these ops to Functional API op layers was unreliable before TF 2.4, + and prone to erroring incomprehensibly or being silently buggy. + * Code that directly asserts on a Keras symbolic value in cases where ops + like `tf.rank` used to return a static or symbolic value depending on if the + input had a fully static shape or not. Now these ops always return symbolic values. + * Code already susceptible to leaking tensors outside of graphs becomes slightly + more likely to do so now. + * Code that tries directly getting gradients with respect to symbolic Keras + inputs/outputs. Use `GradientTape` on the actual Tensors passed to the already-constructed + model instead. + * Code that requires very tricky shape manipulation via converted op layers + in order to work, where the Keras symbolic shape inference proves insufficient. + * Code that tries manually walking a `tf.keras.Model` layer by layer and assumes + layers only ever have one positional argument. This assumption doesn't hold + true before TF 2.4 either, but is more likely to cause issues now. + * Code that manually enters `keras.backend.get_graph()` before building a + functional model is no longer needed. + * Start enforcing input shape assumptions when calling Functional API Keras + models. This may potentially break some users, in case there is a mismatch + between the shape used when creating `Input` objects in a Functional model, + and the shape of the data passed to that model. You can fix this mismatch by + either calling the model with correctly-shaped data, or by relaxing `Input` shape + assumptions (note that you can pass shapes with `None` entries for axes that + are meant to be dynamic). You can also disable the input checking entirely by + setting `model.input_spec = None`. + * Several changes have been made to `tf.keras.mixed_precision.experimental`. + Note that it is now recommended to use the non-experimental + `tf.keras.mixed_precision` API. + * `AutoCastVariable.dtype` now refers to the actual variable dtype, not the + dtype it will be casted to. + * When mixed precision is enabled, `tf.keras.layers.Embedding` now outputs a + float16 or bfloat16 tensor instead of a float32 tensor. + * The property `tf.keras.mixed_precision.experimental.LossScaleOptimizer.loss_scale` + is now a tensor, not a `LossScale` object. This means to get a loss scale + of a `LossScaleOptimizer` as a tensor, you must now call `opt.loss_scale`instead of `opt.loss_scale()`. + * The property `should_cast_variables` has been removed from `tf.keras.mixed_precision.experimental.Policy` + * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, + the `DynamicLossScale`'s multiplier must be 2. + * When passing a `tf.mixed_precision.experimental.DynamicLossScale` to + `tf.keras.mixed_precision.experimental.LossScaleOptimizer`, the weights of + the `DynanmicLossScale` are copied into the `LossScaleOptimizer` instead of being reused. + This means modifying the weights of the `DynamicLossScale` will no longer affect the weights of the LossScaleOptimizer, and vice versa. + * The global policy can no longer be set to a non-floating point policy in `tf.keras.mixed_precision.experimental.set_policy` + * In `Layer.call`, `AutoCastVariable`s will no longer be casted within + `MirroredStrategy.run` or `ReplicaContext.merge_call`. This is because a thread local + variable is used to determine whether `AutoCastVariable`s are casted, and those + two functions run with a different thread. Note this only applies if one of + these two functions is called within `Layer.call`; if one of those two functions calls `Layer.call`, `AutoCastVariable`s will still be casted. + +* `tf.data`: + * `tf.data.experimental.service.DispatchServer` now takes a config tuple + instead of individual arguments. Usages should be updated to + `tf.data.experimental.service.DispatchServer(dispatcher_config)`. + * `tf.data.experimental.service.WorkerServer` now takes a config tuple instead + of individual arguments. Usages should be updated to `tf.data.experimental.service.WorkerServer(worker_config)`. + +* `tf.distribute`: + * Removes `tf.distribute.Strategy.experimental_make_numpy_dataset`. Please use + `tf.data.Dataset.from_tensor_slices` instead. + * Renames `experimental_hints` in `tf.distribute.StrategyExtended.reduce_to`, + `tf.distribute.StrategyExtended.batch_reduce_to`, `tf.distribute.ReplicaContext.all_reduce` + to `options`. + * Renames `tf.distribute.experimental.CollectiveHints` to `tf.distribute.experimental.CommunicationOptions`. + * Renames `tf.distribute.experimental.CollectiveCommunication` to `tf.distribute.experimental.CommunicationImplementation`. + * Renames `tf.distribute.Strategy.experimental_distribute_datasets_from_function` to `distribute_datasets_from_function` as it is no longer experimental. + * Removes `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2. + +* `tf.lite`: + * `tf.quantization.quantize_and_dequantize_v2` has been introduced, which updates the gradient definition for quantization which is outside the range + to be 0. To simulate the V1 the behavior of `tf.quantization.quantize_and_dequantize(...)` use + `tf.grad_pass_through(tf.quantization.quantize_and_dequantize_v2)(...)`. + +* Building TensorFlow: + * Windows platform builds: TensorFlow on Windows under MSVC is now built with + `--copt=/experimental:preprocessor --host_copt=/experimental:preprocessor` + (see `.bazelrc` for more details). Builds including TensorFlow may fail with + unexpected syntax errors if these flags are absent. See also + [this thread on SIG Build](https://groups.google.com/a/tensorflow.org/g/build/c/LbAw8RILvTg/m/ttnuhYU2BgAJ). + +## Known Caveats + * `tf.keras.mixed_precision` + * When using mixed precision, calling `RMSprop.apply_gradients` or + `Nadam.apply_gradients` outside a `tf.function` does not work and will raise + the AttributeError "Tensor.op is meaningless when eager execution is enabled". + See this [issue](https://github.com/tensorflow/tensorflow/issues/45536) for details and a workaround. + +## Bug Fixes and Other Changes + +### TF Core: + * Introduces experimental support for a new module named [`tf.experimental.numpy`] + (https://www.tensorflow.org/api_docs/python/tf/experimental/numpy), which is a + NumPy-compatible API for writing TF programs. This module provides class + `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable + `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are + provided. Their inter-operation with TF facilities is seamless in most cases. + See [tensorflow/python/ops/numpy_ops/README.md](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/numpy_ops/README.md) + for details of what operations are supported and what are the differences + from NumPy. + * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as + type annotation for variables representing a Tensor or a value + that can be converted to Tensor by `tf.convert_to_tensor`. + * Calling ops with a python constants or numpy values is now consistent with + tf.convert_to_tensor behavior. This avoids operations like + tf.reshape truncating inputs such as from int64 to int32. + * Adds `tf.sparse.map_values` to apply a function to the `.value`s of + `SparseTensor` arguments. + * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` and `__invert__` now support non-`bool` + arguments and apply the corresponding bitwise ops. `bool` arguments continue + to be supported and dispatch to logical ops. This brings them more in line with + Python and NumPy behavior. + * Adds `tf.SparseTensor.with_values`. This returns a new SparseTensor with the same sparsity pattern, but with new provided values. It is + similar to the `with_values` function of `RaggedTensor`. + * Adds `StatelessCase` op, and uses it if none of case branches has stateful ops. + * Adds `tf.config.experimental.get_memory_usage` to return total memory usage of the device. + * Adds gradients for `RaggedTensorToVariant` and `RaggedTensorFromVariant`. + * Improve shape inference of nested function calls by supporting constant + folding across Arg nodes which makes more static values available to shape + inference functions. +* `tf.debugging`: + * `tf.debugging.assert_shapes()` now works on `SparseTensor`s (Fixes [#36268](https://github.com/tensorflow/tensorflow/issues/36268)). +* GPU + * Adds Support for [TensorFloat-32](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) + on Ampere based GPUs.TensorFloat-32, or TF32 for short, is a math mode for + NVIDIA Ampere based GPUs which causes certain float32 ops, such as matrix + multiplications and convolutions, to run much faster on Ampere GPUs but with + reduced precision. This reduced precision has not been found to effect + convergence quality of deep learning models in practice. TensorFloat-32 is + enabled by default, but can be disabled with `tf.config.experimental.enable_tensor_float_32_execution`. +* `tf.math`: + * Adds `tf.math.erfcinv`, the inverse to `tf.math.erfc`. +* `tf.nn`: + * `tf.nn.max_pool2d` now supports explicit padding. +* `tf.image`: + * Adds deterministic `tf.image.stateless_random_*` functions for each + `tf.image.random_*` function. Added a new op `stateless_sample_distorted_bounding_box` + which is a deterministic version of `sample_distorted_bounding_box` op. + Given the same seed, these stateless functions/ops produce the same results + independent of how many times the function is called, and independent of global seed settings. + * Adds deterministic `tf.image.resize` backprop CUDA kernels for + `method=ResizeMethod.BILINEAR` (the default method). Enable by setting the environment + variable `TF_DETERMINISTIC_OPS` to `"true"` or `"1"`. +* `tf.print`: + * Bug fix in `tf.print()` with `OrderedDict` where if an `OrderedDict` + didn't have the keys sorted, the keys and values were not being printed + in accordance with their correct mapping. +* `tf.train.Checkpoint`: + * Now accepts a `root` argument in the initialization, which generates a + checkpoint with a root object. This allows users to create a `Checkpoint` + object that is compatible with Keras `model.save_weights()` and + `model.load_weights`. The checkpoint is also compatible with the checkpoint + saved in the `variables/` folder in the SavedModel. + * When restoring, `save_path` can be a path to a SavedModel. The function will + automatically find the checkpoint in the SavedModel. + +### `tf.data`: + * Adds new `tf.data.experimental.service.register_dataset` and + `tf.data.experimental.service.from_dataset_id` APIs to enable one process to + register a dataset with the tf.data service, and another process to consume + data from the dataset. + * Adds support for dispatcher fault tolerance. To enable fault tolerance, + configure a `work_dir` when running your dispatcher server and set + `dispatcher_fault_tolerance=True`. The dispatcher will store its state to + `work_dir`, so that on restart it can continue from its previous state after restart. + * Adds support for sharing dataset graphs via shared filesystem instead of + over RPC. This reduces load on the dispatcher, improving performance + of distributing datasets. For this to work, the dispatcher's `work_dir` + must be accessible from workers. If the worker fails to read from the `work_dir`, + it falls back to using RPC for dataset graph transfer. + * Adds support for a new "distributed_epoch" processing mode. + This processing mode distributes a dataset across all tf.data workers, + instead of having each worker process the full dataset. See + [the tf.data service docs](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#understand_processing_mode) + to learn more. + * Adds optional `exclude_cols` parameter to CsvDataset. This parameter is the + complement of `select_cols`; at most one of these should be specified. + * We have implemented an optimization which reorders data-discarding + transformations such as `take` and `shard` to happen earlier in the dataset + when it is safe to do so. The optimization can be disabled via the + `experimental_optimization.reorder_data_discarding_ops` dataset option. + * `tf.data.Options` were previously immutable and can now be overridden. + * `tf.data.Dataset.from_generator` now supports Ragged and Sparse tensors with + a new `output_signature` argument, which allows `from_generator` to produce any + type describable by a `tf.TypeSpec`. + * `tf.data.experimental.AUTOTUNE` is now available in the core API as `tf.data.AUTOTUNE`. - * We have replaced uses of "whitelist" and "blacklist" with "allowlist" - and "denylist" where possible. Please see - https://developers.google.com/style/word-list#blacklist for more - context. - * Add `tf.config.experimental.mlir_bridge_rollout` which will help us - rollout the new MLIR TPU bridge. - * Added `tf.experimental.register_filesystem_plugin` to load modular - filesystem plugins from Python - * +### `tf.distribute`: + * Introduces experimental support for asynchronous training of models via + `tf.distribute.experimental.ParameterServerStrategy`: + * Replaces the existing `tf.distribute.experimental.ParameterServerStrategy` + symbol with a new class that is for parameter server training in TF2. Usage of + the old symbol, usually with Estimator API, should be **replaced** with + [`tf.compat.v1.distribute.experimental.ParameterServerStrategy`]. + * Added `tf.distribute.experimental.coordinator.*` namespace, including the + main API `ClusterCoordinator` for coordinating the training cluster, the + related data structure `RemoteValue` and `PerWorkerValue`. + * `MultiWorkerMirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy) + is now a stable API and is no longer considered experimental. Some of the major + improvements involve handling peer failure and many bug fixes. Please check out + the detailed tutorial on [Multi-worer training with Keras](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras). + * Adds `tf.distribute.Strategy.gather` and `tf.distribute.ReplicaContext.all_gather` + APIs to support gathering dense distributed values. + * Fixes various issues with saving a distributed model. + +### `tf.keras`: + * Improvements from the Functional API refactoring: + * Functional model construction does not need to maintain a global workspace + graph, removing memory leaks especially when building many models or very large models. + * Functional model construction should be ~8-10% faster on average. + * Functional models can now contain non-symbolic values in their call inputs + inside of the first positional argument. + * Several classes of TF ops that were not reliably converted to Keras layers + during functional API construction should now work, e.g.`tf.image.ssim_multiscale` + * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be + clearer and easier to understand. + * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` + as an alternative to accepting a `callable` loss. + * Adds `beta` hyperparameter to [FTRL](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl) + optimizer classes (Keras and others) to match [FTRL paper](https://research.google.com/pubs/archive/41159.pdf). + * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for customization + of how gradients are aggregated across devices, as well as `gradients_transformers` + to allow for custom gradient transformations (such as gradient clipping). + * Improvements to Keras preprocessing layers: + * TextVectorization can now accept a vocabulary list or file as an init arg. + * Normalization can now accept mean and variance values as init args. + * In `Attention` and `AdditiveAttention` layers, the `call()` method now accepts a `return_attention_scores` argument. When set to + True, the layer returns the attention scores as an additional output argument. + * Adds `tf.metrics.log_cosh` and `tf.metrics.logcosh` API entrypoints with the + same implementation as their `tf.losses` equivalent. + * For Keras model, the individual call of `Model.evaluate` uses no cached data + for evaluation, while `Model.fit` uses cached data when `validation_data` arg + is provided for better performance. + * Adds a `save_traces` argument to `model.save`/ `tf.keras.models.save_model` + which determines whether the SavedModel format stores the Keras model/layer call + functions. The traced functions allow Keras to revive custom models and layers + without the original class definition, but if this isn't required the tracing + can be disabled with the added option. + * The `tf.keras.mixed_precision` API is now non-experimental. + The non-experimental API differs from the experimental API in several ways. + * `tf.keras.mixed_precision.Policy` no longer takes in a `tf.mixed_precision. + experimental.LossScale` in the constructor, and no longer has a `LossScale` + associated with it. Instead, `Model.compile` will automatically wrap the optimizer + with a `LossScaleOptimizer` using dynamic loss scaling if `Policy.name` + is "mixed_float16". + * `tf.keras.mixed_precision.LossScaleOptimizer`'s constructor takes in different + arguments. In particular, it no longer takes in a `LossScale`, and there is + no longer a `LossScale` associated with the `LossScaleOptimizer`. Instead, + `LossScaleOptimizer` directly implements fixed or dynamic loss scaling. See the + documentation of [`tf.keras.mixed_precision.experimental.LossScaleOptimizer`] + (https://www.tensorflow.org/api_docs/python/tf/keras/mixed_precision/experimental/LossScaleOptimizer?version=nightly) + for details on the differences between the experimental `LossScaleOptimizer` + and the new non-experimental `LossScaleOptimizer`. + * `tf.mixed_precision.experimental.LossScale` and its subclasses are + deprecated, as all of its functionality now exists within `tf.keras.mixed_precision.LossScaleOptimizer` + +### `tf.lite`: + * `TFLiteConverter`: + * Support optional flags `inference_input_type` and `inference_output_type` + for full integer quantized models. This allows users to modify the model input + and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting + to float type (`tf.float32`). + * NNAPI + * Adds NNAPI Delegation support for requantization use cases by converting + the operation into a dequantize-quantize pair. + * Removes deprecated `Interpreter.setUseNNAPI(boolean)` Java API. Use + `Interpreter.Options.setUseNNAPI` instead. + * Deprecates `Interpreter::UseNNAPI(bool)` C++ API. Use `NnApiDelegate()` + and related delegate configuration methods directly. + * Deprecates `Interpreter::SetAllowFp16PrecisionForFp32(bool)` C++ API. + Prefer controlling this via delegate options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' + or `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. + * GPU + * GPU acceleration now supports quantized models by default + * `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty. + * Adds support for cumulative sum (cumsum), both as builtin op and MLIR conversion. + +### `TensorRT` + * Issues a warning when the `session_config` parameter for the TF1 converter + is used or the `rewrite_config_template` field in the TF2 converter parameter + object is used. + +### TPU Enhancements: + * Adds support for the `beta` parameter of the FTRL optimizer for TPU + embeddings. Users of other TensorFlow platforms can implement equivalent + behavior by adjusting the `l2` parameter. + +### XLA Support: + * xla.experimental.compile is deprecated, use `tf.function(experimental_compile=True)` instead. + * Adds `tf.function.experimental_get_compiler_ir` which returns compiler IR + (currently 'hlo' and 'optimized_hlo') for given input for given function. + +### Security: + * Fixes an undefined behavior causing a segfault in `tf.raw_ops.Switch`, + ([CVE-2020-15190](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15190)) + * Fixes three vulnerabilities in conversion to DLPack format + * [CVE-2020-15191](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15191), + * [CVE-2020-15192](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15192), + * [CVE-2020-15193](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15193) + * Fixes two vulnerabilities in `SparseFillEmptyRowsGrad` + * [CVE-2020-15194](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15194), + * [CVE-2020-15195](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15195) + * Fixes several vulnerabilities in `RaggedCountSparseOutput` and `SparseCountSparseOutput` operations + * [CVE-2020-15196](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15196), + * [CVE-2020-15197](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15197), + * [CVE-2020-15198](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15198), + * [CVE-2020-15199](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15199), + * [CVE-2020-15200](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15200), + * [CVE-2020-15201](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15201) + * Fixes an integer truncation vulnerability in code using the work sharder API, + ([CVE-2020-15202](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15202)) + * Fixes a format string vulnerability in `tf.strings.as_string`, + ([CVE-2020-15203](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15203)) + * Fixes segfault raised by calling session-only ops in eager mode, + ([CVE-2020-15204](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15204)) + * Fixes data leak and potential ASLR violation from `tf.raw_ops.StringNGrams`, + ([CVE-2020-15205](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15205)) + * Fixes segfaults caused by incomplete `SavedModel` validation, + ([CVE-2020-15206](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15206)) + * Fixes a data corruption due to a bug in negative indexing support in TFLite, + ([CVE-2020-15207](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15207)) + * Fixes a data corruption due to dimension mismatch in TFLite, + ([CVE-2020-15208](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15208)) + * Fixes several vulnerabilities in TFLite saved model format + * [CVE-2020-15209](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15209), + * [CVE-2020-15210](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15210), + * [CVE-2020-15211](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15211) + * Fixes several vulnerabilities in TFLite implementation of segment sum + * [CVE-2020-15212](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15212), + * [CVE-2020-15213](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15213), + * [CVE-2020-15214](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15214) + * Fixes a segfault in `tf.quantization.quantize_and_dequantize`, + ([CVE-2020-15265](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15265)) + * Fixes an undefined behavior float cast causing a crash, + ([CVE-2020-15266](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-15266)) + * Fixes a lack of validation in `tf.raw_ops.DataFormatVecPermute` and + `tf.raw_ops.DataFormatDimMap` which can cause uninitialized memory access, + read outside bounds of arrays, data corruption and segmentation faults + ([CVE-2020-26267](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26267)) + * Fixes a crash caused by writing to read only memory region + ([CVE-2020-26268](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26268)) + * Fixes a heap out of bounds access in filesystem globbing implementation + ([CVE-2020-26269](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26269)) + +### Other: + * We have replaced uses of "whitelist" and "blacklist" with "allowlist" and + "denylist" where possible. Please see [this list](https://developers.google.com/style/word-list#blacklist) for more context. + * Adds `tf.config.experimental.mlir_bridge_rollout` which will help us rollout the new MLIR TPU bridge. + * Adds `tf.experimental.register_filesystem_plugin` to load modular filesystem plugins from Python ## Thanks to our Contributors -This release contains contributions from many people at Google, as well as: +This release contains contributions from many people at Google as well as the following external contributors: -stjohnso98, , , , , +8bitmp3, aaa.jq, Abhineet Choudhary, Abolfazl Shahbazi, acxz, Adam Hillier, Adrian Garcia Badaracco, Ag Ramesh, ahmedsabie, Alan Anderson, Alexander Grund, Alexandre Lissy, Alexey Ivanov, Amedeo Cavallo, anencore94, Aniket Kumar Singh, Anthony Platanios, Ashwin Phadke, Balint Cristian, Basit Ayantunde, bbbboom, Ben Barsdell, Benjamin Chetioui, Benjamin Peterson, bhack, Bhanu Prakash Bandaru Venkata, Biagio Montaruli, Brent M. Spell, bubblebooy, bzhao, cfRod, Cheng Chen, Cheng(Kit) Chen, Chris Tessum, Christian, chuanqiw, codeadmin_peritiae, COTASPAR, CuiYifeng, danielknobe, danielyou0230, dannyfriar, daria, DarrenZhang01, Denisa Roberts, dependabot[bot], Deven Desai, Dmitry Volodin, Dmitry Zakharov, drebain, Duncan Riach, Eduard Feicho, Ehsan Toosi, Elena Zhelezina, emlaprise2358, Eugene Kuznetsov, Evaderan-Lab, Evgeniy Polyakov, Fausto Morales, Felix Johnny, fo40225, Frederic Bastien, Fredrik Knutsson, fsx950223, Gaurav Singh, Gauri1 Deshpande, George Grzegorz Pawelczak, gerbauz, Gianluca Baratti, Giorgio Arena, Gmc2, Guozhong Zhuang, Hannes Achleitner, Harirai, HarisWang, Harsh188, hedgehog91, Hemal Mamtora, Hideto Ueno, Hugh Ku, Ian Beauregard, Ilya Persky, jacco, Jakub Beránek, Jan Jongboom, Javier Montalt Tordera, Jens Elofsson, Jerry Shih, jerryyin, jgehw, Jinjing Zhou, jma, jmsmdy, Johan Nordström, John Poole, Jonah Kohn, Jonathan Dekhtiar, jpodivin, Jung Daun, Kai Katsumata, Kaixi Hou, Kamil Rakoczy, Kaustubh Maske Patil, Kazuaki Ishizaki, Kedar Sovani, Koan-Sin Tan, Koki Ibukuro, Krzysztof Laskowski, Kushagra Sharma, Kushan Ahmadian, Lakshay Tokas, Leicong Li, levinxo, Lukas Geiger, Maderator, Mahmoud Abuzaina, Mao Yunfei, Marius Brehler, markf, Martin Hwasser, Martin Kubovčík, Matt Conley, Matthias, mazharul, mdfaijul, Michael137, MichelBr, Mikhail Startsev, Milan Straka, Ml-0, Myung-Hyun Kim, Måns Nilsson, Nathan Luehr, ngc92, nikochiko, Niranjan Hasabnis, nyagato_00, Oceania2018, Oleg Guba, Ongun Kanat, OscarVanL, Patrik Laurell, Paul Tanger, Peter Sobot, Phil Pearl, PlusPlusUltra, Poedator, Prasad Nikam, Rahul-Kamat, Rajeshwar Reddy T, redwrasse, Rickard, Robert Szczepanski, Rohan Lekhwani, Sam Holt, Sami Kama, Samuel Holt, Sandeep Giri, sboshin, Sean Settle, settle, Sharada Shiddibhavi, Shawn Presser, ShengYang1, Shi,Guangyong, Shuxiang Gao, Sicong Li, Sidong-Wei, Srihari Humbarwadi, Srinivasan Narayanamoorthy, Steenu Johnson, Steven Clarkson, stjohnso98, Tamas Bela Feher, Tamas Nyiri, Tarandeep Singh, Teng Lu, Thibaut Goetghebuer-Planchon, Tim Bradley, Tomasz Strejczek, Tongzhou Wang, Torsten Rudolf, Trent Lo, Ty Mick, Tzu-Wei Sung, Varghese, Jojimon, Vignesh Kothapalli, Vishakha Agrawal, Vividha, Vladimir Menshakov, Vladimir Silyaev, VoVAllen, Võ Văn Nghĩa, wondertx, xiaohong1031, Xiaoming (Jason) Cui, Xinan Jiang, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yimei Sun, Yiwen Li, Yixing, Yoav Ramon, Yong Tang, Yong Wu, yuanbopeng, Yunmo Koo, Zhangqiang, Zhou Peng, ZhuBaohe, zilinzhu, zmx # Release 2.3.1 diff --git a/WORKSPACE b/WORKSPACE index fa39cedae9bacc..1286ef9ac034e7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -1,140 +1,23 @@ workspace(name = "org_tensorflow") -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +# Initialize the TensorFlow repository and all dependencies. +# +# The cascade of load() statements and tf_workspace?() calls works around the +# restriction that load() statements need to be at the top of .bzl files. +# E.g. we can not retrieve a new repository with http_archive and then load() +# a macro from that repository in the same file. +load("@//tensorflow:workspace3.bzl", "tf_workspace3") -http_archive( - name = "io_bazel_rules_closure", - sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", - strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 - ], -) +tf_workspace3() -# Load tf_repositories() before loading dependencies for other repository so -# that dependencies like com_google_protobuf won't be overridden. -load("//tensorflow:workspace.bzl", "tf_repositories") -# Please add all new TensorFlow dependencies in workspace.bzl. -tf_repositories() +load("@//tensorflow:workspace2.bzl", "tf_workspace2") -register_toolchains("@local_config_python//:py_toolchain") +tf_workspace2() -load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories") +load("@//tensorflow:workspace1.bzl", "tf_workspace1") -closure_repositories() +tf_workspace1() -load("//third_party/toolchains/preconfig/generate:archives.bzl", - "bazel_toolchains_archive") - -bazel_toolchains_archive() - -load( - "@bazel_toolchains//repositories:repositories.bzl", - bazel_toolchains_repositories = "repositories", -) - -bazel_toolchains_repositories() - -# Use `swift_rules_dependencies` to fetch the toolchains. With the -# `git_repository` rules above, the following call will skip redefining them. -load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") -swift_rules_dependencies() - -# We must check the bazel version before trying to parse any other BUILD -# files, in case the parsing of those build files depends on the bazel -# version we require here. -load("//tensorflow:version_check.bzl", "check_bazel_version_at_least") -check_bazel_version_at_least("1.0.0") - -load("//third_party/android:android_configure.bzl", "android_configure") -android_configure(name="local_config_android") -load("@local_config_android//:android.bzl", "android_workspace") -android_workspace() - -# If a target is bound twice, the later one wins, so we have to do tf bindings -# at the end of the WORKSPACE file. -load("//tensorflow:workspace.bzl", "tf_bind") -tf_bind() - -http_archive( - name = "inception_v1", - build_file = "//:models.BUILD", - sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", - ], -) - -http_archive( - name = "mobile_ssd", - build_file = "//:models.BUILD", - sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", - ], -) - -http_archive( - name = "mobile_multibox", - build_file = "//:models.BUILD", - sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", - ], -) - -http_archive( - name = "stylize", - build_file = "//:models.BUILD", - sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", - ], -) - -http_archive( - name = "speech_commands", - build_file = "//:models.BUILD", - sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", - ], -) - -http_archive( - name = "person_detect_data", - sha256 = "170542270da256994ce24d1e357f6e84a54fdaf7d28ff2b74725a40b70b082cf", - urls = [ - "https://storage.googleapis.com/download.tensorflow.org/data/tf_lite_micro_person_data_grayscale_2020_05_24.zip", - ], -) - -# Required for dependency @com_github_grpc_grpc - -load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") - -grpc_deps() - -load( - "@build_bazel_rules_apple//apple:repositories.bzl", - "apple_rules_dependencies", -) - -apple_rules_dependencies() - -load( - "@build_bazel_apple_support//lib:repositories.bzl", - "apple_support_dependencies", -) - -apple_support_dependencies() - -load("@upb//bazel:repository_defs.bzl", "bazel_version_repository") - -bazel_version_repository(name = "bazel_version") - -load("//third_party/googleapis:repository_rules.bzl", "config_googleapis") - -config_googleapis() +load("@//tensorflow:workspace0.bzl", "tf_workspace0") +tf_workspace0() diff --git a/configure.py b/configure.py index e381c8c20dbf70..5207db40cfd709 100644 --- a/configure.py +++ b/configure.py @@ -20,6 +20,7 @@ import argparse import errno +import glob import os import platform import re @@ -46,7 +47,7 @@ _TF_WORKSPACE_ROOT = '' _TF_BAZELRC = '' _TF_CURRENT_BAZEL_VERSION = None -_TF_MIN_BAZEL_VERSION = '3.1.0' +_TF_MIN_BAZEL_VERSION = '3.7.2' _TF_MAX_BAZEL_VERSION = '3.99.0' NCCL_LIB_PATHS = [ @@ -55,16 +56,15 @@ # List of files to configure when building Bazel on Apple platforms. APPLE_BAZEL_FILES = [ - 'tensorflow/lite/experimental/ios/BUILD', - 'tensorflow/lite/experimental/objc/BUILD', - 'tensorflow/lite/experimental/swift/BUILD', + 'tensorflow/lite/ios/BUILD', 'tensorflow/lite/objc/BUILD', + 'tensorflow/lite/swift/BUILD', 'tensorflow/lite/tools/benchmark/experimental/ios/BUILD' ] # List of files to move when building for iOS. IOS_FILES = [ - 'tensorflow/lite/experimental/objc/TensorFlowLiteObjC.podspec', - 'tensorflow/lite/experimental/swift/TensorFlowLiteSwift.podspec', + 'tensorflow/lite/objc/TensorFlowLiteObjC.podspec', + 'tensorflow/lite/swift/TensorFlowLiteSwift.podspec', ] @@ -184,6 +184,8 @@ def get_python_path(environ_cp, python_bin_path): ] all_paths = set(python_paths + library_paths) + # Sort set so order is deterministic + all_paths = sorted(all_paths) paths = [] for path in all_paths: @@ -526,7 +528,12 @@ def set_cc_opt_flags(environ_cp): elif is_windows(): default_cc_opt_flags = '/arch:AVX' else: - default_cc_opt_flags = '-march=native -Wno-sign-compare' + # On all other platforms, no longer use `-march=native` as this can result + # in instructions that are too modern being generated. Users that want + # maximum performance should compile TF in their environment and can pass + # `-march=native` there. + # See https://github.com/tensorflow/tensorflow/issues/45744 and duplicates + default_cc_opt_flags = '-Wno-sign-compare' question = ('Please specify optimization flags to use during compilation when' ' bazel option "--config=opt" is specified [Default is %s]: ' ) % default_cc_opt_flags @@ -534,10 +541,7 @@ def set_cc_opt_flags(environ_cp): question, default_cc_opt_flags) for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) - # It should be safe on the same build host. - if not is_ppc64le() and not is_windows(): - write_to_bazelrc('build:opt --host_copt=-march=native') - write_to_bazelrc('build:opt --define with_default_optimizations=true') + write_to_bazelrc('build:opt --host_copt=%s' % opt) def set_tf_cuda_clang(environ_cp): @@ -1163,49 +1167,20 @@ def set_system_libs_flag(environ_cp): syslibs = ','.join(sorted(syslibs.split())) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) - if 'PREFIX' in environ_cp: - write_to_bazelrc('build --define=PREFIX=%s' % environ_cp['PREFIX']) - if 'LIBDIR' in environ_cp: - write_to_bazelrc('build --define=LIBDIR=%s' % environ_cp['LIBDIR']) - if 'INCLUDEDIR' in environ_cp: - write_to_bazelrc('build --define=INCLUDEDIR=%s' % environ_cp['INCLUDEDIR']) - - -def is_reduced_optimize_huge_functions_available(environ_cp): - """Check to see if the system supports /d2ReducedOptimizeHugeFunctions. - - The above compiler flag is a new compiler flag introduced to the Visual Studio - compiler in version 16.4 (available in Visual Studio 2019, Preview edition - only, as of 2019-11-19). TensorFlow needs this flag to massively reduce - compile times, but until 16.4 is officially released, we can't depend on it. - - See also - https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion - - Because it's very annoying to check this manually (to check the MSVC installed - versions, you need to use the registry, and it's not clear if Bazel will be - using that install version anyway), we expect enviroments who know they may - use this flag to export TF_VC_VERSION=16.4 - - TODO(angerson, gunan): Remove this function when TensorFlow's minimum VS - version is upgraded to 16.4. - - Arguments: - environ_cp: Environment of the current execution - - Returns: - boolean, whether or not /d2ReducedOptimizeHugeFunctions is available on this - machine. - """ - return float(environ_cp.get('TF_VC_VERSION', '0')) >= 16.4 + for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'): + if varname in environ_cp: + write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname])) def set_windows_build_flags(environ_cp): """Set Windows specific build options.""" - if is_reduced_optimize_huge_functions_available(environ_cp): - write_to_bazelrc( - 'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions' - ) + + # First available in VS 16.4. Speeds up Windows compile times by a lot. See + # https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion + # pylint: disable=line-too-long + write_to_bazelrc( + 'build --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions' + ) if get_var( environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline', @@ -1226,13 +1201,12 @@ def config_info_line(name, help_text): print('\t--config=%-12s\t# %s' % (name, help_text)) -def configure_ios(): - """Configures TensorFlow for iOS builds. - - This function will only be executed if `is_macos()` is true. - """ +def configure_ios(environ_cp): + """Configures TensorFlow for iOS builds.""" if not is_macos(): return + if not get_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False): + return for filepath in APPLE_BAZEL_FILES: existing_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath + '.apple') renamed_filepath = os.path.join(_TF_WORKSPACE_ROOT, filepath) @@ -1265,9 +1239,12 @@ def maybe_encode_env(env): if environ_cp.get('TF_NCCL_VERSION', None): cuda_libraries.append('nccl') + paths = glob.glob('**/third_party/gpus/find_cuda_config.py', recursive=True) + if not paths: + raise FileNotFoundError( + "Can't find 'find_cuda_config.py' script inside working directory") proc = subprocess.Popen( - [environ_cp['PYTHON_BIN_PATH'], 'third_party/gpus/find_cuda_config.py'] + - cuda_libraries, + [environ_cp['PYTHON_BIN_PATH'], paths[0]] + cuda_libraries, stdout=subprocess.PIPE, env=maybe_encode_env(environ_cp)) @@ -1348,11 +1325,11 @@ def main(): if is_macos(): environ_cp['TF_NEED_TENSORRT'] = '0' - else: - environ_cp['TF_CONFIGURE_IOS'] = '0' - if environ_cp.get('TF_ENABLE_XLA', '1') == '1': - write_to_bazelrc('build --config=xla') + with_xla_support = environ_cp.get('TF_ENABLE_XLA', None) + if with_xla_support is not None: + write_to_bazelrc('build --define=with_xla_support=%s' % ( + 'true' if int(with_xla_support) else 'false')) set_action_env_var( environ_cp, 'TF_NEED_ROCM', 'ROCm', False, bazel_config_name='rocm') @@ -1364,12 +1341,6 @@ def main(): if (environ_cp.get('TF_NEED_ROCM') == '1' and environ_cp.get('ROCM_PATH')): write_action_env_to_bazelrc('ROCM_PATH', environ_cp.get('ROCM_PATH')) - write_action_env_to_bazelrc('ROCM_ROOT', environ_cp.get('ROCM_PATH')) - - if ((environ_cp.get('TF_NEED_ROCM') == '1') and - (environ_cp.get('TF_ENABLE_MLIR_GENERATED_GPU_KERNELS') == '1')): - write_to_bazelrc( - 'build:rocm --define tensorflow_enable_mlir_generated_gpu_kernels=1') environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) @@ -1477,17 +1448,16 @@ def main(): system_specific_test_config(environ_cp) - set_action_env_var(environ_cp, 'TF_CONFIGURE_IOS', 'iOS', False) - if environ_cp.get('TF_CONFIGURE_IOS') == '1': - configure_ios() + configure_ios(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See .bazelrc for more ' 'details.') config_info_line('mkl', 'Build with MKL support.') - config_info_line('mkl_aarch64', 'Build with oneDNN support for Aarch64.') + config_info_line( + 'mkl_aarch64', + 'Build with oneDNN and Compute Library for the Arm Architecture (ACL).') config_info_line('monolithic', 'Config for mostly static monolithic build.') - config_info_line('ngraph', 'Build with Intel nGraph support.') config_info_line('numa', 'Build with NUMA support.') config_info_line( 'dynamic_kernels', diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 274a829f57526c..3ef74d742efb13 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -3,8 +3,16 @@ # learning applications. load("@bazel_skylib//lib:selects.bzl", "selects") -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") -load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") +load( + "//tensorflow:tensorflow.bzl", + "VERSION", + "if_google", + "if_oss", + "tf_cc_shared_object", + "tf_custom_op_library_additional_deps_impl", + "tf_native_cc_binary", +) load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_binary_deps", @@ -23,10 +31,6 @@ load( "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1", # @unused ) -load( - "//third_party/ngraph:build_defs.bzl", - "if_ngraph", -) load( "//third_party/mkl:build_defs.bzl", "if_mkl_ml", @@ -72,47 +76,110 @@ TENSORFLOW_API_INIT_FILES_V1 = ( # which requires restricted licenses to be avoided. config_setting( name = "no_lgpl_deps", - values = {"define": "__TENSORFLOW_NO_LGPL_DEPS__=1"}, + define_values = {"__TENSORFLOW_NO_LGPL_DEPS__": "1"}, + visibility = ["//visibility:public"], +) + +# Config setting that disables the default logger, only logging +# to registered TFLogSinks +config_setting( + name = "no_default_logger", + define_values = {"no_default_logger": "true"}, visibility = ["//visibility:public"], ) # Config setting for determining if we are building for Android. config_setting( name = "android", - values = {"crosstool_top": "//external:android/crosstool"}, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "android"}, + {}, + ), + values = if_oss( + {"crosstool_top": "//external:android/crosstool"}, + {}, + ), visibility = ["//visibility:public"], ) config_setting( name = "android_x86", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "x86", - }, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "android"}, + {}, + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "x86", + ), visibility = ["//visibility:public"], ) config_setting( name = "android_x86_64", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "x86_64", - }, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "android"}, + {}, + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "x86_64", + ), visibility = ["//visibility:public"], ) config_setting( name = "android_armeabi", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "armeabi", - }, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "android"}, + {}, + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "armeabi", + ), visibility = ["//visibility:public"], ) +# copybara:uncomment_begin(google-only) +# config_setting( +# name = "chromiumos_x86_64", +# flag_values = {"//tools/cpp:cc_target_os": "chromiumos"}, +# values = {"cpu": "k8"}, +# visibility = ["//visibility:public"], +# ) +# +# config_setting( +# name = "chromiumos_arm64", +# flag_values = {"//tools/cpp:cc_target_os": "chromiumos"}, +# values = {"cpu": "arm"}, +# visibility = ["//visibility:public"], +# ) +# +# config_setting( +# name = "chromiumos_armv7", +# flag_values = {"//tools/cpp:cc_target_os": "chromiumos"}, +# values = {"cpu": "armeabi-v7a"}, +# visibility = ["//visibility:public"], +# ) +# copybara:uncomment_end + config_setting( name = "emscripten", - values = {"crosstool_top": "//external:android/emscripten"}, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "emscripten"}, + {}, + ), + values = if_oss( + {"crosstool_top": "//external:android/emscripten"}, + {}, + ), visibility = ["//visibility:public"], ) @@ -127,19 +194,31 @@ config_setting( config_setting( name = "android_arm", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "armeabi-v7a", - }, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "android"}, + {}, + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "armeabi-v7a", + ), visibility = ["//visibility:public"], ) config_setting( name = "android_arm64", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "arm64-v8a", - }, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "android"}, + {}, + ), + values = dict( + if_oss( + {"crosstool_top": "//external:android/crosstool"}, + ), + cpu = "arm64-v8a", + ), visibility = ["//visibility:public"], ) @@ -163,7 +242,25 @@ config_setting( config_setting( name = "windows", - values = {"cpu": "x64_windows"}, + # Internal builds query the target OS. + flag_values = if_google( + {"//tools/cpp:cc_target_os": "windows"}, + {}, + ), + # OSS builds query the CPU type. + values = if_oss( + {"cpu": "x64_windows"}, + {}, + ), + visibility = ["//visibility:public"], +) + +config_setting( + name = "msvc_cl_debug", + values = { + "compiler": "msvc-cl", + "compilation_mode": "dbg", + }, visibility = ["//visibility:public"], ) @@ -174,38 +271,92 @@ config_setting( ) config_setting( - name = "macos", + name = "macos_x86_64", + flag_values = if_google( + {"//tools/cpp:cc_target_os": "apple"}, + {}, + ), values = { "apple_platform_type": "macos", - "cpu": "darwin", + "cpu": if_google("darwin_x86_64", "darwin"), }, visibility = ["//visibility:public"], ) +config_setting( + name = "macos_arm64", + flag_values = if_google( + {"//tools/cpp:cc_target_os": "apple"}, + {}, + ), + values = { + "apple_platform_type": "macos", + "cpu": "darwin_arm64", + }, + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "macos", + match_any = [ + ":macos_x86_64", + ":macos_arm64", + ], + visibility = ["//visibility:public"], +) + config_setting( name = "ios", - values = {"apple_platform_type": "ios"}, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "apple"}, + {}, + ), + values = if_oss( + {"apple_platform_type": "ios"}, + {}, + ), visibility = ["//visibility:public"], ) config_setting( name = "fuchsia", - values = {"cpu": "fuchsia"}, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "fuchsia"}, + {}, + ), + values = if_oss( + # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. + {"cpu": "fuchsia"}, + {}, + ), visibility = ["//visibility:public"], ) config_setting( name = "ios_x86_64", - values = { - "crosstool_top": "//tools/osx/crosstool:crosstool", - "cpu": "ios_x86_64", - }, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "apple"}, + {}, + ), + values = dict( + if_oss( + {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + ), + cpu = "ios_x86_64", + ), visibility = ["//visibility:public"], ) config_setting( name = "chromiumos", - values = {"crosstool_top": "//external:android/chromiumos"}, + flag_values = if_google( + {"//tools/cpp:cc_target_os": "chromiumos"}, + {}, + ), + values = if_oss( + {"crosstool_top": "//external:android/chromiumos"}, + {}, + ), visibility = ["//visibility:public"], ) @@ -245,6 +396,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "linux_riscv64", + values = {"cpu": "riscv64"}, + visibility = ["//visibility:public"], +) + config_setting( name = "debug", values = { @@ -303,12 +460,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "with_default_optimizations", - define_values = {"with_default_optimizations": "true"}, - visibility = ["//visibility:public"], -) - # Features that are default ON are handled differently below. # config_setting( @@ -342,15 +493,6 @@ config_setting( visibility = ["//visibility:public"], ) -# Crosses between platforms and file system libraries not supported on those -# platforms due to limitations in nested select() statements. -config_setting( - name = "with_cuda_support_windows_override", - define_values = {"using_cuda_nvcc": "true"}, - values = {"cpu": "x64_windows"}, - visibility = ["//visibility:public"], -) - config_setting( name = "with_xla_support", define_values = {"with_xla_support": "true"}, @@ -376,14 +518,12 @@ config_setting( # due to limitations in nested select() statements. config_setting( name = "framework_shared_object", - define_values = { - "framework_shared_object": "true", - }, + define_values = {"framework_shared_object": "true"}, visibility = ["//visibility:public"], ) config_setting( - name = "macos_with_framework_shared_object", + name = "macos_x86_64_with_framework_shared_object", define_values = { "framework_shared_object": "true", }, @@ -395,88 +535,109 @@ config_setting( ) config_setting( - name = "using_cuda_clang", + name = "macos_arm64_with_framework_shared_object", define_values = { - "using_cuda_clang": "true", + "framework_shared_object": "true", }, -) - -# Flag to indicate open source build, .bazelrc always has it set to be true -config_setting( - name = "oss", - define_values = { - "open_source_build": "true", + values = { + "apple_platform_type": "macos", + "cpu": "darwin_arm64", }, visibility = ["//visibility:public"], ) -config_setting( - name = "using_cuda_clang_with_dynamic_build", - define_values = { - "using_cuda_clang": "true", - "framework_shared_object": "true", - }, +selects.config_setting_group( + name = "macos_with_framework_shared_object", + match_any = [ + ":macos_x86_64_with_framework_shared_object", + ":macos_arm64_with_framework_shared_object", + ], ) -config_setting( - name = "build_oss_using_cuda_clang", - define_values = { - "using_cuda_clang": "true", - "open_source_build": "true", - }, +# Config setting that is satisfied when TensorFlow is being built with CUDA +# support through e.g. `--config=cuda` (or `--config=cuda_clang` in OSS). +alias( + name = "is_cuda_enabled", + actual = if_oss( + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//cuda:using_clang", + ), ) -# Setting to use when loading kernels dynamically -config_setting( - name = "dynamic_loaded_kernels", - define_values = { - "dynamic_loaded_kernels": "true", - "framework_shared_object": "true", - }, - visibility = ["//visibility:public"], +# Config setting that is satisfied when CUDA device code should be compiled +# with clang. It does not imply that CUDA support has been enabled. +alias( + name = "is_cuda_compiler_clang", + actual = if_oss( + "@local_config_cuda//:is_cuda_compiler_clang", + "@local_config_cuda//cuda:TRUE", + ), ) -config_setting( - name = "using_cuda_nvcc", - define_values = { - "using_cuda_nvcc": "true", - }, +# Config setting that is satisfied when CUDA device code should be compiled +# with nvcc. It does not imply that CUDA support has been enabled. +alias( + name = "is_cuda_compiler_nvcc", + actual = if_oss( + "@local_config_cuda//:is_cuda_compiler_nvcc", + "@local_config_cuda//cuda:FALSE", + ), ) -config_setting( - name = "using_cuda_nvcc_with_dynamic_build", - define_values = { - "using_cuda_nvcc": "true", - "framework_shared_object": "true", - }, +# Config setting that is satisfied when building with --config=cuda in OSS. +selects.config_setting_group( + name = "is_cuda_enabled_and_oss", + match_all = [ + ":is_cuda_enabled", + ":oss", + ], ) +# Config setting that is satisfied when building with --config=cuda for Windows +selects.config_setting_group( + name = "is_cuda_enabled_and_windows", + match_all = [ + ":is_cuda_enabled", + ":windows", + ], +) + +# Config setting to use in select()s to distinguish open source build from +# google internal build on configurable attributes. +# +# For non-configurable distinction between OSS and Google builds, see +# `if_oss()` and `if_google()` macros in tensorflow.bzl. config_setting( - name = "build_oss_using_cuda_nvcc", - define_values = { - "using_cuda_nvcc": "true", - "open_source_build": "true", - }, + name = "oss", + flag_values = {":oss_setting": "True"}, + visibility = ["//visibility:public"], ) +# Non-configurable setting to indicate open source build. +bool_setting( + name = "oss_setting", + build_setting_default = if_oss(True, False), + visibility = ["//visibility:private"], +) + +# Setting to use when loading kernels dynamically config_setting( - name = "using_rocm_hipcc", + name = "dynamic_loaded_kernels", define_values = { - "using_rocm_hipcc": "true", + "dynamic_loaded_kernels": "true", + "framework_shared_object": "true", }, + visibility = ["//visibility:public"], ) config_setting( - name = "override_eigen_strong_inline", - values = {"define": "override_eigen_strong_inline=true"}, - visibility = ["//visibility:public"], + name = "using_rocm_hipcc", + define_values = {"using_rocm_hipcc": "true"}, ) -# This flag is set from the configure step when the user selects with nGraph option. -# By default it should be false config_setting( - name = "with_ngraph_support", - values = {"define": "with_ngraph_support=true"}, + name = "override_eigen_strong_inline", + define_values = {"override_eigen_strong_inline": "true"}, visibility = ["//visibility:public"], ) @@ -488,40 +649,31 @@ config_setting( visibility = ["//visibility:public"], ) -# This flag is defined for select statements that match both -# on 'windows' and 'api_version_2'. In this case, bazel requires -# having a flag which is a superset of these two. -config_setting( - name = "windows_and_api_version_2", - define_values = {"tf_api_version": "2"}, - values = {"cpu": "x64_windows"}, -) - # This flag enables experimental MLIR support. config_setting( name = "with_mlir_support", - values = {"define": "with_mlir_support=true"}, + define_values = {"with_mlir_support": "true"}, visibility = ["//visibility:public"], ) # This flag forcibly enables experimental MLIR bridge support. config_setting( name = "enable_mlir_bridge", - values = {"define": "enable_mlir_bridge=true"}, + define_values = {"enable_mlir_bridge": "true"}, visibility = ["//visibility:public"], ) # This flag forcibly disables experimental MLIR bridge support. config_setting( name = "disable_mlir_bridge", - values = {"define": "enable_mlir_bridge=false"}, + define_values = {"enable_mlir_bridge": "false"}, visibility = ["//visibility:public"], ) # This flag enables experimental TPU support config_setting( name = "with_tpu_support", - values = {"define": "with_tpu_support=true"}, + define_values = {"with_tpu_support": "true"}, visibility = ["//visibility:public"], ) @@ -537,15 +689,34 @@ selects.config_setting_group( ], ) +# This flag disables all google production dependencies, intended for +# applications run with non-prod environment. +# TODO(timshen): Currently this option only disables some dependencies. +# See b/122528503. +# copybara:uncomment_begin(google-only) +# config_setting( +# name = "no_prod_deps", +# define_values = {"tf_no_prod_deps": "1"}, +# ) +# +# config_setting( +# name = "no_prod_deps_cuda", +# define_values = { +# "tf_no_prod_deps": "1", +# "GOOGLE_CUDA_COMPILER": "clang", +# }, +# ) +# copybara:uncomment_end + config_setting( name = "lite_protos_legacy", - values = {"define": "TENSORFLOW_PROTOS=lite"}, + define_values = {"TENSORFLOW_PROTOS": "lite"}, visibility = ["//visibility:private"], ) config_setting( name = "full_protos", - values = {"define": "TENSORFLOW_PROTOS=full"}, + define_values = {"TENSORFLOW_PROTOS": "full"}, visibility = ["//visibility:public"], ) @@ -570,6 +741,14 @@ selects.config_setting_group( ], ) +# copybara:uncomment_begin(google-only) +# config_setting( +# name = "portable_proto_force_third_party", +# define_values = {"PORTABLE_PROTO_TRANSITION_MODE": "third_party"}, +# visibility = ["//visibility:public"], +# ) +# copybara:uncomment_end + # 'enable_registration_v2' opts-in to a different implementation of op and # kernel registration - REGISTER_OP, REGISTER_KERNEL_BUILDER, etc. # @@ -600,11 +779,16 @@ config_setting( # DO NOT ADD ANY NEW EXCEPTIONS TO THIS LIST! # Instead, please use public APIs or public build rules TF provides. # If you need functionality that is not exposed, we will work with you to expand our public APIs. +# TODO(b/173549186): Move Google-internal TF code out of learning/brain package_group( name = "internal", packages = [ + "//learning/brain/keras/...", + "//learning/brain/mlir/...", "//learning/lib/ami/simple_ml/...", "//tensorflow/...", + "//tensorflow_decision_forests/...", + "//third_party/cloud_tpu/inference_converter/...", ], ) @@ -639,7 +823,7 @@ bzl_library( "//tensorflow/core/platform/default:cuda_build_defs_bzl", "//third_party/mkl:build_defs_bzl", "//third_party/mkl_dnn:build_defs_bzl", - "//third_party/ngraph:build_defs_bzl", + "@bazel_skylib//lib:new_sets", "@bazel_skylib//rules:common_settings", "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", @@ -725,7 +909,9 @@ tf_cc_shared_object( name = "tensorflow_framework", framework_so = [], linkopts = select({ - "//tensorflow:macos": [], + "//tensorflow:macos": [ + "-Wl,-rename_section,__TEXT,text_env,__TEXT,__text", + ], "//tensorflow:windows": [], "//tensorflow:freebsd": [ "-Wl,--version-script,$(location //tensorflow:tf_framework_version_script.lds)", @@ -741,13 +927,16 @@ tf_cc_shared_object( visibility = ["//visibility:public"], deps = [ "//tensorflow/c/experimental/filesystem:filesystem_interface", - "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", - "//tensorflow/c:kernels_hdrs", + "//tensorflow/c/experimental/stream_executor:stream_executor", + "//tensorflow/c:env", + "//tensorflow/c:kernels", + "//tensorflow/c:logging", "//tensorflow/c:ops_hdrs", "//tensorflow/cc/saved_model:loader_lite_impl", "//tensorflow/core/common_runtime:core_cpu_impl", "//tensorflow/core:framework_internal_impl", "//tensorflow/core/common_runtime/gpu:gpu_runtime_impl", + "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_runtime_impl", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl", "//tensorflow/core:lib_internal_impl", "//tensorflow/core/profiler:profiler_impl", @@ -798,12 +987,12 @@ tf_cc_shared_object( per_os_targets = True, soversion = VERSION, visibility = ["//visibility:public"], - # add win_def_file for tensorflow + # copybara:comment_begin(OSS Windows only: DEF file for exported symbols) win_def_file = select({ - # We need this DEF file to properly export symbols on Windows "//tensorflow:windows": ":tensorflow_filtered_def_file", "//conditions:default": None, }), + # copybara:comment_end deps = [ "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", @@ -827,16 +1016,21 @@ tf_cc_shared_object( "-z defs", "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], + }) + select({ + "//tensorflow:msvc_cl_debug": [ + "/DEBUG:FASTLINK", + ], + "//conditions:default": [], }), per_os_targets = True, soversion = VERSION, visibility = ["//visibility:public"], - # add win_def_file for tensorflow_cc + # copybara:comment_begin(OSS Windows only: DEF file for exported symbols) win_def_file = select({ - # We need this DEF file to properly export symbols on Windows "//tensorflow:windows": ":tensorflow_filtered_def_file", "//conditions:default": None, }), + # copybara:comment_end deps = [ "//tensorflow:tf_exported_symbols.lds", "//tensorflow:tf_version_script.lds", @@ -845,9 +1039,8 @@ tf_cc_shared_object( "//tensorflow/cc:cc_ops", "//tensorflow/cc:client_session", "//tensorflow/cc:scope", - "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", - ] + if_ngraph(["@ngraph_tf//:ngraph_tf"]), + ], ) # ** Targets for Windows build (start) ** @@ -1053,7 +1246,7 @@ gen_api_init_files( py_library( name = "tensorflow_py", - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = select({ "api_version_2": [], @@ -1075,7 +1268,7 @@ py_library( "//tensorflow/python/keras/api:keras_python_api_gen_compat_v1", "//tensorflow/python/keras/api:keras_python_api_gen_compat_v2", ], - srcs_version = "PY2AND3", + srcs_version = "PY3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 99a278a14a4b37..1e6b0e1f1d0fe5 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -84,13 +84,21 @@ _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "estimator", estimator) -try: - from .python.keras.api._v2 import keras - _current_module.__path__ = ( - [_module_util.get_parent_dir(keras)] + _current_module.__path__) +if _os.environ.get("_PREFER_OSS_KERAS", False): + _keras_module = "keras.api._v2.keras" + keras = _LazyLoader("keras", globals(), _keras_module) + _module_dir = _module_util.get_parent_dir_for_name(_keras_module) + if _module_dir: + _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "keras", keras) -except ImportError: - pass +else: + try: + from .python.keras.api._v2 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) + except ImportError: + pass # Explicitly import lazy-loaded modules to support autocompletion. # pylint: disable=g-import-not-at-top @@ -116,7 +124,8 @@ # Get sitepackages directories for the python installation. _site_packages_dirs = [] -_site_packages_dirs += [] if _site.USER_SITE is None else [_site.USER_SITE] +if _site.ENABLE_USER_SITE and _site.USER_SITE is not None: + _site_packages_dirs += [_site.USER_SITE] _site_packages_dirs += [_p for _p in _sys.path if 'site-packages' in _p] if 'getsitepackages' in dir(_site): _site_packages_dirs += _site.getsitepackages() @@ -145,17 +154,38 @@ def _running_from_pip_package(): _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) + # Load Pluggable Device Library + _ll.load_pluggable_device_library(_plugin_dir) # Add module aliases if hasattr(_current_module, 'keras'): - losses = keras.losses - metrics = keras.metrics - optimizers = keras.optimizers - initializers = keras.initializers - setattr(_current_module, "losses", losses) - setattr(_current_module, "metrics", metrics) - setattr(_current_module, "optimizers", optimizers) - setattr(_current_module, "initializers", initializers) + # It is possible that keras is a lazily loaded module, which might break when + # actually trying to import it. Have a Try-Catch to make sure it doesn't break + # when it doing some very initial loading, like tf.compat.v2, etc. + if _os.environ.get("_PREFER_OSS_KERAS", False): + try: + _keras_package = "keras.api._v2.keras." + losses = _LazyLoader("losses", globals(), _keras_package + "losses") + metrics = _LazyLoader("metrics", globals(), _keras_package + "metrics") + optimizers = _LazyLoader( + "optimizers", globals(), _keras_package + "optimizers") + initializers = _LazyLoader( + "initializers", globals(), _keras_package + "initializers") + setattr(_current_module, "losses", losses) + setattr(_current_module, "metrics", metrics) + setattr(_current_module, "optimizers", optimizers) + setattr(_current_module, "initializers", initializers) + except ImportError: + pass + else: + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers + initializers = keras.initializers + setattr(_current_module, "losses", losses) + setattr(_current_module, "metrics", metrics) + setattr(_current_module, "optimizers", optimizers) + setattr(_current_module, "initializers", initializers) # pylint: enable=undefined-variable # Delete modules that should be hidden from dir(). diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index ae82f7b4792adc..115c7a41519a8f 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -75,13 +75,21 @@ _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "estimator", estimator) -try: - from .python.keras.api._v1 import keras - _current_module.__path__ = ( - [_module_util.get_parent_dir(keras)] + _current_module.__path__) +if _os.environ.get("_PREFER_OSS_KERAS", False): + _keras_module = "keras.api._v1.keras" + keras = _LazyLoader("keras", globals(), _keras_module) + _module_dir = _module_util.get_parent_dir_for_name(_keras_module) + if _module_dir: + _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "keras", keras) -except ImportError: - pass +else: + try: + from .python.keras.api._v1 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) + except ImportError: + pass # Explicitly import lazy-loaded modules to support autocompletion. # pylint: disable=g-import-not-at-top @@ -155,6 +163,8 @@ def _running_from_pip_package(): _plugin_dir = _os.path.join(_s, 'tensorflow-plugins') if _os.path.exists(_plugin_dir): _ll.load_library(_plugin_dir) + # Load Pluggable Device Library + _ll.load_pluggable_device_library(_plugin_dir) # Delete modules that should be hidden from dir(). # Don't fail if these modules are not available. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 3f4d70ed60eea6..429589ba0c74c4 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", + "check_deps", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -78,7 +79,7 @@ cc_library( ], visibility = [ "//tensorflow/core:__pkg__", - "//tensorflow/python:__pkg__", + "//tensorflow/python:__subpackages__", ], ) @@ -155,16 +156,19 @@ tf_cuda_library( "tf_file_statistics.h", "tf_status.h", "tf_tensor.h", + "tf_tstring.h", ], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ + "//tensorflow/core/platform:tstring", ":c_api_no_xla", ":c_api_internal", ":tf_attrtype", ":tf_status_internal", ":tf_file_statistics", ":tf_tensor_internal", + ":tf_tstring", ] + select({ "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", @@ -174,6 +178,13 @@ tf_cuda_library( }), ) +# Check that c_api_no_xla does not depend on xla. +check_deps( + name = "c_api_no_xla_check_deps", + disallowed_deps = ["//tensorflow/compiler/jit:xla_kernel_creator"], + deps = [":c_api_no_xla"], +) + tf_cuda_library( name = "c_api_no_xla", srcs = [ @@ -199,6 +210,8 @@ tf_cuda_library( "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ + ":env", + ":logging", ":tf_status", ":tf_tensor", "@com_google_absl//absl/strings", @@ -304,6 +317,24 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "tf_tstring", + srcs = [ + "tf_tstring.cc", + ], + hdrs = [ + "c_api_macros.h", + "tf_datatype.h", + "tf_status.h", + "tf_tensor.h", + "tf_tstring.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform:tstring", + ], +) + cc_library( name = "tf_file_statistics", hdrs = ["tf_file_statistics.h"], @@ -419,6 +450,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/common_runtime/pluggable_device:pluggable_device_plugin_init", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/platform", "//tensorflow/core/platform:blocking_counter", @@ -504,11 +536,13 @@ tf_cuda_library( "//tensorflow/core:framework", ], }) + [ - ":c_api", + ":c_api_macros", + ":tf_status", ":tf_status_helper", - ":c_api_internal", ":tf_file_statistics", - "//tensorflow/core:lib", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:types", ], ) @@ -521,6 +555,7 @@ cc_library( ":tf_datatype", ":tf_status", ":tf_tensor", + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", ], ) @@ -541,13 +576,17 @@ tf_cuda_library( ] + select({ "//tensorflow:android": [ ":c_api_internal", + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ ":c_api_internal", ":tf_tensor", + "//tensorflow/stream_executor:stream", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", + "//tensorflow/c/experimental/stream_executor:stream_executor", + "//tensorflow/c/experimental/stream_executor:stream_executor_internal", ], }), ) @@ -634,6 +673,7 @@ tf_cuda_cc_test( "//conditions:default": [], }), tags = [ + "no_cuda_asan", # TODO(b/181771536) "no_windows", # TODO(b/155444728) "noasan", ], @@ -642,6 +682,7 @@ tf_cuda_cc_test( # linkstatic = tf_kernel_tests_linkstatic(), deps = [ ":c_api", + ":c_api_internal", ":c_test_util", ":test_op_kernel", "//tensorflow/cc:cc_ops", @@ -678,7 +719,10 @@ tf_cc_test( name = "c_api_experimental_test", size = "medium", srcs = ["c_api_experimental_test.cc"], - data = ["testdata/tf_record"], + data = [ + "testdata/tf_record", + "//tensorflow/c/experimental/stream_executor/test:test_pluggable_device.so", + ], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -698,6 +742,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", "@com_google_absl//absl/types:optional", ], ) @@ -792,6 +837,7 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:ops_testutil", "//third_party/eigen3", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 9579efab94d807..f3bf7b98a1e6b5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -238,6 +238,15 @@ Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, return Status::OK(); } +Status BufferToMessage(const TF_Buffer* in, + tensorflow::protobuf::MessageLite* out) { + if (in == nullptr || !out->ParseFromArray(in->data, in->length)) { + return errors::InvalidArgument("Unparseable ", out->GetTypeName(), + " proto"); + } + return Status::OK(); +} + void RecordMutation(TF_Graph* graph, const TF_Operation& op, const char* mutation_type) { // If any session has already run this node_id, mark this session as diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index f550b690e27d1c..705cf85e0512fa 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1205,7 +1205,7 @@ typedef struct TF_Session TF_Session; // Return a new execution session with the associated graph, or NULL on // error. Does not take ownership of any input parameters. // -// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be // kept alive for the lifetime of the returned TF_Session. New nodes can still // be added to `graph` after this call. TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 0d188aa5ee0f4b..2b8bd5178afa4b 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -37,7 +38,9 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/strcat.h" @@ -494,51 +497,6 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type, return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(tensor)); } -namespace { -tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, - TFE_Context* ctx) { - // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the - // server object (which currently CHECK-fails) and we miss the error, instead, - // we log the error, and then return to allow the user to see the error - // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const ::tensorflow::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - LOG(ERROR) << _status.error_message(); \ - return _status; \ - } \ - } while (0); - - // New server created for new server_def. Unused if updating server_def. - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server = - dynamic_cast(context->GetServer()); - if (grpc_server == nullptr) { - std::unique_ptr new_server; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); - grpc_server = dynamic_cast(new_server.get()); - if (grpc_server == nullptr) { - LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal( - "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); - } - LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - - LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( - std::move(new_server), grpc_server->worker_env()->device_mgr, - grpc_server->worker_env()->collective_executor_mgr.get())); - } else { - LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); - LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( - /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, - grpc_server->worker_env()->collective_executor_mgr.get())); - } - return tensorflow::Status::OK(); -#undef LOG_AND_RETURN_IF_ERROR -} -} // namespace - // Set server_def on the context, possibly updating it. TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, const void* proto, @@ -550,7 +508,9 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - status->status = EnableCollectiveOps(server_def, ctx); + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->EnableCollectiveOps( + server_def); } TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, @@ -630,6 +590,9 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array, namespace tensorflow { Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); + +// Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file). +Status LoadPluggableDeviceLibrary(const char* library_filename, void** result); } // namespace tensorflow void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, @@ -696,6 +659,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, c.SetInput(i, c.UnknownShape()); continue; } + dims.reserve(input_shape.num_dims); for (int j = 0; j < input_shape.num_dims; ++j) { dims.push_back(c.MakeDim(input_shape.dims[j])); } @@ -743,3 +707,48 @@ void TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptions* opts, unsigned char enable) { opts->opts.validate_colocation_constraints = enable; } + +// Load a Pluggable Device library. +// On success, returns the handle to library in result and return OK from the +// function. Otherwise return nullptr in result and error Status from the +// function. +// +// If `library_filename` has already been loaded, we return a cached handle. +// Device and Kernels/Ops are registered as globals when a library is loaded +// for the first time. +TF_Library* TF_LoadPluggableDeviceLibrary(const char* library_filename, + TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "PluggableDevice plugin functionality is not supported on mobile"); + return nullptr; +#else + TF_Library* lib_handle = new TF_Library; + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_map* loaded_libs = + new std::unordered_map(); + tensorflow::Env* env = tensorflow::Env::Default(); + { + tensorflow::mutex_lock lock(mu); + auto it = loaded_libs->find(library_filename); + if (it != loaded_libs->end()) { + lib_handle->lib_handle = it->second; + } else { + status->status = + env->LoadDynamicLibrary(library_filename, &lib_handle->lib_handle); + if (status->status.ok()) { + TF_CHECK_OK( + tensorflow::RegisterPluggableDevicePlugin(lib_handle->lib_handle)); + } else { + delete lib_handle; + return nullptr; + } + } + return lib_handle; + } +#endif +} + +void TF_DeletePluggableDeviceLibraryHandle(TF_Library* lib_handle) { + delete lib_handle; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 90e074d232fbf5..d4132153641808 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -86,7 +86,7 @@ TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled( // Create a serialized tensorflow.ConfigProto proto, where: // -// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if +// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if // `enable_xla_compilation` is non-zero, and OFF otherwise. // b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. // c) ConfigProto.device_count is set to `num_cpu_devices`. @@ -304,6 +304,27 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetValidateColocationConstraints( TF_ImportGraphDefOptions* opts, unsigned char enable); +// Load the library specified by library_filename and register the pluggable +// device and related kernels present in that library. This function is not +// supported on embedded on mobile and embedded platforms and will fail if +// called. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, returns the newly created library handle and places OK in status. +// The caller owns the library handle. +// +// On failure, returns nullptr and places an error status in status. +TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary( + const char* library_filename, TF_Status* status); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle( + TF_Library* lib_handle); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index cfeba345f8122f..e47b7d0b0f798d 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -234,5 +235,25 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) { TF_DeleteTensor(tensor_1X6); } +TEST(CAPI_EXPERIMENTAL, LibraryPluggableDeviceLoadFunctions) { + // TODO(penpornk): Enable this test on Windows. +#if !defined(PLATFORM_WINDOWS) +#if !defined(TENSORFLOW_NO_SHARED_OBJECTS) + // Load the library. + TF_Status* status = TF_NewStatus(); + string lib_path = + tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath( + "tensorflow", "c", "experimental", "stream_executor", "test", + "test_pluggable_device.so")); + TF_Library* lib = TF_LoadPluggableDeviceLibrary(lib_path.c_str(), status); + TF_Code code = TF_GetCode(status); + string status_msg(TF_Message(status)); + TF_DeleteStatus(status); + ASSERT_EQ(TF_OK, code) << status_msg; + TF_DeletePluggableDeviceLibraryHandle(lib); +#endif // !defined(TENSORFLOW_NO_SHARED_OBJECTS) +#endif // !defined(PLATFORM_WINDOWS) +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index a0fa9613e7fce0..f03e52a937a518 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -155,7 +155,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( int ncontrol_outputs, const TF_Operation* const* control_outputs, const char* const* control_output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status) { - tensorflow::mutex_lock l(*const_cast(&fn_body->mu)); + tensorflow::mutex_lock l(fn_body->mu); // Process inputs. std::vector input_tensors; @@ -196,6 +196,7 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( // Compute body nodes. std::vector control_output_nodes; + control_output_nodes.reserve(ncontrol_outputs); for (int i = 0; i < ncontrol_outputs; ++i) { control_output_nodes.push_back(&control_outputs[i]->node); } @@ -213,6 +214,11 @@ TF_Function* TF_GraphToFunctionWithControlOutputs( TF_DeleteFunction(tf_function); return nullptr; } + + for (const Node* n : fn_body->graph.nodes()) { + tf_function->stack_traces[n->name()] = n->GetStackTrace(); + } + return tf_function; } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 0d128b23e329cf..76345cf068ce87 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -70,7 +70,7 @@ struct TF_Library { struct TF_Graph { TF_Graph(); - tensorflow::mutex mu; + mutable tensorflow::mutex mu; tensorflow::Graph graph TF_GUARDED_BY(mu); // Runs shape inference. @@ -157,6 +157,7 @@ struct TF_DeviceList { struct TF_Function { tensorflow::FunctionDef fdef; + tensorflow::StackTracesMap stack_traces; }; struct TF_ApiDefMap { @@ -189,6 +190,9 @@ namespace tensorflow { Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out); +Status BufferToMessage(const TF_Buffer* in, + tensorflow::protobuf::MessageLite* out); + // Set the shapes and types of the output's handle. // // The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index fc1fdccee162aa..e0b16da84c9c37 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/cc/saved_model/signature_constants.h" @@ -44,6 +45,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" @@ -2576,6 +2578,20 @@ TEST(CAPI, TestTensorIsNotAligned) { TF_DeleteTensor(a); } +TEST(CAPI, MessageBufferConversion) { + NodeDef node_in, node_out; + node_in.set_name("Test name"); + node_in.set_op("Test op"); + + TF_Buffer* buffer = TF_NewBuffer(); + TF_CHECK_OK(MessageToBuffer(node_in, buffer)); + TF_CHECK_OK(BufferToMessage(buffer, &node_out)); + TF_DeleteBuffer(buffer); + + protobuf::util::MessageDifferencer differencer; + EXPECT_TRUE(differencer.Compare(node_in, node_out)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c44d0ee6873038..9a65db0b2d1a5b 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,7 +3,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", - "if_libtpu", "tf_cc_test", "tf_copts", "tf_cuda_cc_test", @@ -51,7 +50,7 @@ tf_cuda_library( ":immediate_execution_context", ":immediate_execution_operation", ":immediate_execution_tensor_handle", - ":abstract_tensor_handle", + ":immediate_execution_distributed_manager", ":tfe_context_internal", ":tfe_cancellation_manager_internal", ":tfe_executor_internal", @@ -70,10 +69,13 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:core", + "//tensorflow/core/common_runtime/eager:custom_device", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:placement_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -84,6 +86,7 @@ tf_cuda_library( ], }) + [ "@com_google_absl//absl/memory", + ":abstract_tensor_handle", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/eager:remote_mgr", "//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime", @@ -108,8 +111,10 @@ filegroup( srcs = [ "abstract_context.h", "abstract_function.h", + "abstract_op_attrs.h", "abstract_operation.h", "abstract_tensor_handle.h", + "c_api.h", "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", @@ -118,6 +123,7 @@ filegroup( "gradients.h", "gradients_internal.h", "immediate_execution_context.h", + "immediate_execution_distributed_manager.h", "immediate_execution_operation.h", "immediate_execution_tensor_handle.h", "tape.h", @@ -127,6 +133,7 @@ filegroup( "tfe_monitoring_internal.h", "tfe_op_attrs_internal.h", "tfe_tensor_debug_info_internal.h", + "tfe_tensorhandle_internal.h", ], visibility = [ "//tensorflow/core:__pkg__", @@ -140,10 +147,7 @@ cc_library( "c_api_experimental.h", "c_api_internal.h", ], - visibility = [ - "//learning/deepmind/courier:__subpackages__", - "//tensorflow:internal", - ], + visibility = ["//tensorflow:internal"], deps = [ ":c_api", ":tfe_cancellation_manager_internal", @@ -175,6 +179,7 @@ cc_library( "//tensorflow/c:c_api_internal", "//tensorflow/c:conversion_macros", "//tensorflow/c:tf_status", + "//tensorflow/core:framework", "//tensorflow/core/platform:casts", "//tensorflow/core/platform:types", ], @@ -212,17 +217,46 @@ cc_library( ], deps = [ ":abstract_context", - ":abstract_operation", ":abstract_tensor_handle", ":c_api_unified_internal", ":tape", "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) +cc_library( + name = "unified_api_testutil", + testonly = 1, + srcs = [ + "unified_api_testutil.cc", + ], + hdrs = [ + "unified_api_testutil.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_test_util", + ":c_api_unified_internal", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c:tf_tensor", + "//tensorflow/core:framework", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + tf_cuda_cc_test( name = "gradients_test", size = "small", @@ -239,14 +273,15 @@ tf_cuda_cc_test( ":c_api_test_util", ":c_api_unified_internal", ":gradients_internal", + ":unified_api_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", "//tensorflow/c/experimental/gradients:array_grad", "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:not_differentiable", "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops", - "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -259,77 +294,32 @@ tf_cuda_cc_test( ], ) -cc_library( - name = "gradients_util", +tf_cuda_cc_test( + name = "unified_api_test", + size = "small", srcs = [ - "gradients_util.cc", - ], - hdrs = [ - "gradients_util.h", - ], - visibility = [ - "//tensorflow:internal", + "unified_api_test.cc", ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156 deps = [ - ":abstract_context", - ":abstract_operation", - ":abstract_tensor_handle", - ":c_api", ":c_api_experimental", ":c_api_unified_internal", - ":gradients_internal", - ":tape", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "//tensorflow/c:c_api", + ":unified_api_testutil", "//tensorflow/c:tf_status_helper", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/lib/llvm_rtti", - ] + if_libtpu( - if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], - if_true = [], - ), -) - -cc_library( - name = "mnist_gradients_testutil", - srcs = [ - "mnist_gradients_testutil.cc", - ], - hdrs = [ - "mnist_gradients_testutil.h", - ], - visibility = [ - "//tensorflow:internal", - ], - deps = [ - ":abstract_tensor_handle", - ":c_api_experimental", - ":c_api_unified_internal", - ":gradients_internal", - ":gradients_util", - ":tape", - "//tensorflow/c/experimental/gradients/tape:tape_context", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", "//tensorflow/core/lib/llvm_rtti", - "//tensorflow/core/platform:status", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/types:span", + "//tensorflow/core/platform:errors", ], ) cc_library( name = "gradient_checker", + testonly = 1, srcs = [ "gradient_checker.cc", ], @@ -341,120 +331,62 @@ cc_library( ], deps = [ ":abstract_tensor_handle", - ":c_api_experimental", - ":c_api_unified_internal", - ":gradients_internal", - ":gradients_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "//tensorflow/c:c_api", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/experimental/gradients:math_grad", - "//tensorflow/c/experimental/gradients:nn_grad", - "//tensorflow/c/experimental/ops:array_ops", + ":unified_api_testutil", + "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/cc/profiler", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/lib/llvm_rtti", - ] + if_libtpu( - if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], - if_true = [], - ), -) - -tf_cuda_cc_test( - name = "gradient_checker_test", - size = "small", - srcs = [ - "gradient_checker_test.cc", - ], - args = ["--heap_check=local"], - linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags() + ["nomac"], - deps = [ - ":abstract_tensor_handle", - ":c_api_experimental", - ":c_api_test_util", - ":c_api_unified_internal", - ":gradient_checker", - ":gradients_internal", - ":gradients_util", - ":mnist_gradients_testutil", - "//tensorflow/c:c_api", - "//tensorflow/c:c_test_util", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/experimental/gradients:math_grad", - "//tensorflow/c/experimental/gradients:nn_grad", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/cc/profiler", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/lib/llvm_rtti", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) tf_cuda_cc_test( - name = "mnist_gradients_test", + name = "gradient_checker_test", size = "small", srcs = [ - "mnist_gradients_test.cc", + "gradient_checker_test.cc", ], args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ - "nomac", + "no_cuda_asan", # b/175330074 ], deps = [ ":abstract_tensor_handle", ":c_api_experimental", - ":c_api_unified_internal", - ":gradients_internal", - ":gradients_util", - ":mnist_gradients_testutil", - "//tensorflow/c:c_api", - "//tensorflow/c:c_test_util", + ":gradient_checker", + ":unified_api_testutil", "//tensorflow/c:tf_status_helper", - "//tensorflow/c/experimental/gradients:math_grad", - "//tensorflow/c/experimental/gradients:nn_grad", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/cc/profiler", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", + "//tensorflow/c:tf_tensor_internal", + "//tensorflow/c/experimental/ops", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/lib/llvm_rtti", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", + "//tensorflow/core/platform:tensor_float_32_utils", "@com_google_absl//absl/types:span", ], ) cc_library( name = "abstract_tensor_handle", + srcs = ["abstract_tensor_handle.cc"], hdrs = ["abstract_tensor_handle.h"], visibility = [ "//tensorflow:internal", ], - deps = [ - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:refcount", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/platform:status", + ], + }), ) cc_library( name = "immediate_execution_tensor_handle", + srcs = ["immediate_execution_tensor_handle.cc"], hdrs = ["immediate_execution_tensor_handle.h"], visibility = [ "//tensorflow:internal", @@ -468,6 +400,21 @@ cc_library( ], ) +cc_library( + name = "abstract_op_attrs", + hdrs = ["abstract_op_attrs.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c:tensor_interface", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "abstract_operation", hdrs = ["abstract_operation.h"], @@ -498,7 +445,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/util:abstract_stack_trace", + "//tensorflow/core/util:managed_stack_trace", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -528,6 +475,19 @@ cc_library( ], ) +cc_library( + name = "immediate_execution_distributed_manager", + hdrs = ["immediate_execution_distributed_manager.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "immediate_execution_context", hdrs = ["immediate_execution_context.h"], @@ -536,12 +496,14 @@ cc_library( ], deps = [ ":abstract_context", + ":immediate_execution_distributed_manager", ":immediate_execution_operation", ":immediate_execution_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -566,6 +528,7 @@ cc_library( "//tensorflow:internal", ], deps = [ + "//tensorflow/c:conversion_macros", "//tensorflow/core:framework", ], ) @@ -600,10 +563,10 @@ cc_library( "//tensorflow:internal", ], deps = [ + ":abstract_op_attrs", "//tensorflow/c:conversion_macros", "//tensorflow/c:tf_status", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/common_runtime/eager:attr_builder", ], ) @@ -655,6 +618,19 @@ cc_header_only_library( ], ) +cc_header_only_library( + name = "tfe_cancellationmanager_internal_hdrs_only", + extra_deps = [ + "@com_google_absl//absl/strings", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tfe_cancellation_manager_internal", + ], +) + tf_cuda_library( name = "c_api_test_util", testonly = 1, @@ -685,10 +661,9 @@ tf_cuda_cc_test( "c_api_test.cc", ], tags = [ - "noguitar", # TODO(b/155445984): flaky - #"guitar", - "notap", # TODO(b/156981931): flaky - "multi_gpu", + "no_cuda_asan", # TODO(b/181771536) + "guitar", + # "multi_gpu", b/180748118 ], deps = [ ":c_api", @@ -934,7 +909,6 @@ tf_cuda_cc_test( ":c_api_experimental", ":c_api_test_util", "//tensorflow/c:c_test_util", - "//tensorflow/cc/profiler", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -960,7 +934,6 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", - "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -999,7 +972,6 @@ tf_cc_test( ":custom_device_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", - "//tensorflow/cc/profiler", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -1072,8 +1044,6 @@ filegroup( "gradient_checker.cc", "gradient_checker.h", "gradients.cc", # Uses RTTI. - "gradients_util.cc", - "gradients_util.h", "tracing_utils.h", "tracing_utils.cc", "*test*", diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h index d31b1e13611ff9..07a78f97bd5a9f 100644 --- a/tensorflow/c/eager/abstract_context.h +++ b/tensorflow/c/eager/abstract_context.h @@ -32,7 +32,7 @@ namespace tensorflow { // environment, a traced representation etc. class AbstractContext { protected: - enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape }; + enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler }; explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} virtual ~AbstractContext() {} diff --git a/tensorflow/c/eager/abstract_op_attrs.h b/tensorflow/c/eager/abstract_op_attrs.h new file mode 100644 index 00000000000000..6c3af10e169f66 --- /dev/null +++ b/tensorflow/c/eager/abstract_op_attrs.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OP_ATTRS_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_OP_ATTRS_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Attributes of an op. +class AbstractOpAttrs { + protected: + enum AbstractOpAttrsKind { kEager, kTfrt }; + explicit AbstractOpAttrs(AbstractOpAttrsKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractOpAttrsKind getKind() const { return kind_; } + virtual ~AbstractOpAttrs() = default; + + // Returns the AbstractFunction as a FunctionDef. + virtual void GetNameAttrList( + tensorflow::NameAttrList* name_and_attrs) const = 0; + + virtual bool GetInt(absl::string_view, int64_t* result) const = 0; + virtual bool GetFloat(absl::string_view attr_name, float* result) const = 0; + virtual bool GetBool(absl::string_view attr_name, bool* result) const = 0; + virtual bool GetType(absl::string_view attr_name, DataType* result) const = 0; + + private: + const AbstractOpAttrsKind kind_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_OP_ATTRS_H_ diff --git a/tensorflow/c/eager/abstract_operation.h b/tensorflow/c/eager/abstract_operation.h index 4c630528f5ddca..997c8e0e441d42 100644 --- a/tensorflow/c/eager/abstract_operation.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -30,7 +30,14 @@ namespace tensorflow { // tracing or immediate execution mode. class AbstractOperation { protected: - enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape }; + enum AbstractOperationKind { + kGraph, + kMlir, + kEager, + kTfrt, + kTape, + kOpHandler + }; explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} virtual ~AbstractOperation() {} diff --git a/tensorflow/c/eager/abstract_tensor_handle.cc b/tensorflow/c/eager/abstract_tensor_handle.cc new file mode 100644 index 00000000000000..a30063a15f4f45 --- /dev/null +++ b/tensorflow/c/eager/abstract_tensor_handle.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/abstract_tensor_handle.h" + +namespace tensorflow { + +std::string AbstractTensorHandle::DebugString() const { + PartialTensorShape shape; + Status s = Shape(&shape); + std::string shape_string; + if (!s.ok()) { + shape_string = ""; + } else { + shape_string = shape.DebugString(); + } + return absl::StrCat("TensorHandle(shape=", shape_string, + ", dtype=", DataType_Name(DataType()), ")"); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index 37e6d1bf29cc5d..8d7e2114d04a39 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -17,21 +17,30 @@ limitations under the License. #include +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { // Abstract interface to a Tensor handle in either tracing or immediate // execution mode. class AbstractTensorHandle : public core::RefCounted { protected: - enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt }; + enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice }; explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} virtual ~AbstractTensorHandle() {} public: // Returns tensor dtype. virtual tensorflow::DataType DataType() const = 0; + // Returns tensor shape. If tensor has unknown rank, shape remains untouched. + virtual tensorflow::Status Shape( + tensorflow::PartialTensorShape* shape) const = 0; + + // The default debug string includes a shape and dtype. Implementations are + // free to override it with something more informative. + virtual std::string DebugString() const; AbstractTensorHandleKind getKind() const { return kind_; } diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 3418bccf050f3c..8182a15be87052 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,16 +21,11 @@ limitations under the License. #include #include -#include "tensorflow/c/eager/abstract_tensor_handle.h" - -// clang-format off -#include "tensorflow/core/platform/platform.h" -// clang-format on - #include "absl/algorithm/container.h" #include "absl/memory/memory.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/immediate_execution_operation.h" @@ -39,59 +34,44 @@ limitations under the License. #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_tensor_internal.h" -#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) -#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" -#endif -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/eager/context.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/device_filters.pb.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/placement_utils.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#if !defined(IS_MOBILE_PLATFORM) -#include "tensorflow/core/distributed_runtime/eager/eager_client.h" -#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" -#include "tensorflow/core/distributed_runtime/remote_device.h" -#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/distributed_runtime/worker_interface.h" -#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" -#endif // !IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/casts.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/notification.h" -#include "tensorflow/core/platform/random.h" -#include "tensorflow/core/platform/refcount.h" -#include "tensorflow/core/platform/stringpiece.h" -#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/public/version.h" +// "tensorflow/core/platform/platform.h" must be included first before using +// PLATFORM_GOOGLE, IS_MOBILE_PLATFORM, etc. +#if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) +#include "tensorflow/core/tfrt/eager/c_api_tfrt.h" +#include "tensorflow/core/tfrt/eager/c_api_tfrt_distributed_impl.h" +#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE + +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" +#endif // !IS_MOBILE_PLATFORM + using tensorflow::string; namespace { @@ -100,610 +80,14 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#if !defined(IS_MOBILE_PLATFORM) -bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, - const tensorflow::ServerDef& server_def) { - if (server_def.job_name() != context->HostCPU()->parsed_name().job) { - return false; - } - return server_def.default_session_config().SerializeAsString() == - context->session_options().config.SerializeAsString(); -} - -tensorflow::Status AddRemoteDevicesToMgr( - const std::vector& added_remote_workers, - tensorflow::WorkerCacheInterface* worker_cache, - tensorflow::DynamicDeviceMgr* remote_device_mgr) { - std::vector> remote_devices; - tensorflow::mutex remote_devices_mu; - int num_added_workers = added_remote_workers.size(); - tensorflow::BlockingCounter counter(num_added_workers); - std::vector statuses(num_added_workers); - for (int i = 0; i < num_added_workers; i++) { - tensorflow::NewRemoteDevices( - tensorflow::Env::Default(), worker_cache, added_remote_workers[i], - [i, &statuses, &counter, &remote_devices, &remote_devices_mu]( - const tensorflow::Status& s, - std::vector* devices) { - statuses[i] = s; - if (s.ok()) { - tensorflow::mutex_lock l(remote_devices_mu); - for (tensorflow::Device* d : *devices) { - remote_devices.emplace_back(d); - } - } - counter.DecrementCount(); - }); - } - counter.Wait(); - for (int i = 0; i < num_added_workers; i++) { - TF_RETURN_IF_ERROR(statuses[i]); - } - - TF_RETURN_IF_ERROR(remote_device_mgr->AddDevices(std::move(remote_devices))); - return tensorflow::Status::OK(); -} - -tensorflow::Status GetAllRemoteDevices( - const std::vector& remote_workers, - tensorflow::WorkerCacheInterface* worker_cache, - std::unique_ptr* device_mgr) { - auto remote_device_mgr = absl::make_unique(); - TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache, - remote_device_mgr.get())); - *device_mgr = std::move(remote_device_mgr); - return tensorflow::Status::OK(); -} - -tensorflow::Status RemoveRemoteDevicesFromMgr( - const std::vector& removed_remote_workers, - tensorflow::DynamicDeviceMgr* remote_device_mgr) { - const std::vector remote_devices = - (remote_device_mgr->ListDevices()); - std::vector devices_to_remove; - for (tensorflow::Device* d : remote_devices) { - for (const string& remote_worker : removed_remote_workers) { - if (tensorflow::DeviceNameUtils::IsSameAddressSpace(remote_worker, - d->name())) { - devices_to_remove.emplace_back(d); - break; - } - } - } - TF_RETURN_IF_ERROR(remote_device_mgr->RemoveDevices(devices_to_remove)); - return tensorflow::Status::OK(); -} - -tensorflow::Status ListRemoteWorkers(tensorflow::ServerInterface* server, - const string& local_worker, - std::vector* remote_workers) { - tensorflow::GrpcServer* grpc_server = - dynamic_cast(server); - if (grpc_server == nullptr) { - return tensorflow::errors::Internal( - "Currently, TFE_NewContext only supports tensorflow::GrpcServer."); - } - grpc_server->master_env()->worker_cache->ListWorkers(remote_workers); - remote_workers->erase( - std::remove(remote_workers->begin(), remote_workers->end(), local_worker), - remote_workers->end()); - return tensorflow::Status::OK(); -} - -void DifferentiateWorkerLists(const std::vector* current_list, - const std::vector* new_list, - std::vector* added, - std::vector* removed, - std::vector* existing) { - // Get STL set_difference and set_intersection with one list traversal. - // Similar to the set_difference library function, the input lists - // (`current_list` and `new_list`) must be sorted before calling the function. - added->resize(new_list->size()); - removed->resize(current_list->size()); - existing->resize(current_list->size()); - std::vector::const_iterator curr_it = current_list->begin(); - std::vector::const_iterator new_it = new_list->begin(); - std::vector::iterator added_it = added->begin(); - std::vector::iterator removed_it = removed->begin(); - std::vector::iterator existing_it = existing->begin(); - while (curr_it != current_list->end() && new_it != new_list->end()) { - if (*curr_it < *new_it) { - *removed_it++ = *curr_it++; - } else if (*curr_it > *new_it) { - *added_it++ = *new_it++; - } else { - *existing_it++ = *curr_it++; - new_it++; - } - } - removed_it = std::copy(curr_it, current_list->end(), removed_it); - added_it = std::copy(new_it, new_list->end(), added_it); - added->resize(added_it - added->begin()); - removed->resize(removed_it - removed->begin()); - existing->resize(existing_it - existing->begin()); -} - -tensorflow::Status GetReplacedFromExistingWorkers( - const std::vector* existing_workers, tensorflow::uint64 context_id, - tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* client_cache, - std::vector* replaced_workers) { - tensorflow::BlockingCounter counter(existing_workers->size()); - std::vector statuses(existing_workers->size()); - tensorflow::eager::KeepAliveRequest request; - request.set_context_id(context_id); - std::vector responses( - existing_workers->size()); - for (int i = 0; i < existing_workers->size(); i++) { - tensorflow::core::RefCountPtr eager_client; - statuses[i] = - client_cache->GetClient(existing_workers->at(i), &eager_client); - if (!statuses[i].ok()) { - counter.DecrementCount(); - continue; - } - eager_client->KeepAliveAsync( - &request, &responses[i], - [i, &statuses, &counter](const tensorflow::Status& s) { - statuses[i] = s; - counter.DecrementCount(); - }); - } - counter.Wait(); - for (int i = 0; i < existing_workers->size(); i++) { - // If the RPC fails (indicating that the requested ID doesn't exist on - // remote), or the returned view ID is not equal to the local one - // (indicating that the remote worker has a stale view of cluster), treat - // the worker as replaced. - if (!statuses[i].ok() || - responses[i].context_view_id() != context_view_id) { - replaced_workers->emplace_back(existing_workers->at(i)); - } - } - return tensorflow::Status::OK(); -} - -tensorflow::Status CreateRemoteContexts( - TFE_Context* ctx, const std::vector& remote_workers, - tensorflow::uint64 context_id, tensorflow::uint64 context_view_id, - int keep_alive_secs, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, - const bool lazy_copy_remote_function_inputs, - const tensorflow::eager::CreateContextRequest& base_request) { - int num_remote_workers = remote_workers.size(); - tensorflow::BlockingCounter counter(num_remote_workers); - std::vector statuses(num_remote_workers); - for (int i = 0; i < num_remote_workers; i++) { - const string& remote_worker = remote_workers[i]; - tensorflow::DeviceNameUtils::ParsedName parsed_name; - if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, - &parsed_name)) { - statuses[i] = tensorflow::errors::InvalidArgument( - "Unable to parse ", remote_worker, " as a device name"); - counter.DecrementCount(); - continue; - } - - tensorflow::core::RefCountPtr eager_client; - statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); - if (eager_client == nullptr) { - statuses[i] = tensorflow::errors::Internal( - "Cannot find a client for the given target:", remote_worker); - } - if (!statuses[i].ok()) { - counter.DecrementCount(); - continue; - } - - tensorflow::eager::CreateContextRequest request; - tensorflow::eager::CreateContextResponse* response = - new tensorflow::eager::CreateContextResponse(); - request.set_context_id(context_id); - request.set_context_view_id(context_view_id); - *request.mutable_server_def() = server_def; - request.mutable_server_def()->set_job_name(parsed_name.job); - request.mutable_server_def()->set_task_index(parsed_name.task); - request.mutable_server_def()->mutable_default_session_config()->MergeFrom( - server_def.default_session_config()); - - std::vector filtered_device_mask; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->FilterDevicesForRemoteWorkers( - remote_worker, base_request.cluster_device_attributes(), - &filtered_device_mask); - DCHECK_EQ(filtered_device_mask.size(), - base_request.cluster_device_attributes_size()); - for (int i = 0; i < filtered_device_mask.size(); i++) { - if (filtered_device_mask[i]) { - const auto& da = base_request.cluster_device_attributes(i); - *request.add_cluster_device_attributes() = da; - } - } - request.set_async(async); - request.set_keep_alive_secs(keep_alive_secs); - request.set_lazy_copy_remote_function_inputs( - lazy_copy_remote_function_inputs); - - eager_client->CreateContextAsync( - &request, response, - [i, &statuses, &counter, response](const tensorflow::Status& s) { - statuses[i] = s; - delete response; - counter.DecrementCount(); - }); - } - counter.Wait(); - tensorflow::StatusGroup sg; - for (int i = 0; i < num_remote_workers; i++) { - if (TF_PREDICT_FALSE(!statuses[i].ok())) { - sg.Update(statuses[i]); - } - } - return sg.as_summary_status(); -} - -tensorflow::Status UpdateRemoteContexts( - TFE_Context* ctx, const std::vector& remote_workers, - const std::vector& added_workers, - const std::vector& removed_workers, tensorflow::uint64 context_id, - tensorflow::uint64 context_view_id, const tensorflow::ServerDef& server_def, - tensorflow::eager::EagerClientCache* remote_eager_workers, - const tensorflow::eager::CreateContextRequest& base_request) { - int num_remote_workers = remote_workers.size(); - tensorflow::BlockingCounter counter(num_remote_workers); - std::vector statuses(num_remote_workers); - - int cluster_device_count = base_request.cluster_device_attributes_size(); - std::unordered_set added_or_removed(added_workers.begin(), - added_workers.end()); - std::copy(removed_workers.begin(), removed_workers.end(), - std::inserter(added_or_removed, added_or_removed.end())); - // Whether each device is in the updated (added or removed) workers - std::vector device_added_or_removed(cluster_device_count); - for (int i = 0; i < base_request.cluster_device_attributes_size(); i++) { - const auto& da = base_request.cluster_device_attributes().at(i); - tensorflow::DeviceNameUtils::ParsedName pn; - tensorflow::DeviceNameUtils::ParseFullName(da.name(), &pn); - string task_name; - tensorflow::DeviceNameUtils::GetTaskName(pn, &task_name); - if (added_or_removed.find(task_name) != added_or_removed.end()) { - device_added_or_removed[i] = true; - } - } - - for (int i = 0; i < num_remote_workers; i++) { - const string& remote_worker = remote_workers[i]; - tensorflow::DeviceNameUtils::ParsedName parsed_name; - if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, - &parsed_name)) { - statuses[i] = tensorflow::errors::InvalidArgument( - "Unable to parse ", remote_worker, " as a device name"); - counter.DecrementCount(); - continue; - } - - tensorflow::core::RefCountPtr eager_client; - statuses[i] = remote_eager_workers->GetClient(remote_worker, &eager_client); - if (eager_client == nullptr) { - statuses[i] = tensorflow::errors::Internal( - "Cannot find a client for the given target:", remote_worker); - } - if (!statuses[i].ok()) { - counter.DecrementCount(); - continue; - } - - std::vector filtered_device_mask; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - context->FilterDevicesForRemoteWorkers( - remote_worker, base_request.cluster_device_attributes(), - &filtered_device_mask); - DCHECK_EQ(filtered_device_mask.size(), cluster_device_count); - - // If any of the devices that match the device filters are in the set of - // added or removed workers, we must send a complete UpdateContextRequest. - // Otherwise, only send a simple request to increment context view ID. - std::vector added_or_removed_filtered_devices(cluster_device_count); - std::transform(device_added_or_removed.begin(), - device_added_or_removed.end(), filtered_device_mask.begin(), - added_or_removed_filtered_devices.begin(), - std::logical_and()); - const bool full_update_request = - std::accumulate(added_or_removed_filtered_devices.begin(), - added_or_removed_filtered_devices.end(), false, - std::logical_or()); - - tensorflow::eager::UpdateContextRequest request; - auto* response = new tensorflow::eager::UpdateContextResponse(); - request.set_context_id(context_id); - request.set_context_view_id(context_view_id); - if (full_update_request) { - *request.mutable_server_def() = server_def; - request.mutable_server_def()->set_job_name(parsed_name.job); - request.mutable_server_def()->set_task_index(parsed_name.task); - request.mutable_server_def()->mutable_default_session_config()->MergeFrom( - server_def.default_session_config()); - for (int i = 0; i < cluster_device_count; i++) { - if (filtered_device_mask[i]) { - const auto& da = base_request.cluster_device_attributes(i); - *request.add_cluster_device_attributes() = da; - } - } - } - - eager_client->UpdateContextAsync( - &request, response, - [i, &statuses, &counter, response](const tensorflow::Status& s) { - statuses[i] = s; - delete response; - counter.DecrementCount(); - }); - } - counter.Wait(); - for (int i = 0; i < num_remote_workers; i++) { - TF_RETURN_IF_ERROR(statuses[i]); - } - return tensorflow::Status::OK(); -} - -tensorflow::Status UpdateTFE_ContextWithServerDef( - int keep_alive_secs, const tensorflow::ServerDef& server_def, - TFE_Context* ctx, bool reset_context) { - // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the - // server object (which currently CHECK-fails) and we miss the error, instead, - // we log the error, and then return to allow the user to see the error - // message. -#define LOG_AND_RETURN_IF_ERROR(...) \ - do { \ - const ::tensorflow::Status _status = (__VA_ARGS__); \ - if (TF_PREDICT_FALSE(!_status.ok())) { \ - LOG(ERROR) << _status.error_message(); \ - return _status; \ - } \ - } while (0); - - string worker_name = - tensorflow::strings::StrCat("/job:", server_def.job_name(), - "/replica:0/task:", server_def.task_index()); - - // List of current remote workers before updating server_def. Unused if - // resetting the server_def. - std::vector curr_remote_workers; - // List of updated remote workers. - std::vector remote_workers; - - // New server created for new server_def. Unused if updating server_def. - std::unique_ptr new_server; - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server; - if (reset_context) { - const tensorflow::DeviceMgr* device_mgr = - AreLocalDevicesCompatible(context, server_def) - ? context->local_device_mgr() - : nullptr; - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( - server_def, {device_mgr}, &new_server)); - grpc_server = dynamic_cast(new_server.get()); - LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); - } else { - LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, - &curr_remote_workers)); - // No need to check the cast here, since `ListRemoteWorkers` already checks - // if the server is a GRPC server or not. - grpc_server = dynamic_cast(context->GetServer()); - LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); - LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); - } - - tensorflow::uint64 context_id = context->GetContextId(); - tensorflow::uint64 context_view_id = context->GetContextViewId(); - if (reset_context) { - context_id = tensorflow::EagerContext::NewContextId(); - context_view_id = 0; - // Make master eager context accessible by local eager service, which might - // receive send tensor requests from remote workers. - LOG_AND_RETURN_IF_ERROR( - grpc_server->AddMasterEagerContextToEagerService(context_id, context)); - } - - std::unique_ptr remote_eager_workers; - LOG_AND_RETURN_IF_ERROR( - grpc_server->master_env()->worker_cache->GetEagerClientCache( - &remote_eager_workers)); - - // For cluster update, use a status group to aggregate statuses from - // * adding and removing remote devices - // * creating remote contexts on newly added workers - // * updating remote contexts on existing workers - // * updating the master context - // Note that we should not return immediately on errors in the middle of these - // updates to prevent cluster from having inconsistent context views. - // - // Unused if `reset_context` is True. - tensorflow::StatusGroup sg; - - // When updating an existing context, populate the following lists with: - // * added_workers: set(remote_workers) - set(curr_remote_workers) - // * removed_workers: set(curr_remote_workers) - set(remote_workers) - // * existing_workers: set(curr_remote_workers) intersect set(remote_workers) - // * replaced_workers: workers with the same task names and potentially the - // same `hostname:port`s, but replaced by different processes - std::vector added_workers; - std::vector removed_workers; - std::vector existing_workers; - std::vector replaced_workers; - - // New remote device manager created for new server_def. Unused if updating - // server_def. - std::unique_ptr new_remote_device_mgr; - tensorflow::DynamicDeviceMgr* remote_device_mgr = nullptr; - if (reset_context) { - LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices( - remote_workers, grpc_server->master_env()->worker_cache, - &new_remote_device_mgr)); - remote_device_mgr = new_remote_device_mgr.get(); - } else { - context->ClearCachesAndDefaultExecutor(); - // TODO(b/143914772): Potential memory leak if rendezvous has pending - // tensors for removed / replaced workers. - - remote_device_mgr = context->GetOwnedRemoteDeviceMgr(); - if (remote_device_mgr == nullptr) { - LOG_AND_RETURN_IF_ERROR(tensorflow::errors::InvalidArgument( - "Updating context with an invalid set of remote devices.")); - } - std::sort(curr_remote_workers.begin(), curr_remote_workers.end()); - std::sort(remote_workers.begin(), remote_workers.end()); - DifferentiateWorkerLists(&curr_remote_workers, &remote_workers, - &added_workers, &removed_workers, - &existing_workers); - sg.Update(GetReplacedFromExistingWorkers( - &existing_workers, context_id, context->GetContextViewId(), server_def, - remote_eager_workers.get(), &replaced_workers)); - if (VLOG_IS_ON(1)) { - VLOG(1) << "Updating cluster with following changes"; - for (const string& w : added_workers) VLOG(1) << " Added worker " << w; - for (const string& w : removed_workers) - VLOG(1) << " Removed worker " << w; - for (const string& w : replaced_workers) - VLOG(1) << " Replaced worker " << w; - } - if (!replaced_workers.empty()) { - // Treat replaced workers as removed then added back, so that we recreate - // remote devices and contexts, and re-register functions on those workers - removed_workers.insert(removed_workers.end(), replaced_workers.begin(), - replaced_workers.end()); - added_workers.insert(added_workers.end(), replaced_workers.begin(), - replaced_workers.end()); - for (const string& w : replaced_workers) { - existing_workers.erase( - std::remove(existing_workers.begin(), existing_workers.end(), w), - existing_workers.end()); - } - } - sg.Update(RemoveRemoteDevicesFromMgr(removed_workers, remote_device_mgr)); - sg.Update(AddRemoteDevicesToMgr(added_workers, - grpc_server->master_env()->worker_cache, - remote_device_mgr)); - } - - std::vector cluster_device_attributes; - remote_device_mgr->ListDeviceAttributes(&cluster_device_attributes); - - std::vector local_device_attributes; - grpc_server->worker_env()->device_mgr->ListDeviceAttributes( - &local_device_attributes); - - // This request make sure that we can create Rendezvous properly between - // Local and Remote context. - tensorflow::eager::CreateContextRequest base_request; - for (const auto& da : cluster_device_attributes) { - *base_request.add_cluster_device_attributes() = da; - } - for (const auto& da : local_device_attributes) { - *base_request.add_cluster_device_attributes() = da; - } - - // Initialize remote eager workers. - if (reset_context) { - const tensorflow::Status s = CreateRemoteContexts( - ctx, remote_workers, context_id, context_view_id, keep_alive_secs, - server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request); - // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause - // the CreateRemoteContexts to fail. We currently only log instead of - // directly returning the error, since returning here will cause the server - // object to be destroyed (which currently CHECK-fails). The client will - // see additional errors if ops are subsequently sent to the failed workers. - if (TF_PREDICT_FALSE(!s.ok())) { - LOG(ERROR) << "Error when creating contexts on remote targets: " - << s.error_message() - << "\nExecuting remote ops or functions on these remote " - "targets will fail."; - } - } else { - if (sg.ok()) { - // Create remote contexts on the newly added workers only if the master - // has collected all device information from them (i.e., the - // GetAllRemoteDevices call returns succussfully). Note that in rare cases - // GetAllRemoteDevices can still fail even with RPCs configured to wait - // until the remote workers to become alive. If the master creates remote - // contexts on the workers whose devices are still not collected, those - // workers will be treated as existing workers subsequently, so the master - // will never get devices from them even with retrying UpdateServerDef. - sg.Update(CreateRemoteContexts( - ctx, added_workers, context_id, context_view_id + 1, keep_alive_secs, - server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request)); - } - if (!existing_workers.empty()) { - if (VLOG_IS_ON(1)) { - for (const string& w : existing_workers) { - VLOG(1) << "Updating cluster with existing worker " << w; - } - } - // The master's context_view_id will be incremented by one in the - // UpdateRemoteMaster call later. We want existing workers to also have - // the updated context_view_id, so we must set their context_view_id to - // the master's current context_view_id + 1. - sg.Update(UpdateRemoteContexts(ctx, existing_workers, added_workers, - removed_workers, context_id, - context_view_id + 1, server_def, - remote_eager_workers.get(), base_request)); - } - } - - auto session_name = tensorflow::strings::StrCat("eager_", context_id); - if (reset_context) { - tensorflow::RemoteRendezvous* r = - grpc_server->worker_env()->rendezvous_mgr->Find(context_id); - auto* device_mgr = grpc_server->worker_env()->device_mgr; - std::shared_ptr worker_session; - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->CreateSession( - session_name, server_def, base_request.cluster_device_attributes(), - true)); - LOG_AND_RETURN_IF_ERROR( - grpc_server->worker_env()->session_mgr->WorkerSessionForSession( - session_name, &worker_session)); - - // Initialize remote tensor communication based on worker session. - LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); - - tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = - tensorflow::eager::CreateClusterFLR(context_id, context, - worker_session.get()); - auto remote_mgr = absl::make_unique( - /*is_master=*/true, context); - - LOG_AND_RETURN_IF_ERROR(context->InitializeRemoteMaster( - std::move(new_server), grpc_server->worker_env(), worker_session, - std::move(remote_eager_workers), std::move(new_remote_device_mgr), - remote_workers, context_id, r, device_mgr, keep_alive_secs, cluster_flr, - std::move(remote_mgr))); - - // NOTE: We start the server after all other initialization, because the - // GrpcServer cannot be destroyed after it is started. - LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - } else { - sg.Update(grpc_server->worker_env()->session_mgr->UpdateSession( - session_name, server_def, base_request.cluster_device_attributes(), - /*isolate_session_state=*/true)); - sg.Update(context->UpdateRemoteMaster(context_id, - std::move(remote_eager_workers), - added_workers, removed_workers)); - LOG_AND_RETURN_IF_ERROR(sg.as_summary_status()); - } -#undef LOG_AND_RETURN_IF_ERROR - - return tensorflow::Status::OK(); +// Annotate eager runtime construction context to the given `function_def` as +// an attribute. +void AnnotateEagerRuntimeConstructionContext( + tensorflow::FunctionDef& function_def) { + tensorflow::AttrValue value; + SetAttrValue("kEagerRuntime", &value); + (*function_def.mutable_attr())["_construction_context"] = value; } -#endif // !IS_MOBILE_PLATFORM } // namespace @@ -731,11 +115,21 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #if defined(PLATFORM_GOOGLE) && !defined(LIBTPU_ON_GCE) - return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); + tfrt::tf::ContextInterface* tfrt_context = new tfrt::tf::ContextInterface( + opts->session_options.options, + static_cast( + opts->device_placement_policy), + opts->async); +#if !defined(IS_MOBILE_PLATFORM) + tfrt_context->SetDistributedManager( + tfrt::tf::CreateDistributedManagerContext( + tfrt_context->GetCoreRuntime()->GetHostContext())); +#endif // !IS_MOBILE_PLATFORM + return tensorflow::wrap(tfrt_context); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; -#endif +#endif // PLATFORM_GOOGLE && !LIBTPU_ON_GCE } std::vector> devices; status->status = tensorflow::DeviceFactory::AddDevices( @@ -747,13 +141,18 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); - - return tensorflow::wrap(new tensorflow::EagerContext( + tensorflow::EagerContext* eager_context = new tensorflow::EagerContext( opts->session_options.options, static_cast( opts->device_placement_policy), - opts->async, opts->lazy_remote_inputs_copy, device_mgr.release(), - /*device_mgr_owned*/ true, r)); + opts->async, device_mgr.release(), + /*device_mgr_owned*/ true, r); +#if !defined(IS_MOBILE_PLATFORM) + eager_context->SetDistributedManager( + std::make_unique( + eager_context)); +#endif // !IS_MOBILE_PLATFORM + return tensorflow::wrap(eager_context); } void TFE_DeleteContext(TFE_Context* ctx) { @@ -791,26 +190,9 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, "Invalid tensorflow.ServerDef protocol buffer"); return; } - if (server_def.has_cluster_device_filters()) { - const auto& cdf = server_def.cluster_device_filters(); - for (const auto& jdf : cdf.jobs()) { - const string remote_prefix = "/job:" + jdf.name() + "/task:"; - for (const auto& tdf : jdf.tasks()) { - const int32_t task_index = tdf.first; - std::vector device_filters(tdf.second.device_filters_size()); - for (int i = 0; i < tdf.second.device_filters_size(); i++) { - device_filters[i] = tdf.second.device_filters(i); - } - const string remote_worker = remote_prefix + std::to_string(task_index); - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = - context->SetRemoteDeviceFilters(remote_worker, device_filters); - } - } - } - status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, - ctx, /*reset_context=*/true); + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( + server_def, /*reset_context=*/true, keep_alive_secs); #endif // !IS_MOBILE_PLATFORM } @@ -835,14 +217,9 @@ TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, status->status = tensorflow::errors::InvalidArgument( "Trying to update a context with invalid context id."); } - if (server_def.has_cluster_device_filters()) { - LOG(WARNING) << "Device filters can only be specified when initializing " - "the cluster. Any changes in device filters are ignored " - "when updating the server def."; - } - // TODO(haoyuzhang): Check server_def compatibility before the update - status->status = UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, - ctx, /*reset_context=*/false); + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->SetOrUpdateServerDef( + server_def, /*reset_context=*/false, keep_alive_secs); #endif // !IS_MOBILE_PLATFORM } @@ -854,44 +231,11 @@ TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, "TFE_ContextSetServerDef not supported on mobile"); return false; #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - tensorflow::GrpcServer* grpc_server = - dynamic_cast(context->GetServer()); - if (grpc_server == nullptr) { - status->status = - tensorflow::errors::Internal("Failed to get tensorflow::GrpcServer."); - return false; - } - tensorflow::WorkerInterface* wi = - grpc_server->master_env()->worker_cache->GetOrCreateWorker(worker_name); - if (wi == nullptr) { - status->status = tensorflow::errors::InvalidArgument( - "Unable to find worker interface corresponding to task ", worker_name); - return false; - } - - tensorflow::GetStatusRequest request; - tensorflow::GetStatusResponse response; - tensorflow::Status remote_status; - tensorflow::Notification done; - wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, - [&remote_status, &done](const tensorflow::Status& s) { - remote_status = s; - done.Notify(); - }); - done.WaitForNotification(); - - // We set OK status so the call does not raise any exceptions. Instead, caller - // users the return value to tell if the remote worker is alive. - status->status = tensorflow::Status::OK(); - - if (remote_status.ok()) { - return true; - } - LOG(INFO) << "Remote worker " << worker_name - << " is not alive: " << remote_status.error_message(); - return false; + bool is_alive; + status->status = + tensorflow::unwrap(ctx)->GetDistributedManager()->CheckRemoteAlive( + worker_name, &is_alive); + return is_alive; #endif // !IS_MOBILE_PLATFORM } @@ -1022,13 +366,21 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); - if (VariantDeviceIsCustom(handle->device())) { - const tensorflow::Tensor* t; - status->status = handle->Tensor(&t); - return t->data(); + tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle = + tensorflow::unwrap(h); + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. + if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) { + return tensorflow::down_cast( + unwrapped_handle) + ->DevicePointer(); } + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. + if (!tensorflow::TensorHandle::classof(unwrapped_handle)) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return nullptr; + } + tensorflow::TensorHandle* handle = + tensorflow::TensorHandleFromInterface(unwrapped_handle); if (handle->Type() != tensorflow::TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( @@ -1036,7 +388,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { handle->TypeString(), " tensor handle."); return nullptr; } - tensorflow::Device* device(absl::get(handle->device())); + tensorflow::Device* device(handle->device()); if (device != nullptr) { status->status = device->Sync(); if (!status->status.ok()) { @@ -1052,6 +404,153 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { static_cast(tensor->tensor_data().data())); } +namespace tensorflow { +namespace { +class CustomDeviceAPI : public tensorflow::CustomDevice { + public: + CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info, + string name) + : context_(context), device_(device), info_(info), name_(name) {} + + ~CustomDeviceAPI() override { device_.delete_device(info_); } + + const string& name() override { return name_; } + + tensorflow::Status CopyTensorToDevice( + ImmediateExecutionTensorHandle* handle, + ImmediateExecutionTensorHandle** result) override { + handle->Ref(); + TF_Status status; + TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( + context_, tensorflow::wrap(handle), &status, info_); + handle->Release(); + if (!status.status.ok()) return status.status; + *result = tensorflow::unwrap(result_handle); + (*result)->Ref(); + TFE_DeleteTensorHandle(result_handle); + return status.status; + } + + tensorflow::Status CopyTensorFromDevice( + ImmediateExecutionTensorHandle* handle, + const tensorflow::string& target_device_name, + ImmediateExecutionTensorHandle** result) override { + TF_Status status; + handle->Ref(); + TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( + context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, + info_); + handle->Release(); + if (!status.status.ok()) return status.status; + *result = tensorflow::unwrap(result_handle); + (*result)->Ref(); + TFE_DeleteTensorHandle(result_handle); + return status.status; + } + + tensorflow::Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) override { + std::vector outputs(*num_retvals); + TF_Status status; + device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, + info_); + if (status.status.ok()) { + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = tensorflow::unwrap(outputs[i]); + retvals[i]->Ref(); + TFE_DeleteTensorHandle(outputs[i]); + } + } + return status.status; + } + + tensorflow::Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) override { + TF_Status status; + *result = tensorflow::unwrap(device_.pack(context_, + tensorflow::wrap(handles.data()), + handles.size(), &status, info_)); + return status.status; + } + + private: + TFE_Context* context_; + TFE_CustomDevice device_; + void* info_; + string name_; +}; + +// An adapter which wraps the shape/data produced by C custom devices and uses +// it to implement custom device methods. +class CAPICustomDeviceTensorHandle + : public tensorflow::CustomDeviceTensorHandle { + public: + CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context, + tensorflow::CustomDevice* device, + tensorflow::DataType dtype, void* data, + TFE_CustomDeviceTensorHandleMethods methods) + : tensorflow::CustomDeviceTensorHandle(context, device, dtype), + data_(data), + methods_(methods) {} + + ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); } + void* DevicePointer() const override { return data_; } + Status NumDims(int* num_dims) const override { + TF_Status s; + *num_dims = methods_.num_dims(data_, &s); + return s.status; + } + Status Dim(int dim_index, int64* dim) const override { + TF_Status s; + *dim = methods_.dim(data_, dim_index, &s); + return s.status; + } + + bool HasCustomSummarizer() const override { + return methods_.summarize != nullptr; + } + + Status SummarizeValue(std::string& summary) const override { + if (methods_.summarize == nullptr) { + return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary); + } + TF_Status c_status; + std::unique_ptr summary_buffer( + methods_.summarize(data_, &c_status), TF_DeleteBuffer); + if (!c_status.status.ok()) { + return c_status.status; + } + summary = std::string(reinterpret_cast(summary_buffer->data), + summary_buffer->length); + return Status::OK(); + } + + private: + void* const data_; + const TFE_CustomDeviceTensorHandleMethods methods_; +}; + +} // namespace +} // namespace tensorflow + +TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( + TFE_Context* ctx, const char* device_name, TF_DataType dtype, void* data, + TFE_CustomDeviceTensorHandleMethods methods, TF_Status* status) { + tensorflow::ImmediateExecutionContext* context = tensorflow::unwrap(ctx); + tensorflow::CustomDevice* device = nullptr; + if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name, + &device)) { + methods.deallocator(data); + status->status = + tensorflow::errors::InvalidArgument(device_name, " unknown device."); + return nullptr; + } + return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle( + context, device, *reinterpret_cast(&dtype), data, + methods)); +} + TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( TFE_Context* ctx, const char* device_name, TF_DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, @@ -1061,16 +560,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = context->FindDeviceFromName(device_name, &device); - tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { - if (!context->FindCustomDeviceFromName(device_name, &custom_device)) { - deallocator(data, len, deallocator_arg); - status->status = - tensorflow::errors::InvalidArgument(device_name, " unknown device."); - return nullptr; - } else { - status->status = tensorflow::Status::OK(); - } + deallocator(data, len, deallocator_arg); + status->status = + tensorflow::errors::InvalidArgument(device_name, " unknown device."); + return nullptr; } std::vector dimvec(num_dims); for (int i = 0; i < num_dims; ++i) { @@ -1086,13 +580,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::Tensor t(static_cast(dtype), tensorflow::TensorShape(dimvec), buf); buf->Unref(); - if (custom_device == nullptr) { - return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), device, device, context)); - } else { - return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), custom_device, context)); - } + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), device, device, context)); } // This function will block till the operation that produces `h` has @@ -1145,8 +634,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { } TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { - return tensorflow::wrap( - &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); + return tensorflow::wrap(tensorflow::unwrap(op)->GetContext()); } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { @@ -1380,11 +868,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - status->status = tensorflow::unwrap(op)->Execute( - absl::MakeSpan(reinterpret_cast( - tensorflow::unwrap(retvals)), - *num_retvals), - num_retvals); + tensorflow::ImmediateExecutionOperation* unwrapped_op = + tensorflow::unwrap(op); + + status->status = + unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute( + unwrapped_op, + reinterpret_cast( + retvals), + num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, @@ -1396,8 +888,13 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, return nullptr; } - auto* result = tensorflow::unwrap(ctx)->CopyTensorHandleToDevice( - tensorflow::unwrap(h), device_name, &status->status); + tensorflow::ImmediateExecutionContext* unwrapped_ctx = + tensorflow::unwrap(ctx); + + auto* result = + unwrapped_ctx->GetCustomDeviceOpHandler().CopyTensorHandleToDevice( + unwrapped_ctx, tensorflow::unwrap(h), device_name, &status->status); + if (status->status.ok()) { return tensorflow::wrap(result); } @@ -1413,12 +910,16 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } + + AnnotateEagerRuntimeConstructionContext(function_def); status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef); + AnnotateEagerRuntimeConstructionContext(function->fdef); + status->status = tensorflow::unwrap(ctx)->AddFunctionDefWithStackTraces( + function->fdef, function->stack_traces); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, @@ -1447,13 +948,11 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->Executor().WaitForAllPendingNodes(); + auto* context = tensorflow::unwrap(ctx); + status->status = context->AsyncWait(); if (!status->status.ok()) return; - tensorflow::mutex_lock ml(*context->MetadataMu()); - status->status = MessageToBuffer(*context->RunMetadataProto(), buf); - context->ClearRunMetadata(); + auto run_metadata = context->ExportRunMetadata(); + status->status = MessageToBuffer(*run_metadata, buf); } namespace { @@ -1478,22 +977,17 @@ void TFE_ContextEndStep(TFE_Context* ctx) { } const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op) { - return tensorflow::wrap( - &OperationFromInterface(tensorflow::unwrap(op))->Attrs()); + return tensorflow::wrap(tensorflow::unwrap(op)->GetOpAttrs()); } void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { - tensorflow::EagerOperation* operation = - OperationFromInterface(tensorflow::unwrap(op)); - tensorflow::AttrBuilder* destination = operation->MutableAttrs(); - destination->CopyAttributes(*tensorflow::unwrap(attrs)); + tensorflow::unwrap(op)->AddAttrs(tensorflow::unwrap(attrs)); } void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, TF_Status* status) { tensorflow::NameAttrList name_and_attrs; - tensorflow::unwrap(attrs)->FillAttrValueMap(name_and_attrs.mutable_attr()); - name_and_attrs.set_name(tensorflow::unwrap(attrs)->op_name()); + tensorflow::unwrap(attrs)->GetNameAttrList(&name_and_attrs); status->status = MessageToBuffer(name_and_attrs, buf); } @@ -1618,74 +1112,14 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } // namespace tensorflow namespace { -class CustomDeviceAPI : public tensorflow::CustomDevice { - public: - CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info, - string name) - : context_(context), device_(device), info_(info), name_(name) {} - - ~CustomDeviceAPI() override { device_.delete_device(info_); } - - const string& name() override { return name_; } - - tensorflow::Status CopyTensorToDevice( - tensorflow::TensorHandle* handle, - tensorflow::TensorHandle** result) override { - handle->Ref(); - TF_Status status; - TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( - context_, tensorflow::wrap(handle), &status, info_); - handle->Release(); - if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface( - tensorflow::unwrap(result_handle)); - (*result)->Ref(); - TFE_DeleteTensorHandle(result_handle); - return status.status; - } - - tensorflow::Status CopyTensorFromDevice( - tensorflow::TensorHandle* handle, - const tensorflow::string& target_device_name, - tensorflow::TensorHandle** result) override { - TF_Status status; - handle->Ref(); - TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( - context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, - info_); - handle->Release(); - if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface( - tensorflow::unwrap(result_handle)); - (*result)->Ref(); - TFE_DeleteTensorHandle(result_handle); - return status.status; - } - - tensorflow::Status Execute(const tensorflow::EagerOperation* op, - tensorflow::TensorHandle** retvals, - int* num_retvals) override { - std::vector outputs(*num_retvals); - TF_Status status; - device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, - info_); - if (status.status.ok()) { - for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = tensorflow::TensorHandleFromInterface( - tensorflow::unwrap(outputs[i])); - retvals[i]->Ref(); - TFE_DeleteTensorHandle(outputs[i]); - } - } - return status.status; - } - - private: - TFE_Context* context_; - TFE_CustomDevice device_; - void* info_; - string name_; -}; +TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context, + TFE_TensorHandle** handles, + int num_handles, TF_Status* status, + void* device_info) { + TF_SetStatus(status, TF_UNIMPLEMENTED, + "This custom device does not support packing tensors."); + return nullptr; +} } // namespace extern "C" { @@ -1693,12 +1127,14 @@ extern "C" { void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, const char* device_name, void* device_info, TF_Status* status) { - auto custom_device = - std::make_unique(ctx, device, device_info, device_name); - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = - context->RegisterCustomDevice(device_name, std::move(custom_device)); + // Fill in default values for optional functionality. + if (device.pack == nullptr) { + device.pack = &DefaultCustomDevicePack; + } + auto custom_device = std::make_unique( + ctx, device, device_info, device_name); + status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice( + device_name, std::move(custom_device)); } } // extern "C" diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 1ef536a66f6c51..a2ec468d44b2d5 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -482,41 +482,34 @@ TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( static_cast(sampler->sampler->GetCell(label1, label2))); } -void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options, - bool lazy_copy) { - options->lazy_remote_inputs_copy = lazy_copy; -} - void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) { options->use_tfrt = use_tfrt; } TFE_CancellationManager* TFE_NewCancellationManager() { - return new TFE_CancellationManager; + return tensorflow::wrap(new tensorflow::CancellationManager); } void TFE_CancellationManagerStartCancel( TFE_CancellationManager* cancellation_manager) { - cancellation_manager->cancellation_manager.StartCancel(); + tensorflow::unwrap(cancellation_manager)->StartCancel(); } bool TFE_CancellationManagerIsCancelled( TFE_CancellationManager* cancellation_manager) { - return cancellation_manager->cancellation_manager.IsCancelled(); + return tensorflow::unwrap(cancellation_manager)->IsCancelled(); } void TFE_DeleteCancellationManager( TFE_CancellationManager* cancellation_manager) { - delete cancellation_manager; + delete tensorflow::unwrap(cancellation_manager); } void TFE_OpSetCancellationManager(TFE_Op* op, TFE_CancellationManager* cancellation_manager, TF_Status* status) { - tensorflow::EagerOperation* operation = - tensorflow::OperationFromInterface(tensorflow::unwrap(op)); - operation->SetCancellationManager( - &cancellation_manager->cancellation_manager); + tensorflow::unwrap(op)->SetCancellationManager( + tensorflow::unwrap(cancellation_manager)); status->status = tensorflow::Status::OK(); } @@ -618,8 +611,23 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, std::vector tensor_handles; tensor_handles.reserve(*num_handles); for (int i = 0; i < *num_handles; ++i) { + tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle = + tensorflow::unwrap(handles[i]); + if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) { + // One of the inputs we're trying to pack is on a custom device. We'll let + // the first custom device we see handle all of the packing. + auto* custom_device_handle = + tensorflow::down_cast( + unwrapped_handle); + tensorflow::ImmediateExecutionTensorHandle* result; + status->status = custom_device_handle->device()->Pack( + absl::Span( + tensorflow::unwrap(handles), *num_handles), + &result); + return tensorflow::wrap(result); + } tensor_handles.push_back( - tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i]))); + tensorflow::TensorHandleFromInterface(unwrapped_handle)); } tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); @@ -654,3 +662,23 @@ int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) { } return tensorflow::unwrap(h)->DeviceId(&status->status); } + +void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + const std::vector& op_names = + tensorflow::unwrap(ctx)->GetLoggedOpsTestonly(); + + std::ostringstream op_names_oss; + for (const auto& op : op_names) { + op_names_oss << op << ", "; + } + const std::string& op_names_str = op_names_oss.str(); + void* data = tensorflow::port::Malloc(op_names_str.length()); + op_names_str.copy(static_cast(data), op_names_str.length(), 0); + buf->data = data; + buf->length = op_names_str.length(); + buf->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + status->status = tensorflow::Status::OK(); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index d0739a5437df0f..8c97904c44dc23 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -265,10 +265,6 @@ TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); -// Sets whether to copy the remote inputs of a function lazily. -TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( - TFE_ContextOptions*, bool lazy_copy); - // Sets whether to use TFRT TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, bool use_tfrt); @@ -388,9 +384,11 @@ TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*, TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*, TF_Status*); -// Creates a new TensorHandle from memory residing in device_name. Takes -// ownership of the memory, and will call deleter to release it after TF -// no longer needs it or in case of error. +// Creates a new TensorHandle from memory residing in the physical device +// device_name. Takes ownership of the memory, and will call deleter to release +// it after TF no longer needs it or in case of error. +// +// Custom devices must use TFE_NewCustomDeviceTensorHandle instead. TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims, int num_dims, void* data, size_t len, @@ -439,16 +437,16 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, // to have a non-string representation of devices (TF_Device) extracted from // tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. -#define TFE_CUSTOM_DEVICE_VERSION 3 +#define TFE_CUSTOM_DEVICE_VERSION 4 -// Struct to be filled in +// Struct to be filled in. Functions are required except where indicated. typedef struct TFE_CustomDevice { int version = TFE_CUSTOM_DEVICE_VERSION; // Method to copy a tensor to the custom device. TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status, - void* device_info) = nullptr; + void* device_info); // Method to copy a tensor from the custom device to a target device. TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context, @@ -472,6 +470,16 @@ typedef struct TFE_CustomDevice { // Method to delete a device. void (*delete_device)(void* device_info); + + // Implements TFE_CreatePackedTensorHandle when one of `handles` is on this + // custom device. + // + // Many devices will want to simply return an "unimplemented" status + // here. This is the default behavior if `pack` is null when passed to + // TFE_RegisterCustomDevice. + TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles, + int num_handles, TF_Status* s, + void* device_info) = nullptr; } TFE_CustomDevice; // Registers a custom device for use with eager execution. @@ -481,7 +489,7 @@ typedef struct TFE_CustomDevice { // "/job:localhost/replica:0/task:0/device:CUSTOM:0". // // The custom device defines copy operations for moving TensorHandles on and -// off, and an an execution operation for named operations. Often execution will +// off, and an execution operation for named operations. Often execution will // simply wrap op execution on one or more physical devices. // // device_info is an opaque caller-defined type stored with the custom device @@ -511,6 +519,48 @@ TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx, void* device_info, TF_Status* status); +// Struct to be filled in to define a custom device tensor handle. Fields are +// required except where indicated. +typedef struct TFE_CustomDeviceTensorHandleMethods { + int version = TFE_CUSTOM_DEVICE_VERSION; + + // Computes the rank of the tensor handle. + // + // Shapes are specified via callbacks because retrieving the shape of a tensor + // is a blocking operation for async eager; custom devices should avoid + // retrieving shapes of tensors they wrap until the custom device tensor's + // shape is explicitly requested where possible. + int (*num_dims)(void* data, TF_Status* status); + + // Computes the axis length at `dim_index`. + int64_t (*dim)(void* data, int dim_index, TF_Status* status); + + void (*deallocator)(void* data); + + // Summarizes the value of this tensor. The caller takes ownership of the + // returned buffer. If `status` is not TF_OK, instead returns a null pointer. + // + // Does not include the shape and dtype of the tensor (which is generally + // appended later), but should include any information specific to this custom + // device which would be useful for debugging. + // + // Optional. If null, defaults to resolving the TFE_TensorHandle into a + // TF_Tensor and summarizing that. + TF_Buffer* (*summarize)(void* data, TF_Status* status) = nullptr; +} TFE_CustomDeviceTensorHandle; + +// Creates a new TensorHandle from memory residing in a custom device. Takes +// ownership of the memory pointed to by `tensor_handle_data`, and calls +// `methods.deallocator` to release it after TF no longer needs it or in case of +// an error. +// +// This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but supports +// custom devices instead of physical devices and does not require blocking +// waiting for exact shapes. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( + TFE_Context*, const char* device_name, TF_DataType, void* data, + TFE_CustomDeviceTensorHandle methods, TF_Status* status); + TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, TF_Buffer* buf, @@ -561,6 +611,13 @@ TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status); +// Get a comma-separated list of op names executed in graph functions dispatched +// to `ctx`. This feature is currently only enabled for TFRT debug builds, for +// performance and simplicity reasons. +TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 4fe83b5116da77..c1949ae826fa03 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/cc/profiler/profiler.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 356476c218620c..450e1a66062f01 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -32,8 +32,6 @@ struct TFE_ContextOptions { bool async = false; TFE_ContextDevicePlacementPolicy device_placement_policy{ TFE_DEVICE_PLACEMENT_SILENT}; - // If true, lazily copy the remote inputs of a function to the target devices. - bool lazy_remote_inputs_copy = true; // If true, use TFRT backend bool use_tfrt = false; }; diff --git a/tensorflow/c/eager/c_api_remote_function_test.cc b/tensorflow/c/eager/c_api_remote_function_test.cc index a9bbd5b694f2fa..45e8302c248775 100644 --- a/tensorflow/c/eager/c_api_remote_function_test.cc +++ b/tensorflow/c/eager/c_api_remote_function_test.cc @@ -20,10 +20,11 @@ namespace { void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote, bool heavy_load_on_streaming_rpc, - bool remote_func_outputs = false) { + bool remote_func_outputs = false, + bool has_packed_input = false) { return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true, heavy_load_on_streaming_rpc, - remote_func_outputs); + remote_func_outputs, has_packed_input); } TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { @@ -60,5 +61,14 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, /*heavy_load_on_streaming_rpc=*/true); } +TEST(CAPI, RemoteExecuteSilentCopiesRemoteAsyncPackedInputFuncOrdering) { + // A remote input (packed) may be not ready when we start running a function. + // Test that the function execution should wait until the remote input is + // ready. + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true, + /*heavy_load_on_streaming_rpc=*/true, + /*remote_func_outputs*/ true, + /*has_packed_input=*/true); +} } // namespace diff --git a/tensorflow/c/eager/c_api_remote_test_util.cc b/tensorflow/c/eager/c_api_remote_test_util.cc index 159fa442a73cff..beb1baf3fe63fd 100644 --- a/tensorflow/c/eager/c_api_remote_test_util.cc +++ b/tensorflow/c/eager/c_api_remote_test_util.cc @@ -68,7 +68,9 @@ string MatMulFunction(const string& matmul_device) { void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, bool heavy_load_on_streaming_rpc, - bool remote_func_outputs) { + bool remote_func_outputs, + bool has_packed_input) { + CHECK(!has_packed_input || func); tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. @@ -123,6 +125,15 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TFE_TensorHandle* packed_handle = nullptr; + if (has_packed_input) { + int num_replicas = 1; + std::vector packed_handles = {h1_task2}; + packed_handle = TFE_CreatePackedTensorHandle(ctx, packed_handles.data(), + &num_replicas, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + } + TFE_Op* matmul = nullptr; if (func) { const string matmul_device = remote_func_outputs ? task2_name : ""; @@ -135,7 +146,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); TFE_OpAddInput(matmul, h0_task0, status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); - TFE_OpAddInput(matmul, h1_task2, status); + TFE_OpAddInput(matmul, has_packed_input ? packed_handle : h1_task2, status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); } else { // Handles are on task0 (local), and task2, but op is on task1. @@ -194,6 +205,9 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, TFE_DeleteTensorHandle(h0_task0); TFE_DeleteTensorHandle(h1_task0); + if (packed_handle) { + TFE_DeleteTensorHandle(packed_handle); + } TFE_DeleteTensorHandle(h1_task2); TFE_DeleteTensorHandle(retvals[0]); for (auto* h : handles_task0) { diff --git a/tensorflow/c/eager/c_api_remote_test_util.h b/tensorflow/c/eager/c_api_remote_test_util.h index 08633689402d48..6d9edb65feaba7 100644 --- a/tensorflow/c/eager/c_api_remote_test_util.h +++ b/tensorflow/c/eager/c_api_remote_test_util.h @@ -16,11 +16,12 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ // Run a function containing a MatMul op and check its output. -// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one -// which creates a remote remote input, to simulate a scenario that the remote -// input is not ready when we start running an op or a function. +// If heavy_load_on_streaming_rpc is true, send some rpc requests before the one +// which creates a remote input, to simulate a scenario that the remote input +// is not ready when we start running an op or a function. void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, bool heavy_load_on_streaming_rpc, - bool remote_func_outputs = false); + bool remote_func_outputs = false, + bool has_packed_input = false); #endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index fd208c6770d0f2..813cfdb613a9e2 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -43,13 +43,13 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" using tensorflow::string; namespace { -void BM_InitOp(int iters) { - tensorflow::testing::StopTiming(); +void BM_InitOp(::testing::benchmark::State& state) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -57,12 +57,10 @@ void BM_InitOp(int iters) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_DeleteOp(matmul); } - tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteContext(ctx); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -70,8 +68,8 @@ void BM_InitOp(int iters) { } BENCHMARK(BM_InitOp); -void BM_Execute(int iters, int async) { - tensorflow::testing::StopTiming(); +void BM_Execute(::testing::benchmark::State& state) { + const int async = state.range(0); tensorflow::testing::SetLabel(async ? "ExecuteAsync" : "Execute"); TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -85,8 +83,7 @@ void BM_Execute(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retvals[1]; int num_retvals = 1; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TFE_OpReset(matmul, "MatMul", nullptr, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpAddInput(matmul, m, status); @@ -95,14 +92,13 @@ void BM_Execute(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Execute(matmul, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + if (state.iterations() >= state.max_iterations && async) { + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + } } - if (async) { - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - } - tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); TFE_DeleteContext(ctx); @@ -111,8 +107,8 @@ void BM_Execute(int iters, int async) { } BENCHMARK(BM_Execute)->Arg(0)->Arg(1); -void BM_Execute_Identity(int iters, int async) { - tensorflow::testing::StopTiming(); +void BM_Execute_Identity(::testing::benchmark::State& state) { + const int async = state.range(0); tensorflow::testing::SetLabel(async ? "ExecuteIdentityAsync" : "ExecuteIdentity"); TF_Status* status = TF_NewStatus(); @@ -126,22 +122,20 @@ void BM_Execute_Identity(int iters, int async) { TFE_Op* identity = TFE_NewOp(ctx, "Identity", status); TFE_TensorHandle* retvals[1]; int num_retvals = 1; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { TFE_OpReset(identity, "Identity", nullptr, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_OpAddInput(identity, m, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Execute(identity, &retvals[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + if (state.iterations() >= state.max_iterations && async) { + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + } } - if (async) { - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - } - tensorflow::testing::StopTiming(); TFE_DeleteOp(identity); TFE_DeleteTensorHandle(m); TFE_DeleteContext(ctx); @@ -423,7 +417,7 @@ void TensorHandleSilentCopy(bool async, tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu)); auto gpu_arg = tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu)); - auto gpu_device = absl::get(gpu_arg->device()); + auto gpu_device = gpu_arg->device(); ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device)); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); @@ -650,10 +644,19 @@ void ExecuteAdd(bool async, bool forward_input, bool tfrt) { TFE_DeleteOp(add_op); TF_Tensor* t = TFE_TensorHandleResolve(retval, status); - if (forward_input || async) { - EXPECT_EQ(orig_ptr, TF_TensorData(t)); + if (async) { + if (forward_input) { + EXPECT_EQ(orig_ptr, TF_TensorData(t)); + } else { + // TODO(b/156981931): Flaky test. Very occasionally the following is false + // EXPECT_EQ(orig_ptr, TF_TensorData(t)); + } } else { - EXPECT_NE(orig_ptr, TF_TensorData(t)); + if (forward_input) { + EXPECT_EQ(orig_ptr, TF_TensorData(t)); + } else { + EXPECT_NE(orig_ptr, TF_TensorData(t)); + } } ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -696,8 +699,9 @@ TEST(CAPI, ExecuteAddForwardAsync) { /*tfrt*/ false); } #ifdef PLATFORM_GOOGLE -// TODO(b/153349425): Add add forwarding tests for TFRT -TEST(CAPI, ExecuteAddTfrt) { +// TODO(b/153349425): Add forwarding tests for TFRT +// TODO(b/178003466): Fix and re-enable. +TEST(CAPI, DISABLED_ExecuteAddTfrt) { ExecuteAdd( /*async=*/false, /*forward_input*/ false, @@ -769,7 +773,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); EXPECT_NE(TF_OK, TF_GetCode(status)); EXPECT_EQ(nullptr, t); - const char* msg = "Matrix size-incompatible: In[0]: [2,2], In[1]: [3,2]"; + const char* msg = "In[0] mismatch In[1] shape: 2 vs. 3: [2,2] [3,2]"; EXPECT_TRUE(strstr(TF_Message(status), msg) != nullptr) << TF_Message(status); // Since error is not cleared, the following copy with correct device will @@ -955,6 +959,41 @@ string MatMulFunction() { return def.SerializeAsString(); } +// a + a +string AddFunction() { + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + " signature {" + " name: 'AddFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'o'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'output'" + " op: 'Add'" + " input: 'a'" + " input: 'a'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'o'" + " value: 'output:z'" + " }", + &def)); + return def.SerializeAsString(); +} + void FunctionDefAndExecute(bool async) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -1005,8 +1044,108 @@ void FunctionDefAndExecute(bool async) { TEST(CAPI, FunctionDefAndExecute) { FunctionDefAndExecute(false); } TEST(CAPI, FunctionDefAndExecuteAsync) { FunctionDefAndExecute(true); } -void BM_ExecuteFunction(int iters, int async) { - tensorflow::testing::StopTiming(); +void RunAddFunction(bool use_tfrt, bool enable_grappler) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + string function_def = AddFunction(); + TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), + status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); + TFE_TensorHandle* retval[1] = {nullptr}; + int num_retvals = 1; + TFE_Op* op = TFE_NewOp(ctx, "AddFunction", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + // Add a config_proto attr, to trigger grappler graph rewrites in the current + // eager runtime. + if (enable_grappler) { + tensorflow::ConfigProto config; + // Do not skip grappler optimization even for small graphs. + config.mutable_graph_options() + ->mutable_rewrite_options() + ->set_min_graph_nodes(-1); + string serialized_config; + ASSERT_TRUE(config.SerializeToString(&serialized_config)); + TFE_OpSetAttrString( + op, "config_proto", + reinterpret_cast(serialized_config.c_str()), + serialized_config.length()); + } + + if (use_tfrt) { + // Set some test-only graph compiler options. + TFE_OpSetAttrBool(op, "TFRT_TEST_enable_native_ops", false); + TFE_OpSetAttrBool(op, "TFRT_TEST_enable_grappler", enable_grappler); + } + + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_OpAddInput(op, m, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Execute(op, &retval[0], &num_retvals, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + TFE_DeleteOp(op); + TFE_DeleteTensorHandle(m); + TF_Tensor* t = TFE_TensorHandleResolve(retval[0], status); + TFE_DeleteTensorHandle(retval[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(2, product[0]); + EXPECT_EQ(4, product[1]); + EXPECT_EQ(6, product[2]); + EXPECT_EQ(8, product[3]); + + // When we turn on grappler, confirm that the tf.Add has been rewritten into a + // tf.Mul. + // This capability of checking the executed op names is currently only enabled + // for TFRT debug build, for performance and simplicity reasons. + if (use_tfrt) { + TF_Buffer* buf = TF_NewBuffer(); + TFE_GetExecutedOpNames(ctx, buf, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); +#ifndef NDEBUG + if (enable_grappler) + EXPECT_NE(strstr(static_cast(buf->data), "tf.Mul"), nullptr); + else + EXPECT_NE(strstr(static_cast(buf->data), "tf.Add"), nullptr); +#endif + TF_DeleteBuffer(buf); + } + + TFE_ContextRemoveFunction(ctx, "AddFunction", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContext(ctx); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +TEST(CAPI, RunAddFunctionWithGrappler) { + RunAddFunction(/*use_tfrt=*/false, /*enable_grappler=*/true); +} + +#ifdef PLATFORM_GOOGLE +TEST(CAPI, RunAddFunction_TFRT) { + RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/false); +} + +TEST(CAPI, RunAddFunctionWithGrappler_TFRT) { + RunAddFunction(/*use_tfrt=*/true, /*enable_grappler=*/true); +} +#endif + +void BM_ExecuteFunction(::testing::benchmark::State& state) { + const int async = state.range(0); tensorflow::testing::SetLabel(async ? "ExecuteFunctionAsync" : "ExecuteFunction"); TF_Status* status = TF_NewStatus(); @@ -1022,24 +1161,23 @@ void BM_ExecuteFunction(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* m = TestMatrixTensorHandle(ctx); - TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(matmul, m, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_TensorHandle* retval[1] = {nullptr}; int num_retvals = 1; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { + TFE_Op* matmul = TFE_NewOp(ctx, "MatMulFunction", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(matmul, m, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Execute(matmul, &retval[0], &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + if (state.iterations() >= state.max_iterations && async) { + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + } } - if (async) { - TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); - TFE_ExecutorWaitForAllPendingNodes(executor, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteExecutor(executor); - } - tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); TFE_DeleteTensorHandle(retval[0]); TFE_ContextRemoveFunction(ctx, "MatMulFunction", status); @@ -1092,8 +1230,7 @@ TEST(CAPI, Variables) { TF_DeleteStatus(status); } -void BM_ReadVariable(int iters) { - tensorflow::testing::StopTiming(); +void BM_ReadVariable(::testing::benchmark::State& state) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); @@ -1103,16 +1240,14 @@ void BM_ReadVariable(int iters) { TFE_TensorHandle* var_handle = TestVariable(ctx, 5.0); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpSetAttrType(op, "dtype", TF_FLOAT); - TFE_OpAddInput(op, var_handle, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - int num_retvals = 1; TFE_TensorHandle* h = nullptr; - tensorflow::testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { + TFE_Op* op = TFE_NewOp(ctx, "ReadVariableOp", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrType(op, "dtype", TF_FLOAT); + TFE_OpAddInput(op, var_handle, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_Execute(op, &h, &num_retvals, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(1, num_retvals); @@ -1121,11 +1256,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; - TFE_OpAddInput(op, var_handle, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(op); } - tensorflow::testing::StopTiming(); - TFE_DeleteOp(op); TFE_DeleteTensorHandle(var_handle); TFE_DeleteContext(ctx); @@ -1134,7 +1266,8 @@ void BM_ReadVariable(int iters) { } BENCHMARK(BM_ReadVariable); -TEST(CAPI, StringAttributes) { +// TODO(b/178003466): Fix and re-enable. +TEST(CAPI, DISABLED_StringAttributes) { // Test that TFE_OpSetAttrString doesn't hold on to the value after it // returns. TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index ad0c7c6340f65b..3f6fdeb4e9298e 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -16,6 +16,9 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" @@ -53,6 +56,27 @@ TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[], TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[], int64_t dims[], int num_dims); +// Return a tensor handle with given type, values and dimensions. +template +TFE_TensorHandle* TestTensorHandleWithDims(TFE_Context* ctx, const T* data, + const int64_t* dims, int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, datatype, dims, num_dims, status); + memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +// Return a scalar tensor handle with given values. +template +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, const T value) { + T data[] = {value}; + return TestTensorHandleWithDims(ctx, data, nullptr, 0); +} + // Return a tensor handle containing a 100x100 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 2d290df19cec7d..f89d3e84cf42ec 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, } TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, - TF_DataType dtype, TF_Status* s) { + TF_DataType dtype, TF_Shape shape, + TF_Status* s) { + DCHECK_GE(shape.num_dims, -1); TracingTensorHandle* t; TracingContext* tracing_ctx = dyn_cast(unwrap(func)); if (!tracing_ctx) { @@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, "TF_AddFunctionParameter must be called on a TracingContext.")); return nullptr; } + tensorflow::PartialTensorShape partial_shape; + if (shape.num_dims != -1) { + DCHECK(shape.dim_sizes != nullptr); + Status status = tensorflow::PartialTensorShape::MakePartialShape( + reinterpret_cast(shape.dim_sizes), shape.num_dims, + &partial_shape); + if (!status.ok()) { + Set_TF_Status_from_Status(s, status); + return nullptr; + } + } Set_TF_Status_from_Status( - s, tracing_ctx->AddParameter(static_cast(dtype), &t)); + s, tracing_ctx->AddParameter(static_cast(dtype), partial_shape, + &t)); return wrap(t); } diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index d216b4e694b4e5..ee22695632fd12 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, TF_Status* s); void TF_DeleteExecutionContext(TF_ExecutionContext*); +// Represents a (partially-defined) shape. +typedef struct TF_Shape { + int num_dims; // Must be >= -1; -1 represents unknown rank. + int64_t* dim_sizes; +} TF_Shape; + // Add a new parameter to a TensorFlow Function. -// TODO(aminim): what about shape? TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, - TF_DataType dtype, TF_Status* s); + TF_DataType dtype, TF_Shape shape, + TF_Status* s); // Create an operation suitable to use with the provided context. The operation // requires its type (e.g. "AddV2") to be set independently. diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 0e9d6c18157f17..b229abb0cb6e42 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" @@ -43,22 +45,50 @@ class GraphContext; class GraphOperation; class GraphTensor; +auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim; +auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank; + // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index // into the list of outputs for the operation. class GraphTensor : public TracingTensorHandle { public: - explicit GraphTensor(TF_Output output) - : TracingTensorHandle(kGraph), output_(output) {} + explicit GraphTensor(TF_Output output, TF_Graph* graph) + : TracingTensorHandle(kGraph), output_(output), graph_(graph) {} tensorflow::DataType DataType() const override { return static_cast(TF_OperationOutputType(output_)); } + + tensorflow::Status Shape( + tensorflow::PartialTensorShape* shape) const override { + DCHECK(shape != nullptr); + TF_Status status; + int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status); + DCHECK_GE(num_dims, -1); + TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); + if (num_dims == kUnknownRank) { + return Status::OK(); + } + + std::vector dims(num_dims, kUnknownDim); + TF_GraphGetTensorShape(graph_, output_, + reinterpret_cast(dims.data()), num_dims, + &status); + TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); + TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape)); + + return Status::OK(); + } + TF_Output output_; // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { return ptr->getKind() == kGraph; } + + private: + TF_Graph* graph_; // For shape inference. }; // GraphOperation wraps and populates a TF_OperationDescription. @@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation { TF_DeleteStatus(s); *num_retvals = TF_OperationNumOutputs(operation); for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = new GraphTensor({operation, i}); + retvals[i] = new GraphTensor({operation, i}, g_); } return Status::OK(); } @@ -326,12 +356,18 @@ class GraphContext : public TracingContext { return new GraphOperation(graph_.get()); } - Status AddParameter(DataType dtype, TracingTensorHandle** output) override { + Status AddParameter(DataType dtype, const PartialTensorShape& shape, + TracingTensorHandle** output) override { TracingOperationPtr operation(CreateOperation()); TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr)); TF_RETURN_IF_ERROR( operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str())); TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype)); + if (!shape.unknown_rank()) { + TF_RETURN_IF_ERROR(operation->SetAttrShape( + "shape", reinterpret_cast(shape.dim_sizes().data()), + shape.dims())); + } int num_outputs = 1; std::vector outputs(num_outputs); TF_RETURN_IF_ERROR(operation->Execute( diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index 9433fe8f120836..cd0d7610c7faa8 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/types.h" @@ -107,7 +108,8 @@ class TracingContext : public AbstractContext { public: // Add a function parameter and return the corresponding tensor. - virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0; + virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape, + TracingTensorHandle**) = 0; // Finalize this context and make a function out of it. The context is in a // invalid state after this call and must be destroyed. diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 432ddb4b2d4984..71dcfc4dcd2fbf 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); auto* placeholder_t = - TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract operation. @@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); auto* placeholder_t = - TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract operation. @@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); // Create a first "Add" computing `arg0 + arg1`. @@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) { TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); // Create a first "Add" computing `arg0 + arg1`. @@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) { // Add a placeholder to the graph. auto placeholder_t = - TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get()); TF_AbstractTensorGetEagerTensor(placeholder_t, status.get()); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); diff --git a/tensorflow/c/eager/custom_device_testutil.cc b/tensorflow/c/eager/custom_device_testutil.cc index 014abe383688e5..f4221e765cd39b 100644 --- a/tensorflow/c/eager/custom_device_testutil.cc +++ b/tensorflow/c/eager/custom_device_testutil.cc @@ -45,23 +45,31 @@ struct LoggedTensor { ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); } }; -void LoggedTensorDeallocator(void* data, size_t len, void* arg) { +int64_t LoggedTensorDim(void* data, int dim_index, TF_Status* status) { + return TFE_TensorHandleDim(reinterpret_cast(data)->tensor, + dim_index, status); +} + +int LoggedTensorNumDims(void* data, TF_Status* status) { + return TFE_TensorHandleNumDims(reinterpret_cast(data)->tensor, + status); +} + +void LoggedTensorDeallocator(void* data) { delete reinterpret_cast(data); } TFE_TensorHandle* MakeLoggedTensorHandle( TFE_Context* context, const tensorflow::string& logging_device_name, std::unique_ptr t, TF_Status* status) { - std::vector shape(TFE_TensorHandleNumDims(t->tensor, status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - for (int i = 0; i < shape.size(); ++i) { - shape[i] = TFE_TensorHandleDim(t->tensor, i, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } auto dtype = TFE_TensorHandleDataType(t->tensor); - return TFE_NewTensorHandleFromDeviceMemory( - context, logging_device_name.c_str(), dtype, shape.data(), shape.size(), - t.release(), 1, &LoggedTensorDeallocator, nullptr, status); + TFE_CustomDeviceTensorHandleMethods handle_methods; + handle_methods.num_dims = &LoggedTensorNumDims; + handle_methods.dim = &LoggedTensorDim; + handle_methods.deallocator = &LoggedTensorDeallocator; + return TFE_NewCustomDeviceTensorHandle(context, logging_device_name.c_str(), + dtype, t.release(), handle_methods, + status); } TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context, @@ -133,6 +141,7 @@ void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs, TFE_DeleteOp(op); if (TF_GetCode(s) != TF_OK) return; std::vector unwrapped_outputs; + unwrapped_outputs.reserve(op_outputs.size()); for (auto* handle : op_outputs) { unwrapped_outputs.push_back(handle); } diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index 640edc7228abb4..687f171bb7a975 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -18,18 +18,8 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/experimental/gradients/math_grad.h" -#include "tensorflow/c/experimental/gradients/nn_grad.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace gradients { @@ -39,22 +29,13 @@ using namespace std; // ================== Helper functions ================= // Fills data with values [start,end) with given step size. -void Range(vector* data, int start, int end, int step = 1) { - for (int i = start; i < end; i += step) { +void Range(vector* data, int32_t start, int32_t end, + int32_t step = 1) { + for (int32_t i = start; i < end; i += step) { (*data)[i] = i; } } -// Returns AbstractTensorHandlePtr containing [0, ..., n-1]. -AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) { - vector vals(n); - int64_t vals_shape[] = {n}; - Range(&vals, 0, n); - AbstractTensorHandlePtr r = - GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1); - return r; -} - // Fills out_dims with the dimensions of the given tensor. void GetDims(const TF_Tensor* t, int64_t* out_dims) { int num_dims = TF_NumDims(t); @@ -66,52 +47,59 @@ void GetDims(const TF_Tensor* t, int64_t* out_dims) { // Runs model as is if output is a scalar, // else sums the output tensor before returning. Status RunAndMaybeSum(AbstractContext* ctx, Model forward, - absl::Span inputs, + absl::Span inputs, absl::Span outputs, bool use_function) { - GradientRegistry registry; std::vector model_outputs(1); // Run the model. TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs, - absl::MakeSpan(model_outputs), use_function, - registry)); - AbstractTensorHandle* model_out = model_outputs[0]; + absl::MakeSpan(model_outputs), use_function)); + AbstractTensorHandlePtr model_out(model_outputs[0]); TF_Tensor* model_out_tensor; - TF_RETURN_IF_ERROR(GetValue(model_out, &model_out_tensor)); + TF_RETURN_IF_ERROR(GetValue(model_out.get(), &model_out_tensor)); int num_dims_out = TF_NumDims(model_out_tensor); + TF_DeleteTensor(model_out_tensor); // If the output is a scalar, then return the scalar output if (num_dims_out == 0) { - outputs[0] = model_out; + outputs[0] = model_out.release(); return Status::OK(); } // Else, reduce sum the output to get a scalar // Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1]. - AbstractTensorHandlePtr sum_dims = - GetRangeTensorHandleUtil(ctx, num_dims_out); + AbstractTensorHandlePtr sum_dims; + { + vector vals(num_dims_out); + int64_t vals_shape[] = {num_dims_out}; + Range(&vals, 0, num_dims_out); + AbstractTensorHandle* sum_dims_raw = nullptr; + TF_RETURN_IF_ERROR(TestTensorHandleWithDims( + ctx, vals.data(), vals_shape, 1, &sum_dims_raw)); + sum_dims.reset(sum_dims_raw); + } // Reduce sum the output on all dimensions. - std::vector sum_inputs(2); - sum_inputs[0] = model_out; - sum_inputs[1] = sum_dims.get(); - - TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs), - absl::MakeSpan(model_outputs), "sum_output")); - outputs[0] = model_outputs[0]; + TF_RETURN_IF_ERROR( + ops::Sum(ctx, {model_out.get(), sum_dims.get()}, outputs, "sum_output")); return Status::OK(); } // ========================= End Helper Functions============================== Status CalcNumericalGrad(AbstractContext* ctx, Model forward, - absl::Span inputs, + absl::Span inputs, int input_index, bool use_function, AbstractTensorHandle** numerical_grad) { + vector theta_inputs(inputs.size()); + for (int i{}; i < inputs.size(); ++i) { + theta_inputs[i] = inputs[i]; + } + AbstractTensorHandle* theta = - inputs[input_index]; // parameter we are grad checking + theta_inputs[input_index]; // parameter we are grad checking // Convert from AbstractTensor to TF_Tensor. TF_Tensor* theta_tensor; @@ -139,61 +127,77 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, // Numerical Grad Check for (int i = 0; i < num_elems; i++) { // Get relative epsilon value - float epsilon = - std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0 - AbstractTensorHandlePtr two_eps = - GetScalarTensorHandleUtil(ctx, 2 * epsilon); + float epsilon = theta_data[i] == 0 ? 1e-4 : std::abs(theta_data[i] * 1e-4); + AbstractTensorHandlePtr two_eps; + { + AbstractTensorHandle* two_eps_raw = nullptr; + TF_RETURN_IF_ERROR(TestScalarTensorHandle( + ctx, 2 * epsilon, &two_eps_raw)); + two_eps.reset(two_eps_raw); + } // Initialize theta[i] + epsilon. memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor), TF_TensorByteSize(theta_tensor)); thetaPlus_data[i] += epsilon; - AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat( - ctx, thetaPlus_data.data(), theta_dims.data(), num_dims); + AbstractTensorHandlePtr thetaPlus; + { + AbstractTensorHandle* thetaPlus_raw = nullptr; + TF_RETURN_IF_ERROR(TestTensorHandleWithDims( + ctx, thetaPlus_data.data(), theta_dims.data(), num_dims, + &thetaPlus_raw)); + thetaPlus.reset(thetaPlus_raw); + } // Initialize theta[i] - epsilon. memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor), TF_TensorByteSize(theta_tensor)); thetaMinus_data[i] -= epsilon; - AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat( - ctx, thetaMinus_data.data(), theta_dims.data(), num_dims); + AbstractTensorHandlePtr thetaMinus; + { + AbstractTensorHandle* thetaMinus_raw = nullptr; + TF_RETURN_IF_ERROR(TestTensorHandleWithDims( + ctx, thetaMinus_data.data(), theta_dims.data(), num_dims, + &thetaMinus_raw)); + thetaMinus.reset(thetaMinus_raw); + } // Get f(theta + eps): - inputs[input_index] = thetaPlus.get(); - TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + theta_inputs[input_index] = thetaPlus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs, absl::MakeSpan(f_outputs), use_function)); - AbstractTensorHandle* fPlus = f_outputs[0]; + AbstractTensorHandlePtr fPlus(f_outputs[0]); // Get f(theta - eps): - inputs[input_index] = thetaMinus.get(); - TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + theta_inputs[input_index] = thetaMinus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs, absl::MakeSpan(f_outputs), use_function)); - AbstractTensorHandle* fMinus = f_outputs[0]; + AbstractTensorHandlePtr fMinus(f_outputs[0]); // Take Difference of both estimates: (f(theta + eps) - f(theta - eps)). - TF_RETURN_IF_ERROR( - ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top")); - AbstractTensorHandle* fDiff = f_outputs[0]; + TF_RETURN_IF_ERROR(ops::Sub(ctx, {fPlus.get(), fMinus.get()}, + absl::MakeSpan(f_outputs), "sub_top")); + AbstractTensorHandlePtr fDiff(f_outputs[0]); // Calculate using the difference quotient definition: // (f(theta + eps) - f(theta - eps)) / (2 * eps). - TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()}, - absl::MakeSpan(f_outputs), - "diff_quotient")); - AbstractTensorHandle* diff_quotient = f_outputs[0]; + TF_RETURN_IF_ERROR(ops::Div(ctx, {fDiff.get(), two_eps.get()}, + absl::MakeSpan(f_outputs), "diff_quotient")); + AbstractTensorHandlePtr diff_quotient(f_outputs[0]); TF_Tensor* grad_tensor; - TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor)); + TF_RETURN_IF_ERROR(GetValue(diff_quotient.get(), &grad_tensor)); float grad_data[1]; memcpy(&grad_data[0], TF_TensorData(grad_tensor), TF_TensorByteSize(grad_tensor)); - + TF_DeleteTensor(grad_tensor); dtheta_approx[i] = grad_data[0]; } // Populate *numerical_grad with the data from dtheta_approx. - TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( + TF_RETURN_IF_ERROR(TestTensorHandleWithDims( ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); + TF_DeleteTensor(theta_tensor); return Status::OK(); } diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h index 8497f5af48e5ec..c1671480bf9bf9 100644 --- a/tensorflow/c/eager/gradient_checker.h +++ b/tensorflow/c/eager/gradient_checker.h @@ -12,23 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_ +#define TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_ + #include #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/eager/gradients_util.h" -#include "tensorflow/c/experimental/gradients/math_grad.h" -#include "tensorflow/c/experimental/gradients/nn_grad.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" +#include "tensorflow/c/eager/unified_api_testutil.h" namespace tensorflow { namespace gradients { @@ -45,9 +36,11 @@ namespace gradients { * hold the numerical gradient data at the end of the function. */ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, - absl::Span inputs, + absl::Span inputs, int input_index, bool use_function, AbstractTensorHandle** numerical_grad); } // namespace gradients } // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_ diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index 393ad2ceb98862..3fef906f58d0d6 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -15,20 +15,11 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/eager/gradients_util.h" -#include "tensorflow/c/eager/mnist_gradients_testutil.h" -#include "tensorflow/c/experimental/gradients/math_grad.h" -#include "tensorflow/c/experimental/gradients/nn_grad.h" -#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/platform/test.h" @@ -37,6 +28,59 @@ namespace gradients { namespace internal { namespace { +using tensorflow::TF_StatusPtr; + +void CompareNumericalAndManualGradients( + Model model, AbstractContext* ctx, + absl::Span inputs, int input_index, + float* expected_grad, int num_grad, bool use_function, + double abs_error = 1e-2) { + Status s; + AbstractTensorHandlePtr numerical_grad; + { + AbstractTensorHandle* numerical_grad_raw; + s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function, + &numerical_grad_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + numerical_grad.reset(numerical_grad_raw); + } + + TF_Tensor* numerical_tensor; + s = GetValue(numerical_grad.get(), &numerical_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto num_elem_numerical = TF_TensorElementCount(numerical_tensor); + ASSERT_EQ(num_elem_numerical, num_grad); + + float* dnumerical = new float[num_elem_numerical]{0}; + memcpy(&dnumerical[0], TF_TensorData(numerical_tensor), + TF_TensorByteSize(numerical_tensor)); + + for (int j = 0; j < num_grad; j++) { + ASSERT_NEAR(dnumerical[j], expected_grad[j], abs_error); + } + delete[] dnumerical; + TF_DeleteTensor(numerical_tensor); +} + +Status MatMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::MatMul(ctx, inputs, outputs, "MatMul", + /*transpose_a=*/false, + /*transpose_b=*/false); +} + +Status MulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Mul(ctx, inputs, outputs, "Mul"); +} + +// TODO(vnvo2409): Add more tests from `python/ops/gradient_checker_v2_test.py`. +// These tests should not be confused with `[*]_grad_test` which compare the +// result of `gradient_checker` and `[*]_grad`. The tests here test the +// functionality of `gradient_checker` by comparing the result with expected +// manual user-provided gradients. class GradientCheckerTest : public ::testing::TestWithParam> { protected: @@ -45,84 +89,62 @@ class GradientCheckerTest TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.error_message(); + + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx_.reset(ctx_raw); + } + + // Computing numerical gradients with TensorFloat-32 is numerically + // unstable. Some forward pass tests also fail with TensorFloat-32 due to + // low tolerances + enable_tensor_float_32_execution(false); } -}; -Status RegisterGradients(GradientRegistry* registry) { - TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); - TF_RETURN_IF_ERROR( - registry->Register("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyWithLogitsRegisterer)); - return Status::OK(); -} + AbstractContextPtr ctx_; -TEST_P(GradientCheckerTest, TestGradCheckMatMul) { - // Computing numerical gradients with TensorFloat-32 is numerically unstable - enable_tensor_float_32_execution(false); + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } +}; - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; +TEST_P(GradientCheckerTest, TestMatMul) { + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + AbstractTensorHandlePtr A; { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + AbstractTensorHandle* A_raw; + Status s = TestTensorHandleWithDims(ctx_.get(), A_vals, + A_dims, 2, &A_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); + A.reset(A_raw); } - - float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t A_dims[] = {2, 2}; float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; int64_t B_dims[] = {2, 2}; - int num_dims = 2; - - AbstractTensorHandlePtr A = - GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); - AbstractTensorHandlePtr B = - GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); - - std::vector inputs; - inputs.push_back(A.get()); - inputs.push_back(B.get()); - - AbstractTensorHandle* grad_approx; - Status s = CalcNumericalGrad( - ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0, - /*use_function=*/!std::get<2>(GetParam()), &grad_approx); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* gt; - s = GetValue(grad_approx, >); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - float result_data[4] = {0}; - memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); - - float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; - float tolerance = 1e-2; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(expected_dA[j], result_data[j], tolerance); - } - TF_DeleteTensor(gt); -} - -TEST_P(GradientCheckerTest, TestGradCheckMul) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - AbstractContextPtr ctx; + AbstractTensorHandlePtr B; { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + AbstractTensorHandle* B_raw; + Status s = TestTensorHandleWithDims(ctx_.get(), B_vals, + B_dims, 2, &B_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); + B.reset(B_raw); } + float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients( + MatMulModel, ctx_.get(), {A.get(), B.get()}, 0, expected_dA, 4, + UseFunction())); +} + +TEST_P(GradientCheckerTest, TestMul) { AbstractTensorHandlePtr x; { AbstractTensorHandle* x_raw = nullptr; - Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + Status s = + TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); x.reset(x_raw); } @@ -130,124 +152,16 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) { AbstractTensorHandlePtr y; { AbstractTensorHandle* y_raw = nullptr; - Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - y.reset(y_raw); - } - - // Will perform z = x*y. - // dz/dx = y - - std::vector inputs; - inputs.push_back(x.get()); - inputs.push_back(y.get()); - AbstractTensorHandle* g; - - Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs), - /*input_index=*/0, - /*use_function=*/!std::get<2>(GetParam()), &g); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* gt; - s = GetValue(g, >); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - float result_data[1] = {0}; - memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); - - ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2); - TF_DeleteTensor(gt); -} - -TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { - bool use_function = !std::get<2>(GetParam()); - if (use_function) { - // TODO(b/168850692): Enable this. - GTEST_SKIP() << "Can't take gradient of " - "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; - } - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - /** Test to show how to use this API with analytical gradients: - * - * We have `SoftmaxLossGradModel`, which is a wrapper for the - * Softmax analytical gradient found in c/experimental/nn_grads. - * - * We will use the GradientChecker by applying finite differences - * to the forward pass wrapped in `SoftmaxModel` and verify that - * both the analytical and numerical gradients are relatively - * close. - * - */ - - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = scores - float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, 1.0f}; - int64_t X_dims[] = {3, 3}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - // y = labels - int y_vals[] = {1, 0, 1}; - int64_t y_dims[] = {3}; - num_dims = sizeof(y_dims) / sizeof(y_dims[0]); - AbstractTensorHandlePtr y = - GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - std::vector inputs; - inputs.push_back(X.get()); - inputs.push_back(y.get()); - - // Run analytical gradient and get its data. - std::vector outputs(2); - s = RunModel(SoftmaxLossGradModel, ctx.get(), absl::MakeSpan(inputs), - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* dX_tensor; - s = GetValue(outputs[0], &dX_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float danalytical[9] = {0}; // Contains data from analytical gradient. - memcpy(&danalytical[0], TF_TensorData(dX_tensor), - TF_TensorByteSize(dX_tensor)); - - // Run numerical gradient approximation using the GradientChecker API. - AbstractTensorHandle* g; // Will contain numerical approximation data. - s = CalcNumericalGrad(ctx.get(), SoftmaxModel, absl::MakeSpan(inputs), - /*input_index=*/0, - /*use_function=*/!std::get<2>(GetParam()), &g); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* gt; - s = GetValue(g, >); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - float dnumerical[9] = {0}; - memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt)); - - // Now compare the two implementations: - for (int j = 0; j < 9; j++) { - ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2); + y.reset(y_raw); } - // Only Unref() first output as 2nd is nullptr grad for labels - outputs[0]->Unref(); - TF_DeleteTensor(dX_tensor); - TF_DeleteTensor(gt); + float expected_dx[1] = {7.0f}; + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients( + MulModel, ctx_.get(), {x.get(), y.get()}, 0, expected_dx, 1, + UseFunction())); } #ifdef PLATFORM_GOOGLE @@ -255,13 +169,13 @@ INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, GradientCheckerTest, ::testing::Combine(::testing::Values("graphdef"), /*tfrt*/ ::testing::Values(false), - /*executing_eagerly*/ ::testing::Values(true, false))); + /*use_function*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, GradientCheckerTest, ::testing::Combine(::testing::Values("graphdef"), /*tfrt*/ ::testing::Values(false), - /*executing_eagerly*/ ::testing::Values(true, false))); + /*use_function*/ ::testing::Values(true, false))); #endif } // namespace } // namespace internal diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 58ffcf247cf836..f83c7fee9327f1 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -20,11 +20,19 @@ limitations under the License. #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace gradients { - namespace { + +// TODO(b/172558015): Using the pointer address as the identifier for the tensor +// may lead to collisions. Introduce another way to get a unique id for this +// tensor. +int64 ToId(const AbstractTensorHandle* t) { + return static_cast(reinterpret_cast(t)); +} + Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, AbstractTensorHandle** result) { AbstractOperationPtr op(ctx->CreateOperation()); @@ -43,85 +51,28 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, } } // namespace -class IncomingGradientsImpl : public IncomingGradients { - public: - explicit IncomingGradientsImpl( - absl::Span grad_inputs, Context* ctx, - DefaultGradientFunction* default_gradients) - : grad_inputs_(grad_inputs), - ctx_(ctx), - default_gradients_(default_gradients) {} - AbstractTensorHandle* operator[](int i) const override { - return default_gradients_->get(ctx_, grad_inputs_, i); - } - size_t size() const override { return grad_inputs_.size(); } - - private: - absl::Span grad_inputs_; - Context* ctx_; - DefaultGradientFunction* default_gradients_; -}; - -AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op) - : outputs_(op.outputs) { - for (auto output : outputs_) { - output->Ref(); - } -} -AbstractTensorHandle* AllZerosDefaultGradients::get( - Context* ctx, absl::Span grad_inputs, int i) { - if (grad_inputs[i]) { - return grad_inputs[i]; - } - if (cached_default_grads_[i]) { - return cached_default_grads_[i].get(); - } - AbstractTensorHandle* result = nullptr; - Status s = ZerosLike(ctx->ctx, outputs_[i], &result); - if (!s.ok()) { - if (result) { - result->Unref(); - } - VLOG(1) << "Failed to create ZerosLike for index " << i; - return nullptr; - } - cached_default_grads_[i].reset(result); - return result; -} - -PassThroughDefaultGradients::PassThroughDefaultGradients( - const ForwardOperation& op) {} -AbstractTensorHandle* PassThroughDefaultGradients::get( - Context* ctx, absl::Span grad_inputs, int i) { - return grad_inputs[i]; -} - Status GradientRegistry::Register( - const string& op_name, BackwardFunctionFactory backward_function_factory) { + const string& op_name, GradientFunctionFactory gradient_function_factory) { auto iter = registry_.find(op_name); if (iter != registry_.end()) { const string error_msg = "Gradient already exists for op: " + op_name + "."; return errors::AlreadyExists(error_msg); } - registry_.insert({op_name, backward_function_factory}); + registry_.insert({op_name, gradient_function_factory}); return Status::OK(); } Status GradientRegistry::Lookup( const ForwardOperation& op, - std::unique_ptr* backward_function) const { + std::unique_ptr* gradient_function) const { auto iter = registry_.find(op.op_name); if (iter == registry_.end()) { const string error_msg = "No gradient defined for op: " + op.op_name + "."; return errors::NotFound(error_msg); } - backward_function->reset(iter->second(op)); + gradient_function->reset(iter->second(op)); return Status::OK(); } -int64 ToId(AbstractTensorHandle* t) { - return static_cast(reinterpret_cast(t)); -} - TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) { handle_->Ref(); } @@ -140,6 +91,47 @@ AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; } AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; } +class TapeVSpace + : public eager::VSpace { + public: + explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {} + ~TapeVSpace() override {} + + // Returns the number of elements in the gradient tensor. + int64 NumElements(AbstractTensorHandle* tensor) const override; + + // Consumes references to the tensors in the gradient_tensors list and returns + // a tensor with the result. + AbstractTensorHandle* AggregateGradients( + gtl::ArraySlice gradient_tensors) const override; + + // Calls the passed-in backward function. + // op_type is the op's name provided in RecordOperation. + Status CallBackwardFunction( + const string& op_type, GradientFunction* gradient_function, + const std::vector& unneeded_gradients, + gtl::ArraySlice output_gradients, + absl::Span result) const override; + + // Builds a tensor filled with ones with the same shape and dtype as `t`. + Status BuildOnesLike(const TapeTensor& t, + AbstractTensorHandle** result) const override; + + // Looks up the ID of a Gradient. + int64 TensorId(AbstractTensorHandle* tensor) const override; + + // Converts a Gradient to a TapeTensor. + TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override; + + void MarkAsResult(AbstractTensorHandle* gradient) const override; + + void DeleteGradient(AbstractTensorHandle* gradient) const override; + + private: + // The context where the aggregation op `Add` is to be created. + AbstractContext* ctx_; +}; + // Returns the number of elements in the gradient tensor. int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const { // TODO(srbs): It seems like this is used only for performance optimization @@ -178,17 +170,20 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients( } // Calls the passed-in backward function. +// op_type is the op's name provided in RecordOperation. Status TapeVSpace::CallBackwardFunction( - BackwardFunction* backward_function, + const string& op_type, GradientFunction* gradient_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, - std::vector* result) const { - if (backward_function == nullptr) return Status::OK(); - Context ctx = {ctx_}; - IncomingGradientsImpl incoming_gradients( - output_gradients, &ctx, backward_function->GetDefaultGradientFunction()); - return backward_function->GetGradientFunction()->Compute( - &ctx, incoming_gradients, result); + absl::Span result) const { + if (gradient_function == nullptr) { + return errors::InvalidArgument( + "Provided null gradient_function for '", op_type, "'.\n", + "If the intent is to treat this op as non-differentiable consider " + "using RegisterNotDifferentiable or " + "NotDifferentiableGradientFunction."); + } + return gradient_function->Compute(ctx_, output_gradients, result); } Status TapeVSpace::BuildOnesLike(const TapeTensor& t, @@ -224,9 +219,84 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const { gradient->Unref(); } +void Tape::Watch(const AbstractTensorHandle* t) { + GradientTape::Watch(ToId(t)); +} +void Tape::RecordOperation(absl::Span inputs, + absl::Span outputs, + GradientFunction* gradient_function, + const string& op_name) { + std::vector input_ids(inputs.size()); + std::vector input_dtypes(inputs.size()); + for (int i = 0; i < inputs.size(); i++) { + input_ids[i] = ToId(inputs[i]); + input_dtypes[i] = inputs[i]->DataType(); + } + std::vector tape_tensors; + for (auto t : outputs) { + tape_tensors.push_back(TapeTensor(t)); + } + GradientTape::RecordOperation( + op_name, tape_tensors, input_ids, input_dtypes, + [gradient_function]() -> GradientFunction* { return gradient_function; }, + [](GradientFunction* ptr) { + if (ptr) { + delete ptr; + } + }); +} +bool Tape::ShouldRecord( + absl::Span tensors) const { + std::vector tensor_ids(tensors.size()); + std::vector tensor_dtypes(tensors.size()); + for (int i = 0; i < tensors.size(); i++) { + tensor_ids[i] = ToId(tensors[i]); + tensor_dtypes[i] = tensors[i]->DataType(); + } + return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes); +} +void Tape::DeleteTrace(const AbstractTensorHandle* t) { + GradientTape::DeleteTrace(ToId(t)); +} + +std::vector MakeTensorIDList( + absl::Span tensors) { + std::vector ids(tensors.size()); + for (int i = 0; i < tensors.size(); i++) { + ids[i] = ToId(tensors[i]); + } + return ids; +} + +Status Tape::ComputeGradient( + AbstractContext* ctx, absl::Span targets, + absl::Span sources, + absl::Span output_gradients, + absl::Span result) { + TapeVSpace vspace(ctx); + std::vector target_tensor_ids = MakeTensorIDList(targets); + std::vector source_tensor_ids = MakeTensorIDList(sources); + tensorflow::gtl::FlatSet sources_set( + source_tensor_ids.begin(), source_tensor_ids.end()); + std::unordered_map sources_that_are_targets; + for (int i = 0; i < target_tensor_ids.size(); ++i) { + int64 target_id = target_tensor_ids[i]; + if (sources_set.find(target_id) != sources_set.end()) { + auto tensor = targets[i]; + sources_that_are_targets.insert( + std::make_pair(target_id, TapeTensor(tensor))); + } + } + + TF_RETURN_IF_ERROR(GradientTape::ComputeGradient( + vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets, + output_gradients, result, /*build_default_zeros_grads*/ false)); + return Status::OK(); +} + // Helper functions which delegate to `AbstractOperation`, update // the state of the ForwardOperation and call the tape as appropriate. -// These APIs are mainly to faciliate testing and are subject to change. +// These APIs are mainly to facilitate testing and are subject to change. namespace internal { Status Reset(AbstractOperation* op_, const char* op, const char* raw_device_name, ForwardOperation* forward_op_) { @@ -398,12 +468,6 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, ForwardOperation* forward_op_, Tape* tape, const GradientRegistry& registry) { TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals)); - std::vector input_ids(forward_op_->inputs.size()); - std::vector input_dtypes(forward_op_->inputs.size()); - for (int i = 0; i < forward_op_->inputs.size(); i++) { - input_ids[i] = ToId(forward_op_->inputs[i]); - input_dtypes[i] = forward_op_->inputs[i]->DataType(); - } for (int i = 0; i < *num_retvals; i++) { // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. forward_op_->outputs.push_back(retvals[i]); @@ -413,25 +477,10 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, // Consider getting rid of this and making the behavior between number types // and string consistent. forward_op_->attrs.BuildNodeDef(); - std::vector tape_tensors; - for (auto t : retvals) { - tape_tensors.push_back(TapeTensor(t)); - } - tape->RecordOperation( - op_->Name(), tape_tensors, input_ids, input_dtypes, - [registry, forward_op_]() -> BackwardFunction* { - std::unique_ptr backward_fn; - Status s = registry.Lookup(*forward_op_, &backward_fn); - if (!s.ok()) { - return nullptr; - } - return backward_fn.release(); - }, - [](BackwardFunction* ptr) { - if (ptr) { - delete ptr; - } - }); + std::unique_ptr gradient_fn; + TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn)); + tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(), + op_->Name()); return Status::OK(); } } // namespace internal diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index f7d80cbeb343cb..ea4e1ef7d4d907 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -33,10 +33,11 @@ namespace gradients { // public: // Status Compute(Context* ctx, // absl::Span grad_inputs, -// std::vector* grad_outputs) override { -// grad_outputs->resize(2); -// (*grad_outputs)[0] = grad_inputs[0]; -// (*grad_outputs)[1] = grad_inputs[0]; +// absl::Span grad_outputs) override { +// grad_outputs[0] = grad_inputs[0]; +// grad_outputs[1] = grad_inputs[0]; +// grad_outputs[0]->Ref(); +// grad_outputs[1]->Ref(); // return Status::OK(); // } // ~AddGradientFunction() override {} @@ -51,123 +52,41 @@ namespace gradients { // Status RegisterGradients(GradientRegistry* registry) { // return registry->Register("Add", AddRegisterer); // } -struct Context { - public: - AbstractContext* ctx; -}; - -class IncomingGradients { - public: - virtual AbstractTensorHandle* operator[](int i) const = 0; - virtual size_t size() const = 0; - virtual ~IncomingGradients() {} -}; - class GradientFunction { public: - // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in - // `grad_inputs`. - virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - std::vector* grad_outputs) = 0; + virtual Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) = 0; virtual ~GradientFunction() {} }; // Metadata from the forward operation that is made available to the -// gradient registerer to instantiate a BackwardFunction. +// gradient registerer to instantiate a GradientFunction. struct ForwardOperation { public: string op_name; std::vector inputs; std::vector outputs; + std::vector skip_input_indices; AttrBuilder attrs; }; -// Interface for building default zeros gradients for op outputs which are -// missing incoming gradients. Custom implementations of this can be used to -// control which of the forward op's output tensors/their metadata needs to -// be kept around in memory to build the default zeros grad. -// -// Some common helper implementations are provided below. -class DefaultGradientFunction { - public: - virtual AbstractTensorHandle* get( - Context* ctx, absl::Span grad_inputs, - int i) = 0; - virtual ~DefaultGradientFunction() {} -}; - -// Returns zeros for any `nullptr` in `grad_inputs`. -// -// This may require keeping track of all of forward op's output -// tensors and hence may incur a higher memory footprint. Use sparingly. -// -// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor -// handle. -// -// The destructor of this class `Unref`'s any cached tensor handles so users of -// those tensor handles should `Ref` them in order to keep them alive if needed. -class AllZerosDefaultGradients : public DefaultGradientFunction { - public: - explicit AllZerosDefaultGradients(const ForwardOperation& op); - AbstractTensorHandle* get(Context* ctx, - absl::Span grad_inputs, - int i) override; - - private: - // TODO(srbs): We do not always need to keep the tensors around. In immediate - // execution mode we just need to store the shape and dtype. During tracing - // we may need to keep the tensor around if the shape is not full defined. - std::vector outputs_; - std::vector cached_default_grads_; -}; - -// Passes through `grad_inputs` as-is. The `GradientFunction` -// will be expected to deal with nullptr in `grad_inputs` if any. -class PassThroughDefaultGradients : public DefaultGradientFunction { - public: - explicit PassThroughDefaultGradients(const ForwardOperation& op); - AbstractTensorHandle* get(Context* ctx, - absl::Span grad_inputs, - int i) override; -}; - -// A `BackwardFunction` wraps a `GradientFunction` and a -// `DefaultGradientFunction`. Both are owned by this class' instance. -class BackwardFunction { - public: - BackwardFunction(GradientFunction* gradient_function, - DefaultGradientFunction* default_gradients) - : gradient_function_(gradient_function), - default_gradients_(default_gradients) {} - GradientFunction* GetGradientFunction() { return gradient_function_.get(); } - DefaultGradientFunction* GetDefaultGradientFunction() { - return default_gradients_.get(); - } +using GradientFunctionFactory = + std::function; - private: - std::unique_ptr gradient_function_; - std::unique_ptr default_gradients_; -}; - -using BackwardFunctionFactory = - std::function; - -// Map from op name to a `BackwardFunctionFactory`. +// Map from op name to a `GradientFunctionFactory`. class GradientRegistry { public: Status Register(const string& op, - BackwardFunctionFactory backward_function_factory); + GradientFunctionFactory gradient_function_factory); Status Lookup(const ForwardOperation& op, - std::unique_ptr* backward_function) const; + std::unique_ptr* gradient_function) const; private: - absl::flat_hash_map registry_; + absl::flat_hash_map registry_; }; -// Returns a unique id for the tensor which is used by the tape to build -// the gradient graph. See documentation of `TapeTensor` for more details. -int64 ToId(AbstractTensorHandle* t); - +// TODO(srbs): Figure out if we can avoid declaring this in the public header. // Wrapper for a tensor output of an operation executing under a tape. // // `GetID` returns a unique id for the wrapped tensor which is used to maintain @@ -203,59 +122,53 @@ class TapeTensor { AbstractTensorHandle* handle_; }; -// Vector space for actually computing gradients. Implements methods for calling -// the backward function with incoming gradients and returning the outgoing -// gradient and for performing gradient aggregation. -// See `tensorflow::eager::VSpace` for more details. -class TapeVSpace - : public eager::VSpace { - public: - explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {} - ~TapeVSpace() override {} - - // Returns the number of elements in the gradient tensor. - int64 NumElements(AbstractTensorHandle* tensor) const override; - - // Consumes references to the tensors in the gradient_tensors list and returns - // a tensor with the result. - AbstractTensorHandle* AggregateGradients( - gtl::ArraySlice gradient_tensors) const override; - - // Calls the passed-in backward function. - Status CallBackwardFunction( - BackwardFunction* backward_function, - const std::vector& unneeded_gradients, - gtl::ArraySlice output_gradients, - std::vector* result) const override; - - // Builds a tensor filled with ones with the same shape and dtype as `t`. - Status BuildOnesLike(const TapeTensor& t, - AbstractTensorHandle** result) const override; - - // Looks up the ID of a Gradient. - int64 TensorId(AbstractTensorHandle* tensor) const override; - - // Converts a Gradient to a TapeTensor. - TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override; - - void MarkAsResult(AbstractTensorHandle* gradient) const override; - - void DeleteGradient(AbstractTensorHandle* gradient) const override; - - private: - // The context where the aggregation op `Add` is to be created. - AbstractContext* ctx_; -}; - // A tracing/immediate-execution agnostic tape. // -// Gradient functions defined for this library support handling null incoming -// gradients. `Tape::ComputeGradient` should be called with -// `build_default_zeros_grads=false`. Calling with -// `build_default_zeros_grads=true` (the default) is equivalent but just results -// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway. -using Tape = tensorflow::eager::GradientTape; +// Gradient functions defined for this tape must support handling null incoming +// gradients. +class Tape : protected eager::GradientTape { + public: + using GradientTape::GradientTape; + // Returns whether the tape is persistent, i.e., whether the tape will hold + // onto its internal state after a call to `ComputeGradient`. + using GradientTape::IsPersistent; + + // Adds this tensor to the list of watched tensors. + // + // This is a no-op if the tensor is already being watched either from an + // earlier call to `GradientTape::Watch` or being an output of an op with + // watched inputs. + void Watch(const AbstractTensorHandle*); + // Records an operation with given inputs and outputs + // on the tape and marks all its outputs as watched if at + // least one input of the op is watched and has a trainable dtype. + // op_name is optional and is used for debugging only. + void RecordOperation(absl::Span inputs, + absl::Span outputs, + GradientFunction* gradient_function, + const string& op_name = ""); + // Returns whether any tensor in a list of tensors is being watched and has + // a trainable dtype. + bool ShouldRecord( + absl::Span tensors) const; + // Unwatches this tensor on the tape. Mainly used for cleanup when deleting + // eager tensors. + void DeleteTrace(const AbstractTensorHandle*); + + // Consumes the internal state of the tape (so cannot be called more than + // once unless the tape is persistent) and produces the gradient of the target + // tensors with respect to the source tensors. The output gradients are used + // if not empty and not null. The result is populated with one tensor per + // target element. + Status ComputeGradient( + AbstractContext* ctx, absl::Span targets, + absl::Span sources, + absl::Span output_gradients, + absl::Span result); +}; } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 84ba0e061cc461..7692bd20234985 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -25,8 +25,10 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/unified_api_testutil.h" #include "tensorflow/c/experimental/gradients/array_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/not_differentiable.h" #include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" @@ -56,341 +58,11 @@ class CppGradients }; Status RegisterGradients(GradientRegistry* registry) { - // TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to - // AddV2Registerer. - TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("Sqrt", SqrtRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer)); + TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics")); return Status::OK(); } -// Computes -// y = inputs[0] + inputs[1] -// return grad(y, {inputs[0], inputs[1]}) -Status AddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = std::make_unique(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - std::vector add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); - TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, - absl::MakeSpan(add_outputs), - "Add")); // Compute x+y. - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto add_output : add_outputs) { - add_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - return Status::OK(); -} - -// Computes -// y = exp(inputs[0]) -// return grad(y, {inputs[0]}) -Status ExpGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = std::make_unique(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - std::vector exp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); - TF_RETURN_IF_ERROR( - ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto exp_output : exp_outputs) { - exp_output->Unref(); - } - outputs[0] = out_grads[0]; - return Status::OK(); -} - -// Computes -// y = sqrt(inputs[0]) -// return grad(y, {inputs[0]}) -Status SqrtGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = std::make_unique(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - std::vector sqrt_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); - TF_RETURN_IF_ERROR( - ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt")); - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto sqrt_output : sqrt_outputs) { - sqrt_output->Unref(); - } - outputs[0] = out_grads[0]; - return Status::OK(); -} - -// Computes -// ignored, y = IdentityN(inputs[0], inputs[1]) -// return grad(y, {inputs[0], inputs[1]}) -// This should return [nullptr, 1]. -Status IdentityNGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = std::make_unique(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); - tape->Watch(ToId(inputs[1])); - - vector identity_n_outputs(2); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); - TF_RETURN_IF_ERROR(ops::IdentityN( - tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); - - std::unordered_map - source_tensors_that_are_targets; - vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto identity_n_output : identity_n_outputs) { - identity_n_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - return Status::OK(); -} - -// Computes -// y = - inputs[0] -// return grad(y, {inputs[0]}) -Status NegGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = std::make_unique(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); - - std::vector neg_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); - TF_RETURN_IF_ERROR( - ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg")); - - std::unordered_map - source_tensors_that_are_targets; - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto neg_output : neg_outputs) { - neg_output->Unref(); - } - outputs[0] = out_grads[0]; - return Status::OK(); -} - -// Computes -// y = inputs[0] - inputs[1] -// return grad(y, {inputs[0], inputs[1]}) -Status SubGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = std::make_unique(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - std::vector sub_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry)); - TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs, - absl::MakeSpan(sub_outputs), - "Sub")); // Compute x-y. - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto sub_output : sub_outputs) { - sub_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - return Status::OK(); -} - -AbstractContext* BuildFunction(const char* fn_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); - return unwrap(graph_ctx); -} - -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - std::vector* params) { - tracing::TracingTensorHandle* handle = nullptr; - for (auto input : inputs) { - TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( - input->DataType(), &handle)); - params->emplace_back(handle); - } - return Status::OK(); -} - -using Model = std::function, - absl::Span, const GradientRegistry&)>; - -// Runs `model` maybe wrapped in a function. -Status RunModel(Model model, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, bool use_function, - const GradientRegistry& registry) { - if (use_function) { - const char* fn_name = "test_fn"; - std::unique_ptr scoped_func; - // Returning null tensors from a tf.function is not supported, so we keep - // track of indices in the model's outputs are nullptr in this set. - // The FunctionDef only outputs the non-null tensors. We later pad the - // function op outputs to have nullptrs at the `null_indices`. - absl::flat_hash_set null_indices; - { - AbstractContextPtr func_ctx(BuildFunction(fn_name)); - std::vector func_inputs; - func_inputs.reserve(inputs.size()); - TF_RETURN_IF_ERROR( - CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); - vector model_outputs; - model_outputs.resize(outputs.size()); - TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), - absl::MakeSpan(model_outputs), registry)); - for (auto func_input : func_inputs) { - func_input->Unref(); - } - AbstractFunction* func = nullptr; - OutputList output_list; - output_list.expected_num_outputs = 0; - output_list.outputs.reserve(outputs.size()); - for (int i = 0; i < model_outputs.size(); i++) { - if (model_outputs[i]) { - output_list.outputs.emplace_back(model_outputs[i]); - output_list.expected_num_outputs += 1; - } else { - null_indices.insert(i); - } - } - TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) - ->Finalize(&output_list, &func)); - scoped_func.reset(func); - for (auto output : output_list.outputs) { - output->Unref(); - } - TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); - } - - AbstractOperationPtr fn_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); - for (auto input : inputs) { - TF_RETURN_IF_ERROR(fn_op->AddInput(input)); - } - int retvals = outputs.size() - null_indices.size(); - vector fn_outputs(retvals); - TF_RETURN_IF_ERROR(fn_op->Execute( - absl::Span(fn_outputs.data(), fn_outputs.size()), - &retvals)); - int skipped_indices = 0; - for (int i = 0; i < outputs.size(); i++) { - if (!null_indices.contains(i)) { - outputs[i] = fn_outputs[i - skipped_indices]; - } else { - skipped_indices += 1; - } - } - TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); - return Status::OK(); - } else { - return model(ctx, inputs, outputs, registry); - } -} - -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetTfrt(opts, use_tfrt); - *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_DeleteContextOptions(opts); - return Status::OK(); -} - -Status TestScalarTensorHandle(AbstractContext* ctx, float value, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_TensorHandle* result_t = - TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); - return Status::OK(); -} - -TEST_P(CppGradients, TestAddGrad) { +TEST_P(CppGradients, TestSetAttrString) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; @@ -402,247 +74,60 @@ TEST_P(CppGradients, TestAddGrad) { ctx.reset(ctx_raw); } - AbstractTensorHandlePtr x; + AbstractTensorHandlePtr t; { AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - x.reset(x_raw); - } - - AbstractTensorHandlePtr y; - { - AbstractTensorHandle* y_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - y.reset(y_raw); + t.reset(x_raw); } - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Pseudo-code: - // - // tape.watch(x) - // tape.watch(y) - // y = x + y - // outputs = tape.gradient(y, [x, y]) - std::vector outputs(2); - s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* result_tensor; - s = getValue(outputs[0], &result_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_EQ(*result_value, 1.0); - outputs[0]->Unref(); - TF_DeleteTensor(result_tensor); - result_tensor = nullptr; - - s = getValue(outputs[1], &result_tensor); + AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); + ForwardOperation forward_op; + Status s = Reset(check_numerics_op.get(), "CheckNumerics", + /*raw_device_name=*/nullptr, &forward_op); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_EQ(*result_value, 1.0); - outputs[1]->Unref(); - TF_DeleteTensor(result_tensor); -} - -TEST_P(CppGradients, TestExpGrad) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - AbstractTensorHandlePtr x; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + if (isa(check_numerics_op.get())) { + s = dyn_cast(check_numerics_op.get()) + ->SetOpName("check_numerics"); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - x.reset(x_raw); } - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Pseudo-code: - // - // tape.watch(x) - // y = exp(x) - // outputs = tape.gradient(y, x) - std::vector outputs(1); - s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* result_tensor; - s = getValue(outputs[0], &result_tensor); + s = AddInput(check_numerics_op.get(), t.get(), &forward_op); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_NEAR(*result_value, 2.718, 0.001); - outputs[0]->Unref(); - TF_DeleteTensor(result_tensor); - result_tensor = nullptr; -} - -TEST_P(CppGradients, TestSqrtGrad) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - AbstractTensorHandlePtr x; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - x.reset(x_raw); - } - - GradientRegistry registry; - Status s = RegisterGradients(®istry); + string message = "This is the way!"; + s = SetAttrString(check_numerics_op.get(), "message", message.data(), + message.length(), &forward_op); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Pseudo-code: - // - // tape.watch(x) - // y = sqrt(x) - // outputs = tape.gradient(y, x) + int num_retvals = 1; std::vector outputs(1); - s = RunModel(SqrtGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* result_tensor; - s = getValue(outputs[0], &result_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_NEAR(*result_value, 0.5, 0.001); - outputs[0]->Unref(); - TF_DeleteTensor(result_tensor); - result_tensor = nullptr; -} - -TEST_P(CppGradients, TestIdentityNGrad) { - // Pseudo-code: - // - // tape.watch(x1) - // tape.watch(x2) - // unused, y = IdentityN([x1, x2]) - // outputs = tape.gradient(y, [x1, x2]) - // Expected: [nullptr, 1] - // - // This test is interesting because the current implementation of GradientTape - // would return [0, 1] whereas we use build_default_zeros_grads=false here - // so we get back [nullptr, 1]. - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - AbstractTensorHandlePtr x1; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - x1.reset(x_raw); - } - AbstractTensorHandlePtr x2; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - x2.reset(x_raw); - } - GradientRegistry registry; - Status s = RegisterGradients(®istry); + s = RegisterGradients(®istry); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - std::vector outputs(2); - s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); + auto tape = std::make_unique(/*persistent=*/false); + s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), + &num_retvals, &forward_op, tape.get(), registry); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - EXPECT_EQ(outputs[0], nullptr); - TF_Tensor* result_tensor; - s = getValue(outputs[1], &result_tensor); + string read_message; + s = forward_op.attrs.Get("message", &read_message); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_EQ(*result_value, 1.0); - outputs[1]->Unref(); - TF_DeleteTensor(result_tensor); - result_tensor = nullptr; + ASSERT_EQ(read_message, message); } -TEST_P(CppGradients, TestNegGrad) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - AbstractTensorHandlePtr x; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - x.reset(x_raw); - } - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Pseudo-code: - // - // tape.watch(x) - // y = - x - // outputs = tape.gradient(y, x) - std::vector outputs(1); - s = RunModel(NegGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* result_tensor; - s = getValue(outputs[0], &result_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_EQ(*result_value, -1.0); - outputs[0]->Unref(); - TF_DeleteTensor(result_tensor); - result_tensor = nullptr; +Status RecordOperationWithNullGradientFunctionModel( + AbstractContext* ctx, absl::Span inputs, + absl::Span outputs) { + Tape tape(/*persistent=*/false); + tape.Watch(inputs[0]); + std::vector neg_outputs(1); + TF_RETURN_IF_ERROR(ops::Neg(ctx, inputs, absl::MakeSpan(neg_outputs), "Neg")); + tape.RecordOperation(inputs, neg_outputs, nullptr, "Neg"); + return tape.ComputeGradient(ctx, /*targets=*/neg_outputs, + /*sources=*/inputs, + /*output_gradients=*/{}, outputs); } -TEST_P(CppGradients, TestSubGrad) { +TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; @@ -657,100 +142,22 @@ TEST_P(CppGradients, TestSubGrad) { AbstractTensorHandlePtr x; { AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); x.reset(x_raw); } - AbstractTensorHandlePtr y; - { - AbstractTensorHandle* y_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - y.reset(y_raw); - } - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Pseudo-code: - // - // tape.watch(x) - // tape.watch(y) - // y = x - y - // outputs = tape.gradient(y, [x, y]) - std::vector outputs(2); - s = RunModel(SubGradModel, ctx.get(), {x.get(), y.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* result_tensor; - s = getValue(outputs[0], &result_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_EQ(*result_value, 1.0); - outputs[0]->Unref(); - TF_DeleteTensor(result_tensor); - result_tensor = nullptr; - - s = getValue(outputs[1], &result_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - result_value = static_cast(TF_TensorData(result_tensor)); - EXPECT_EQ(*result_value, -1.0); - outputs[1]->Unref(); - TF_DeleteTensor(result_tensor); -} - -TEST_P(CppGradients, TestSetAttrString) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - AbstractTensorHandlePtr t; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - t.reset(x_raw); - } - - AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); - ForwardOperation forward_op; - Status s = Reset(check_numerics_op.get(), "CheckNumerics", - /*raw_device_name=*/nullptr, &forward_op); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - if (isa(check_numerics_op.get())) { - s = dyn_cast(check_numerics_op.get()) - ->SetOpName("check_numerics"); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - } - s = AddInput(check_numerics_op.get(), t.get(), &forward_op); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - string message = "This is the way!"; - s = SetAttrString(check_numerics_op.get(), "message", message.data(), - message.length(), &forward_op); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - int num_retvals = 1; std::vector outputs(1); - GradientRegistry registry; - auto tape = std::make_unique(/*persistent=*/false); - s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs), - &num_retvals, &forward_op, tape.get(), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - string read_message; - s = forward_op.attrs.Get("message", &read_message); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ASSERT_EQ(read_message, message); + Status s = RunModel(RecordOperationWithNullGradientFunctionModel, ctx.get(), + {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam())); + ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); + ASSERT_EQ( + "Provided null gradient_function for 'Neg'.\nIf the intent is to treat " + "this op as non-differentiable consider using RegisterNotDifferentiable " + "or NotDifferentiableGradientFunction.", + s.error_message()); + ASSERT_EQ(nullptr, outputs[0]); } // TODO(b/164171226): Enable this test with tfrt after AddInputList is diff --git a/tensorflow/c/eager/gradients_util.cc b/tensorflow/c/eager/gradients_util.cc deleted file mode 100644 index e53faf4a3f3fdf..00000000000000 --- a/tensorflow/c/eager/gradients_util.cc +++ /dev/null @@ -1,317 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/eager/gradients_util.h" - -#include - -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/experimental/ops/math_ops.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" - -namespace tensorflow { -namespace gradients { - -using namespace std; - -Status ScalarTensorHandleHelper(TFE_Context* ctx, float value, - TFE_TensorHandle** result) { - float data[] = {value}; - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_Tensor* t = - TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get()); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); - *result = th; - TF_DeleteTensor(t); - return StatusFromTF_Status(status.get()); -} - -Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[], - int64_t dims[], int num_dims, - TFE_TensorHandle** result) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_Tensor* t = - TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get()); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); - *result = th; - TF_DeleteTensor(t); - return StatusFromTF_Status(status.get()); -} - -Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[], - int64_t dims[], int num_dims, - TFE_TensorHandle** result) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_Tensor* t = - TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get()); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get()); - *result = th; - TF_DeleteTensor(t); - return StatusFromTF_Status(status.get()); -} - -// Get a scalar TensorHandle with given value -Status ScalarTensorHandle(AbstractContext* ctx, float value, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager; - TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager)); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return StatusFromTF_Status(status.get()); -} - -// Get a TensorHandle with given float values and dimensions -Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager; - TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, - num_dims, &input_eager)); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return StatusFromTF_Status(status.get()); -} - -// Get a TensorHandle with given int values and dimensions -Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[], - int num_dims, AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager; - TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims, - num_dims, &input_eager)); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return StatusFromTF_Status(status.get()); -} - -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_TensorHandle* result_t = - TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); - return StatusFromTF_Status(status.get()); -} - -AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, - float vals[], int64_t dims[], - int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); - if (s.ok()) { - A.reset(a_raw); - } - return A; -} - -AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], - int64_t dims[], int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); - if (s.ok()) { - A.reset(a_raw); - } - return A; -} - -AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, - float val) { - AbstractTensorHandlePtr y; - AbstractTensorHandle* y_raw = nullptr; - Status s = ScalarTensorHandle(ctx, val, &y_raw); - if (s.ok()) { - y.reset(y_raw); - } - return y; -} - -Status UpdateWeights(AbstractContext* ctx, vector& grads, - vector& weights, - AbstractTensorHandle* learning_rate) { - /* Update weights one by one using gradient update rule: - * - * w -= lr*grad[w] - * - * NOTE: assuming learning rate is positive - */ - - int num_grads = grads.size(); - vector temp_outputs(1); - std::string update_str; - - // Negate learning rate for gradient descent - TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, - absl::MakeSpan(temp_outputs), - "neg_lr")); // Compute -lr - learning_rate = temp_outputs[0]; - - for (int i = 0; i < num_grads; i++) { - // Compute dW = -lr * grad(w[i]) - update_str = "update_mul_" + std::to_string(i); - TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]}, - absl::MakeSpan(temp_outputs), - update_str.c_str())); - - AbstractTensorHandle* dW = temp_outputs[0]; - - // Compute temp = weights[i] + dW - update_str = "update_add_" + std::to_string(i); - TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW}, - absl::MakeSpan(temp_outputs), - update_str.c_str())); - - // Update the weights - weights[i] = temp_outputs[0]; - } - - return Status::OK(); -} - -AbstractContext* BuildFunction(const char* fn_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); - return unwrap(graph_ctx); -} - -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - vector* params) { - tracing::TracingTensorHandle* handle = nullptr; - for (auto input : inputs) { - TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( - input->DataType(), &handle)); - params->emplace_back(handle); - } - return Status::OK(); -} - -Status RunModel(Model model, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, bool use_function, - const GradientRegistry& registry) { - if (use_function) { - const char* fn_name = "test_fn"; - std::unique_ptr scoped_func; - // Returning null tensors from a tf.function is not supported, so we keep - // track of indices in the model's outputs are nullptr in this set. - // The FunctionDef only outputs the non-null tensors. We later pad the - // function op outputs to have nullptrs at the `null_indices`. - absl::flat_hash_set null_indices; - { - AbstractContextPtr func_ctx(BuildFunction(fn_name)); - vector func_inputs; - func_inputs.reserve(inputs.size()); - TF_RETURN_IF_ERROR( - CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); - vector model_outputs; - model_outputs.resize(outputs.size()); - TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), - absl::MakeSpan(model_outputs), registry)); - for (auto func_input : func_inputs) { - func_input->Unref(); - } - AbstractFunction* func = nullptr; - OutputList output_list; - output_list.expected_num_outputs = 0; - output_list.outputs.reserve(outputs.size()); - for (int i = 0; i < model_outputs.size(); i++) { - if (model_outputs[i]) { - output_list.outputs.emplace_back(model_outputs[i]); - output_list.expected_num_outputs += 1; - } else { - null_indices.insert(i); - } - } - TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) - ->Finalize(&output_list, &func)); - scoped_func.reset(func); - for (auto output : output_list.outputs) { - output->Unref(); - } - TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); - } - - AbstractOperationPtr fn_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); - for (auto input : inputs) { - TF_RETURN_IF_ERROR(fn_op->AddInput(input)); - } - int retvals = outputs.size() - null_indices.size(); - vector fn_outputs(retvals); - TF_RETURN_IF_ERROR(fn_op->Execute( - absl::Span(fn_outputs.data(), fn_outputs.size()), - &retvals)); - int skipped_indices = 0; - for (int i = 0; i < outputs.size(); i++) { - if (!null_indices.contains(i)) { - outputs[i] = fn_outputs[i - skipped_indices]; - } else { - skipped_indices += 1; - } - } - TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); - return Status::OK(); - } else { - return model(ctx, inputs, outputs, registry); - } -} - -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetTfrt(opts, use_tfrt); - *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_DeleteContextOptions(opts); - return Status::OK(); -} - -} // namespace gradients -} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/c/eager/gradients_util.h b/tensorflow/c/eager/gradients_util.h deleted file mode 100644 index cd0bbc0720d072..00000000000000 --- a/tensorflow/c/eager/gradients_util.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/experimental/ops/math_ops.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace gradients { - -// Get a scalar TensorHandle with given value -Status ScalarTensorHandle(AbstractContext* ctx, float value, - AbstractTensorHandle** tensor); - -// Get a TensorHandle with given float values and dimensions -Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor); - -// Get a TensorHandle with given int values and dimensions -Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[], - int num_dims, AbstractTensorHandle** tensor); - -// Places data from `t` into *result_tensor. -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); - -// Util function that wraps an AbstractTensorHandle* with given data and dims. -AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, - float vals[], int64_t dims[], - int num_dims); - -// Util function that wraps an AbstractTensorHandle* with given data and dims. -AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], - int64_t dims[], int num_dims); - -// Util function that wraps an AbstractTensorHandle* with given data. -AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, - float val); - -// Performs gradient update for each weight using given learning rate. -Status UpdateWeights(AbstractContext* ctx, - std::vector& grads, - std::vector& weights, - AbstractTensorHandle* learning_rate); - -using Model = std::function, - absl::Span, const GradientRegistry&)>; - -// Runs given model in either graph or eager mode depending on value of -// use_function. -Status RunModel(Model model, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, bool use_function, - const GradientRegistry& registry); - -// Builds context and returns inside *ctx. -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); - -} // namespace gradients -} // namespace tensorflow diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index a3e3857b34b1b2..90ada313776787 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -21,18 +21,27 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/immediate_execution_distributed_manager.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { class EagerExecutor; +class EagerContext; +class CustomDevice; +class CustomDeviceOpHandler; +class Device; // LINT.IfChange // Note: Keep in sync with exported copy of enum in eager/c_api.h. @@ -106,11 +115,18 @@ class ImmediateExecutionContext : public AbstractContext { // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; + // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under + // the key of the function definition name (to be retrieved during function + // instantiation). + virtual Status AddFunctionDefWithStackTraces( + const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0; + // Find and return a added function by its name. virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; // Return the ParsedName of Host CPU device. virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + virtual const string& HostCPUName() const = 0; // Configure soft device placement policy. virtual void SetAllowSoftPlacement(bool enable) = 0; @@ -124,14 +140,44 @@ class ImmediateExecutionContext : public AbstractContext { // Returns the device placement policy for the current thread. virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + + // Return the collected RunMetadata. This method will transfer the ownership + // to the caller. + virtual std::unique_ptr ExportRunMetadata() = 0; + // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } //===--------------------------------------------------------------------===// - // Following are legacy features in TF Eager Runtime. - // TODO(tf-runtime): Figure out a way to deprecate following features after + // Experimental Custom Device. + //===--------------------------------------------------------------------===// + virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; + + // Register a custom device. It will return error is the device name is + // already registered. + // TODO(tfrt-devs): Remove this method. Let caller register it directly into + // CustomDeviceOpHandler. + virtual Status RegisterCustomDevice(const string& name, + std::unique_ptr device) = 0; + + // Return FunctionLibraryDefinition. Transformations need to use it to use it + // to invoke MLIR compiler passes. + virtual FunctionLibraryDefinition* FuncLibDef() = 0; + + // When tensor transfer across functions/eager executions using send/recv ops + // are required, `reuse_rendezvous_for_functions_` can be set to true so that + // function executions and eager executions use the same rendezvous instance, + // instead of creating new instance per function calls. + virtual void SetReuseRendezvousForFunctions( + bool reuse_rendezvous_for_functions) = 0; + + //===--------------------------------------------------------------------===// + // Following are features in current TF Eager Runtime. + // TODO(tfrt-devs): Figure out a way to deprecate following features after // migrated to TFRT. //===--------------------------------------------------------------------===// // Clear pending nodes in thread executors and kernel caches. @@ -149,8 +195,42 @@ class ImmediateExecutionContext : public AbstractContext { // Update the Eager Executor for current thread. virtual void SetExecutorForThread(EagerExecutor* executor) = 0; - // Configure graph collection in RunMetadata. - virtual void SetShouldStoreGraphs(bool value) = 0; + // Return a list of local tensorflow::Device*. + // TODO(tfrt-devs): We shouldn't expose legacy device in this API. + virtual std::vector ListLocalTfDevices() = 0; + + //===--------------------------------------------------------------------===// + // Following are helper functions to assist integrating TFRT with current + // TF eager runtime. + // TODO(b/172877902): These helper functions are currently used to support + // PyFuncOp on TFRT, and might be useful for ops that directly use low + // level TF APIs. Remove/replace the following functions when TFRT native + // ops are implemented. + //===--------------------------------------------------------------------===// + // Create an abstract tensor handle from tensorflow::Tensor. + virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor( + tensorflow::Tensor& t, const char* d_name) = 0; + + // Convert a TFRT TensorHandle to tensorflow::TensorHandle. + virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( + ImmediateExecutionTensorHandle* handle) = 0; + + virtual std::vector GetLoggedOpsTestonly() { return {}; } + + // Get a list of the names of functions that have been registered. + virtual std::vector ListFunctionNames() = 0; + + //===--------------------------------------------------------------------===// + // Distributed runtime related functions. + //===--------------------------------------------------------------------===// +#if !defined(IS_MOBILE_PLATFORM) + // Set a distributed manager that helps set up, update, and check liveness + // of member tasks in the cluster. + virtual void SetDistributedManager( + std::unique_ptr distributed) = 0; + + virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0; +#endif // !IS_MOBILE_PLATFORM protected: explicit ImmediateExecutionContext(AbstractContextKind kind) diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h new file mode 100644 index 00000000000000..b43649a59663d7 --- /dev/null +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_ +#define TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_ + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +class ImmediateExecutionContext; +class ServerDef; + +class ImmediateExecutionDistributedManager { + public: + virtual ~ImmediateExecutionDistributedManager() {} + + // Set up distributed execution environment on local and remote tasks. + // When `reset_context` is true, initialize new cluster context state based on + // cluster configurations provided in `server_def`; otherwise, update existing + // context state with the provided `server_def`. + // Contexts created on remote tasks will be considered stale and garbage + // collected after `keep_alive_secs` of inactivity. + virtual Status SetOrUpdateServerDef(const ServerDef& server_def, + bool reset_context, + int keep_alive_secs) = 0; + + // Set up a multi-client distributed execution environment. Must be called on + // all tasks in the cluster. + // This call internally coordinates with other tasks to initialize the eager + // context and TF server for multi-client execution. + virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; + + // Check if the remote task is alive. + virtual Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) = 0; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_immediate_execution_distributed_manager_H_ diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index 7b68ec2c9f4a0b..5c944837f53dfb 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -27,12 +27,16 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/util/abstract_stack_trace.h" +#include "tensorflow/core/util/managed_stack_trace.h" struct TFE_Op; namespace tensorflow { +class ImmediateExecutionContext; +class AbstractOpAttrs; +class CancellationManager; + // Abstract interface to an operation. class ImmediateExecutionOperation : public AbstractOperation { public: @@ -41,6 +45,15 @@ class ImmediateExecutionOperation : public AbstractOperation { // Returns the inputs of this op. virtual absl::Span GetInputs() const = 0; + virtual Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) = 0; + + virtual ImmediateExecutionContext* GetContext() const = 0; + + // Following two methods are used to support custom device. + // Return true if the inputs contain custom device tensor handle. It means + // that the argument need to be handled by a custom device. + virtual bool HasCustomDeviceInput() const = 0; virtual const tensorflow::OpDef* OpDef() const = 0; @@ -48,10 +61,16 @@ class ImmediateExecutionOperation : public AbstractOperation { virtual Status OutputLength(const char* output_name, int* length) = 0; // Set stack trace to be used for potential async error reporting. - virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; + virtual void SetStackTrace(ManagedStackTrace stack_trace) = 0; + + virtual const tensorflow::AbstractOpAttrs* GetOpAttrs() const = 0; + virtual void AddAttrs(const AbstractOpAttrs* op_attrs) = 0; + + virtual void SetCancellationManager( + CancellationManager* cancellation_manager) = 0; // Returns the stack trace set by `SetStackTrace` if exists. - virtual absl::optional GetStackTrace() = 0; + virtual absl::optional GetStackTrace() = 0; // For LLVM style RTTI. static bool classof(const AbstractOperation* ptr) { diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.cc b/tensorflow/c/eager/immediate_execution_tensor_handle.cc new file mode 100644 index 00000000000000..816c92c2e19cd9 --- /dev/null +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.cc @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" + +namespace tensorflow { + +std::string ImmediateExecutionTensorHandle::DebugString() const { + PartialTensorShape shape; + std::string shape_string; + if (Shape(&shape).ok()) { + shape_string = shape.DebugString(); + } else { + shape_string = ""; + } + std::string value_string; + if (!SummarizeValue(value_string).ok()) { + value_string = ""; + } + return absl::StrCat("TensorHandle(", value_string, ", shape=", shape_string, + ", dtype=", DataType_Name(DataType()), ")"); +} + +Status ImmediateExecutionTensorHandle::SummarizeValue( + std::string& summary) const { + Status status; + AbstractTensorPtr resolved( + // TODO(allenl): Resolve should be const, and the caches that get updated + // marked mutable. + const_cast(this)->Resolve(&status)); + if (!status.ok()) { + return status; + } + summary = resolved->SummarizeValue(); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index bb6d471f12f12b..cca5d59b8179c0 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -54,6 +54,25 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { // Return a copy of the handle. virtual ImmediateExecutionTensorHandle* Copy() = 0; + std::string DebugString() const override; + + // Returns a Boolean hint indicating whether callers should prefer + // `SummarizeValue` to resolving this handle and formatting the tensor. + // + // For example some tensor handles may represent distributed values, in which + // case placement information is lost when resolving the handle. + // + // If false, a caller might implement pretty-printing by resolving and + // iterating over the resulting tensor. This may still be viable if resolving + // the handle loses information, but `SummarizeValue` would be more precise. + virtual bool HasCustomSummarizer() const { return false; } + + // Returns a string which summarizes the value of this TensorHandle, for + // debugging. Does not include a shape or dtype. + // + // Included in the default implementation of DebugString. + virtual Status SummarizeValue(std::string& summary) const; + // Release any underlying resources, including the interface object. // // WARNING: The destructor of this class is marked as protected to disallow diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc deleted file mode 100644 index 16cb01110fd6ec..00000000000000 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ /dev/null @@ -1,729 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/eager/gradients_util.h" -#include "tensorflow/c/eager/mnist_gradients_testutil.h" -#include "tensorflow/c/experimental/gradients/math_grad.h" -#include "tensorflow/c/experimental/gradients/nn_grad.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/tensor_float_32_utils.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace gradients { -namespace internal { -namespace { -using tensorflow::TF_StatusPtr; - -class CppGradients - : public ::testing::TestWithParam> { - protected: - void SetUp() override { - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = StatusFromTF_Status(status.get()); - CHECK_EQ(errors::OK, s.code()) << s.error_message(); - - // Computing numerical gradients with TensorFloat-32 is numerically - // unstable. Some forward pass tests also fail with TensorFloat-32 due to - // low tolerances - enable_tensor_float_32_execution(false); - } -}; - -Status RegisterGradients(GradientRegistry* registry) { - TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); - TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); - TF_RETURN_IF_ERROR( - registry->Register("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyWithLogitsRegisterer)); - return Status::OK(); -} - -TEST_P(CppGradients, TestMatMulGrad) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t A_dims[] = {2, 2}; - float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; - int64_t B_dims[] = {2, 2}; - int num_dims = 2; - - AbstractTensorHandlePtr A = - GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); - AbstractTensorHandlePtr B = - GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - /* Pseudo-code: - * - * tape.watch(A) - * tape.watch(B) - * Y = AB - * outputs = tape.gradient(Y, [A, B]) - */ - - std::vector outputs(2); - s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* dA_tensor; - s = GetValue(outputs[0], &dA_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[4] = {0}; - memcpy(&result_data[0], TF_TensorData(dA_tensor), - TF_TensorByteSize(dA_tensor)); - - float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; - float tolerance = 1e-3; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data[j], expected_dA[j], tolerance); - } - - TF_Tensor* dB_tensor; - s = GetValue(outputs[1], &dB_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - memcpy(&result_data[0], TF_TensorData(dB_tensor), - TF_TensorByteSize(dB_tensor)); - - float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f}; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data[j], expected_dB[j], tolerance); - } - - outputs[0]->Unref(); - outputs[1]->Unref(); - TF_DeleteTensor(dA_tensor); - TF_DeleteTensor(dB_tensor); -} - -TEST_P(CppGradients, TestMNISTForward) { - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = data - float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t dims[] = {2, 2}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims); - - // W1 = first weights - float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; - AbstractTensorHandlePtr W1 = - GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); - - // W2 = second weights - float W2_vals[] = {.1f, .2f, .3f, -.5f}; - AbstractTensorHandlePtr W2 = - GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); - - // y = labels - int y_vals[] = {1, 1}; - int64_t dims_y[] = {2}; - num_dims = sizeof(dims_y) / sizeof(dims_y[0]); - AbstractTensorHandlePtr y = - GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims); - - GradientRegistry registry; - - // Run the Forward Pass - std::vector outputs(2); - Status s = - RunModel(MNISTForwardModel, ctx.get(), - {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Verify the Results - TF_Tensor* scores_tensor; - s = GetValue(outputs[0], &scores_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[4] = {0}; - memcpy(&result_data[0], TF_TensorData(scores_tensor), - TF_TensorByteSize(scores_tensor)); - - float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f}; - float tolerance = 1e-3; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); - } - - TF_Tensor* loss_vals_tensor; - s = GetValue(outputs[1], &loss_vals_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - memcpy(&result_data[0], TF_TensorData(loss_vals_tensor), - TF_TensorByteSize(loss_vals_tensor)); - float expected_losses[2] = {9.6f, 27.2f}; - for (int j = 0; j < 2; j++) { - ASSERT_NEAR(result_data[j], expected_losses[j], tolerance); - } - - outputs[0]->Unref(); - outputs[1]->Unref(); - TF_DeleteTensor(scores_tensor); - TF_DeleteTensor(loss_vals_tensor); -} - -TEST_P(CppGradients, TestMNISTForward2) { - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = data - float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - int64_t X_dims[] = {3, 2}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - // W1 = first weights - float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; - int64_t dims[] = {2, 2}; - AbstractTensorHandlePtr W1 = - GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); - - // W2 = second weights - float W2_vals[] = {.1f, .2f, .3f, -.5f}; - AbstractTensorHandlePtr W2 = - GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); - - // y = labels - int y_vals[] = {1, 1, 1}; - int64_t y_dims[] = {3}; - num_dims = sizeof(y_dims) / sizeof(y_dims[0]); - AbstractTensorHandlePtr y = - GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); - - GradientRegistry registry; - - // Run the Forward Pass - std::vector outputs(2); - Status s = - RunModel(MNISTForwardModel, ctx.get(), - {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Verify the Results - TF_Tensor* scores_tensor; - s = GetValue(outputs[0], &scores_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[6] = {0}; - memcpy(&result_data[0], TF_TensorData(scores_tensor), - TF_TensorByteSize(scores_tensor)); - - float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f}; - float tolerance = 1e-3; - for (int j = 0; j < 6; j++) { - ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); - } - - TF_Tensor* loss_vals_tensor; - s = GetValue(outputs[1], &loss_vals_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - memcpy(&result_data[0], TF_TensorData(loss_vals_tensor), - TF_TensorByteSize(loss_vals_tensor)); - float expected_losses[3] = {9.6f, 27.2f, 44.8f}; - for (int j = 0; j < 3; j++) { - ASSERT_NEAR(result_data[j], expected_losses[j], tolerance); - } - - outputs[0]->Unref(); - outputs[1]->Unref(); - TF_DeleteTensor(scores_tensor); - TF_DeleteTensor(loss_vals_tensor); -} - -TEST_P(CppGradients, TestMatMulTranspose) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = data - float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - int64_t X_dims[] = {2, 3}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - // W1 = first weights - float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t dims[] = {2, 2}; - AbstractTensorHandlePtr W1 = - GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); - - GradientRegistry registry; - - // Run the MatMul Op - std::vector outputs(1); - - Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Verify the Results - TF_Tensor* scores_tensor; - s = GetValue(outputs[0], &scores_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[6] = {0}; - memcpy(&result_data[0], TF_TensorData(scores_tensor), - TF_TensorByteSize(scores_tensor)); - - float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f}; - float tolerance = 1e-3; - for (int j = 0; j < 6; j++) { - ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); - } -} - -TEST_P(CppGradients, TestReluGrad) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = data - float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; - int64_t X_dims[] = {3, 3}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - /* Pseudo-code: - * - * tape.watch(X) - * Y = Relu(X) - * outputs = tape.gradient(Y, [X]) - */ - std::vector outputs(1); - s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* dX_tensor; - s = GetValue(outputs[0], &dX_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[9] = {0}; - memcpy(&result_data[0], TF_TensorData(dX_tensor), - TF_TensorByteSize(dX_tensor)); - - float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; - float tolerance = 1e-3; - for (int j = 0; j < 9; j++) { - ASSERT_NEAR(result_data[j], expected_dX[j], tolerance); - } - - outputs[0]->Unref(); - TF_DeleteTensor(dX_tensor); -} - -TEST_P(CppGradients, TestSoftmaxLossGrad) { - bool use_function = !std::get<2>(GetParam()); - if (use_function) { - // TODO(b/168850692): Enable this. - GTEST_SKIP() << "Can't take gradient of " - "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; - } - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = scores - float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; - int64_t X_dims[] = {3, 3}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - // y = labels - int y_vals[] = {1, 0, 1}; - int64_t y_dims[] = {3}; - num_dims = sizeof(y_dims) / sizeof(y_dims[0]); - AbstractTensorHandlePtr y = - GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - /* Pseudo-code: - * - * tape.watch(X) - * tape.watch(labels) - * loss = SoftmaxLoss(X, labels) - * outputs = tape.gradient(loss, [X, labels]) - * - * - */ - - std::vector outputs(2); - s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* dX_tensor; - s = GetValue(outputs[0], &dX_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[9] = {0}; - memcpy(&result_data[0], TF_TensorData(dX_tensor), - TF_TensorByteSize(dX_tensor)); - - float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f, - 0.6652f, 0.8437f, -0.8858f, 0.0420f}; - float tolerance = 1e-3; - for (int j = 0; j < 9; j++) { - ASSERT_NEAR(result_data[j], expected_dX[j], tolerance); - } - - // Only Unref() first output as 2nd is nullptr grad for labels - outputs[0]->Unref(); - TF_DeleteTensor(dX_tensor); -} - -TEST_P(CppGradients, TestMNISTGrad) { - bool use_function = !std::get<2>(GetParam()); - if (use_function) { - // TODO(b/168850692): Enable this. - GTEST_SKIP() << "Can't take gradient of " - "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; - } - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = data - float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t X_dims[] = {2, 2}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - // W1 = first weights - float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; - int64_t dims[] = {2, 2}; - AbstractTensorHandlePtr W1 = - GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); - - // W2 = second weights - float W2_vals[] = {.1f, .2f, .3f, -.5f}; - AbstractTensorHandlePtr W2 = - GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); - - // y = labels - int y_vals[] = {1, 1}; - int64_t y_dims[] = {2}; - num_dims = sizeof(y_dims) / sizeof(y_dims[0]); - AbstractTensorHandlePtr y = - GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); - - // Register Grads - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - /* Pseudo-code: - * - * - * tape.watch(W1) - * tape.watch(W2) - * mm = X*W1 - * hidden = Relu(mm) - * scores = W2*hidden - * loss = SoftmaxLoss(scores, y) - * outputs = tape.gradient(loss, [A, B]) - * - */ - - std::vector outputs(3); - s = RunModel(MNISTGradModel, ctx.get(), - {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float tolerance = 1e-3; - TF_Tensor* dW1_tensor; - s = GetValue(outputs[0], &dW1_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[4] = {0}; - memcpy(&result_data[0], TF_TensorData(dW1_tensor), - TF_TensorByteSize(dW1_tensor)); - - float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f}; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance); - } - - TF_Tensor* dW2_tensor; - s = GetValue(outputs[1], &dW2_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - memcpy(&result_data[0], TF_TensorData(dW2_tensor), - TF_TensorByteSize(dW2_tensor)); - - float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance); - } - - outputs[0]->Unref(); - outputs[1]->Unref(); - outputs[2]->Unref(); - TF_DeleteTensor(dW1_tensor); - TF_DeleteTensor(dW2_tensor); -} - -TEST_P(CppGradients, TestScalarMul) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - AbstractTensorHandlePtr eta; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - eta.reset(x_raw); - } - - float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t A_dims[] = {2, 2}; - int num_dims = 2; - - AbstractTensorHandlePtr A = - GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); - - GradientRegistry registry; - std::vector outputs(1); - Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* dA_tensor; - s = GetValue(outputs[0], &dA_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data[4] = {0}; - memcpy(&result_data[0], TF_TensorData(dA_tensor), - TF_TensorByteSize(dA_tensor)); - - float tolerance = 1e-3; - float eta_val = 1.5f; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance); - } - - outputs[0]->Unref(); - TF_DeleteTensor(dA_tensor); -} - -TEST_P(CppGradients, TestMNIST_Training) { - bool use_function = !std::get<2>(GetParam()); - if (use_function) { - // TODO(b/168850692): Enable this. - GTEST_SKIP() << "Can't take gradient of " - "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; - } - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } - - // X = data - float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; - int64_t X_dims[] = {2, 2}; - int num_dims = 2; - AbstractTensorHandlePtr X = - GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); - - // TODO(amturati): use random initializer for weights instead of - // constant values. - - // W1 = first weights - float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f}; - int64_t dims[] = {2, 2}; - AbstractTensorHandlePtr W1 = - GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); - - // W2 = second weights - float W2_vals[] = {.1f, .2f, .3f, -.5f}; - AbstractTensorHandlePtr W2 = - GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); - - // y = labels - int y_vals[] = {1, 1}; - int64_t y_dims[] = {2}; - num_dims = sizeof(y_dims) / sizeof(y_dims[0]); - AbstractTensorHandlePtr y = - GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); - - // Register Grads - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Prepare for training - std::vector weights; - weights.push_back(W1.get()); - weights.push_back(W2.get()); - - // Set learning rate to be 1e-1 - AbstractTensorHandle* learning_rate = nullptr; - s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Train - int num_iters = 10; - std::vector mnist_outputs(3); - std::vector grads(2); - for (int i = 0; i < num_iters; i++) { - // Run Forward Pass - s = RunModel(MNISTGradModel, ctx.get(), - {X.get(), weights[0], weights[1], y.get()}, - absl::MakeSpan(mnist_outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - // Fill grads - grads[0] = mnist_outputs[0]; - grads[1] = mnist_outputs[1]; - - // Gradient Update - s = UpdateWeights(ctx.get(), grads, weights, learning_rate); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - } - - grads[0]->Unref(); // release W1_grad - grads[1]->Unref(); // release W2_grad - mnist_outputs[2]->Unref(); // release loss -} - -#ifdef PLATFORM_GOOGLE -INSTANTIATE_TEST_SUITE_P( - UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef", "mlir"), - /*tfrt*/ ::testing::Values(false), - /*executing_eagerly*/ ::testing::Values(true, false))); -#else -INSTANTIATE_TEST_SUITE_P( - UnifiedCAPI, CppGradients, - ::testing::Combine(::testing::Values("graphdef", "mlir"), - /*tfrt*/ ::testing::Values(false), - /*executing_eagerly*/ ::testing::Values(true, false))); -#endif -} // namespace -} // namespace internal -} // namespace gradients -} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc deleted file mode 100644 index 6688d9d4e75fee..00000000000000 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ /dev/null @@ -1,415 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/eager/mnist_gradients_testutil.h" - -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/eager/gradients_util.h" -#include "tensorflow/c/experimental/gradients/tape/tape_context.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/experimental/ops/math_ops.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" - - -namespace tensorflow { -namespace gradients { -namespace internal { - -using std::vector; - -//===================== Test Models to run ========================= - -// Computes -// y = inputs[0] + inputs[1] -// return grad(y, {inputs[0], inputs[1]}) -Status AddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - std::vector add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR( - ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add")); - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto add_output : add_outputs) { - add_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - delete tape; - return Status::OK(); -} - -// Computes -// y = inputs[0] * inputs[1] -// return grad(y, {inputs[0], inputs[1]}) -Status MatMulGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - vector mm_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs, - absl::MakeSpan(mm_outputs), "matmul0", - /*transpose_a=*/false, - /*transpose_b=*/false)); // Compute x*y. - - std::unordered_map - source_tensors_that_are_targets; - - vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto mm_output : mm_outputs) { - mm_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - delete tape; - return Status::OK(); -} - -// Model to run 2-layer net -Status MNISTForwardModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - /** - * We will trace a 2-layer fully connected network for an MNIST model: - * - * def mnist_forward(X, W1, W2, y_labels): - * mm_out_1 = tf.matmul(X,W1) - * hidden_layer = tf.nn.relu(mm_out_1) - * scores = tf.matmul(hidden_layer,W2) - * softmax = - * tf.nn.sparse_softmax_cross_entropy_with_logits(scores, - * y_labels) - * return scores, softmax - * - * Use this convention for inputs: - * - * inputs = [X, W1, W2, y_labels] - * - */ - AbstractTensorHandle* X = inputs[0]; - AbstractTensorHandle* W1 = inputs[1]; - AbstractTensorHandle* W2 = inputs[2]; - AbstractTensorHandle* y_labels = inputs[3]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(W1)); // Watch W1. - tape->Watch(ToId(W2)); // Watch W2. - vector temp_outputs(1); - - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, - absl::MakeSpan(temp_outputs), "matmul0", - /*transpose_a=*/false, - /*transpose_b=*/false)); // Compute X*W1 - - TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]}, - absl::MakeSpan(temp_outputs), - "relu")); // Compute Relu(X*W1) - - TF_RETURN_IF_ERROR(ops::MatMul( - tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), - "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1) - - AbstractTensorHandle* scores = temp_outputs[0]; - - temp_outputs.resize(2); - TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( - tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmax_loss")); // Compute Softmax(Scores,labels) - - AbstractTensorHandle* loss_vals = temp_outputs[0]; - - outputs[0] = scores; - outputs[1] = loss_vals; - delete tape; - return Status::OK(); -} - -Status MatMulTransposeModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractTensorHandle* X = inputs[0]; - AbstractTensorHandle* W1 = inputs[1]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(X)); - tape->Watch(ToId(W1)); - vector temp_outputs(1); - - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, - absl::MakeSpan(temp_outputs), "matmul0", - /*transpose_a=*/true, - /*transpose_b=*/false)); // Compute X*W1 - - outputs[0] = temp_outputs[0]; - - delete tape; - return Status::OK(); -} - -Status ReluGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch X - vector relu_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs, - absl::MakeSpan(relu_outputs), - "relu0")); // Relu(X) - - std::unordered_map - source_tensors_that_are_targets; - - vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - - for (auto relu_output : relu_outputs) { - relu_output->Unref(); - } - - outputs[0] = out_grads[0]; - delete tape; - return Status::OK(); -} - -Status SoftmaxLossGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch scores. - tape->Watch(ToId(inputs[1])); // Watch labels. - vector sm_outputs(2); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( - tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0")); - - std::unordered_map - source_tensors_that_are_targets; - - vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - delete tape; - return Status::OK(); -} - -Status MNISTGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractTensorHandle* X = inputs[0]; - AbstractTensorHandle* W1 = inputs[1]; - AbstractTensorHandle* W2 = inputs[2]; - AbstractTensorHandle* y_labels = inputs[3]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/true); - tape->Watch(ToId(X)); // Watch X. - tape->Watch(ToId(W1)); // Watch W1. - tape->Watch(ToId(W2)); // Watch W1. - vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, - absl::MakeSpan(temp_outputs), "matmul0", - /*transpose_a=*/false, - /*transpose_b=*/false)); // Compute X*W1 - - AbstractTensorHandle* mm = temp_outputs[0]; - - TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm}, - absl::MakeSpan(temp_outputs), // Relu(X*W1) - "relu0")); - - AbstractTensorHandle* hidden = temp_outputs[0]; - - TF_RETURN_IF_ERROR(ops::MatMul( - tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1) - - AbstractTensorHandle* scores = temp_outputs[0]; - - temp_outputs.resize(2); - TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( - tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmaxloss")); // W2*Relu(X*W1) - - AbstractTensorHandle* loss = temp_outputs[0]; - - std::unordered_map - source_tensors_that_are_targets; - - vector out_grads; - TF_RETURN_IF_ERROR( - tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)}, - /*source_tensor_ids=*/{ToId(W1), ToId(W2)}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - - // Only release 2nd temp output as first holds loss values. - temp_outputs[1]->Unref(); - - outputs[0] = out_grads[0]; // dW1 - outputs[1] = out_grads[1]; // dW2 - outputs[2] = loss; - - delete tape; - return Status::OK(); -} - -Status ScalarMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractTensorHandle* eta = inputs[0]; - AbstractTensorHandle* A = inputs[1]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - vector temp_outputs(1); - - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A}, - absl::MakeSpan(temp_outputs), - "scalarMul0")); // Compute eta*A - - outputs[0] = temp_outputs[0]; - - delete tape; - return Status::OK(); -} - -Status MatMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractTensorHandle* X = inputs[0]; - AbstractTensorHandle* W1 = inputs[1]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, - absl::MakeSpan(temp_outputs), "matmul0", - /*transpose_a=*/false, - /*transpose_b=*/false)); // Compute X*W1 - - outputs[0] = temp_outputs[0]; - delete tape; - return Status::OK(); -} - -Status MulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractTensorHandle* x = inputs[0]; - AbstractTensorHandle* y = inputs[1]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y}, - absl::MakeSpan(temp_outputs), - "mul0")); // Compute x*y - - outputs[0] = temp_outputs[0]; - delete tape; - return Status::OK(); -} - -Status SoftmaxModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractTensorHandle* x = inputs[0]; - AbstractTensorHandle* labels = inputs[1]; - - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - std::vector temp_outputs(2); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( - tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss")); - - outputs[0] = temp_outputs[0]; // loss values - - delete tape; - return Status::OK(); -} - -// ============================= End Models ================================ - -} // namespace internal -} // namespace gradients -} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h deleted file mode 100644 index b173446ac9bb3e..00000000000000 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ -#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ -#include - -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/experimental/ops/array_ops.h" -#include "tensorflow/c/experimental/ops/math_ops.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/status.h" - - -namespace tensorflow { -namespace gradients { -namespace internal { - -// Computes -// y = inputs[0] + inputs[1] -// return grad(y, {inputs[0], inputs[1]}) -Status AddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Computes -// y = inputs[0] * inputs[1] -// return grad(y, {inputs[0], inputs[1]}) -Status MatMulGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Computes 2-layer Neural Network with Softmax Loss. -Status MNISTForwardModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Computes MatMul with first matrix tranposed. -Status MatMulTransposeModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Test Model to verify ReluGrad functionality -Status ReluGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Test Model to verify SoftmaxGrad functionality -Status SoftmaxLossGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Test Model to verify Multi-grad functionality for MNIST -Status MNISTGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Test Model to verify scalar-tensor multiplication Op -Status ScalarMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -Status MatMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -Status MulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -Status SoftmaxModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -} // namespace internal -} // namespace gradients -} // namespace tensorflow - -#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 473ab503834701..62dd2f3bbd4480 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -59,6 +59,7 @@ cc_library( deps = [ ":parallel_device_lib", "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", "@com_google_absl//absl/strings", @@ -74,9 +75,14 @@ cc_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:tfe_cancellation_manager_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", @@ -88,13 +94,17 @@ tf_cc_test( srcs = ["parallel_device_lib_test.cc"], deps = [ ":parallel_device_lib", + ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:tfe_context_internal", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime/eager:context", ], ) @@ -105,6 +115,7 @@ cc_library( hdrs = ["parallel_device_testlib.h"], deps = [ ":parallel_device", + ":parallel_device_lib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", "//tensorflow/c/eager:c_api", @@ -122,8 +133,11 @@ tf_cc_test( ":parallel_device_testlib", "//tensorflow/c:c_api", "//tensorflow/c:c_api_experimental", + "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index 41bde23448bd35..182d18e2c1d2eb 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/parallel_device/parallel_device.h" +#include #include #include "absl/strings/str_cat.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" namespace tensorflow { namespace parallel_device { @@ -177,13 +179,48 @@ absl::optional> ExecuteWithSpecialOps( return result; } -// Used as an argument to TFE_NewTensorHandleFromDeviceMemory, indicating how +// Used as an argument to TFE_NewCustomDeviceTensorHandle, indicating how // ParallelTensors wrapped in TFE_TensorHandles should be cleaned up once their // reference counts drop to zero. -void ParallelTensorDeallocator(void* data, size_t len, void* arg) { +void ParallelTensorDeallocator(void* data) { delete reinterpret_cast(data); } +// Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing the +// number of dimensions of a parallel tensor. +int ParallelTensorNumDims(void* data, TF_Status* status) { + const std::vector* shape; + Status s = reinterpret_cast(data)->Shape(&shape); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return -1; + } + return shape->size(); +} + +// Used as an argument to TFE_NewCustomDeviceTensorHandle, for computing a +// dimension of a parallel tensor. +int64_t ParallelTensorDim(void* data, int dim_index, TF_Status* status) { + const std::vector* shape; + Status s = reinterpret_cast(data)->Shape(&shape); + if (!s.ok()) { + Set_TF_Status_from_Status(status, s); + return -1; + } + return (*shape)[dim_index]; +} + +TF_Buffer* ParallelTensorSummarize(void* data, TF_Status* status) { + ParallelTensor* parallel_tensor = reinterpret_cast(data); + std::string summary; + Status cpp_status = parallel_tensor->SummarizeValue(summary); + if (!cpp_status.ok()) { + Set_TF_Status_from_Status(status, cpp_status); + return nullptr; + } + return TF_NewBufferFromString(summary.data(), summary.size()); +} + TensorHandlePtr ParallelTensorToTensorHandle( const std::string& parallel_device_name, TFE_Context* context, std::unique_ptr t, TF_Status* status) { @@ -191,11 +228,14 @@ TensorHandlePtr ParallelTensorToTensorHandle( // for a ParallelDevice is really a ParallelTensor. When the TensorHandle is // deleted, it will call ParallelTensorDeallocator to free the struct. ParallelTensor* t_released = t.release(); - const std::vector& shape(t_released->shape()); - return TensorHandlePtr(TFE_NewTensorHandleFromDeviceMemory( - context, parallel_device_name.c_str(), t_released->dtype(), shape.data(), - shape.size(), t_released, 1, &ParallelTensorDeallocator, nullptr, - status)); + TFE_CustomDeviceTensorHandleMethods handle_methods; + handle_methods.num_dims = &ParallelTensorNumDims; + handle_methods.dim = &ParallelTensorDim; + handle_methods.deallocator = &ParallelTensorDeallocator; + handle_methods.summarize = &ParallelTensorSummarize; + return TensorHandlePtr(TFE_NewCustomDeviceTensorHandle( + context, parallel_device_name.c_str(), t_released->dtype(), t_released, + handle_methods, status)); } // For TFE_CustomDevice::copy_tensor_to_device in the parallel device diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 095f33ff303c6f..b3b56263f770bd 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -15,10 +15,14 @@ limitations under the License. #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" +#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace parallel_device { @@ -77,9 +81,15 @@ class DeviceThread { // Requests that the worker thread execute the specified operation. Blocks // until the previously pending operation (a StartExecute without a Join) has // finished, if any. + // + // `cancellation_manager` must live until after `Join` finishes and pending + // `is_async` operations finish. In addition to allowing the caller to cancel + // the operation, its `StartCancel` method will be called if op execution + // fails on any device in order to cancel the others. void StartExecute(TFE_Context* context, const char* operation_name, std::vector inputs, - const TFE_OpAttrs* attributes, int expected_max_outputs); + const TFE_OpAttrs* attributes, int expected_max_outputs, + CancellationManager& cancellation_manager); // Block until the previous `StartExecute` operation has executed. Forwards // the status from `TFE_Execute` and returns outputs if the status is OK. std::vector Join(TF_Status* status); @@ -111,13 +121,16 @@ class DeviceThread { tensorflow::condition_variable finished_join_; // Temporary state between `StartExecute` and `Join`. - // Inputs + // + // Inputs; pointers are to objects not owned by the DeviceThread, but which + // are expected to live at least until `Join` finishes: TFE_Context* context_ TF_GUARDED_BY(execution_mutex_); const char* operation_name_ TF_GUARDED_BY(execution_mutex_); std::vector op_inputs_ TF_GUARDED_BY(execution_mutex_); const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_); int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_); - // Outputs + CancellationManager* cancellation_manager_ TF_GUARDED_BY(execution_mutex_); + // Outputs: std::vector op_outputs_ TF_GUARDED_BY(execution_mutex_); // TF_Status is an incomplete type and so can't be stack allocated. To avoid // unnecessary allocations each Execute call, we keep one heap-allocated @@ -164,7 +177,8 @@ void DeviceThread::StartExecute(TFE_Context* context, const char* operation_name, std::vector inputs, const TFE_OpAttrs* attributes, - int expected_max_outputs) { + int expected_max_outputs, + CancellationManager& cancellation_manager) { { tensorflow::mutex_lock l(execution_mutex_); while (execution_state_ != ExecutionState::kIdle) { @@ -177,6 +191,7 @@ void DeviceThread::StartExecute(TFE_Context* context, op_inputs_ = inputs; attributes_ = attributes; expected_max_outputs_ = expected_max_outputs; + cancellation_manager_ = &cancellation_manager; execution_state_ = ExecutionState::kReadyToExecute; } start_execute_.notify_one(); @@ -196,6 +211,7 @@ std::vector DeviceThread::Join(TF_Status* status) { // the bad `status`) start with an OK status. TF_SetStatus(status_.get(), TF_OK, ""); } + cancellation_manager_ = nullptr; execution_state_ = ExecutionState::kIdle; result = std::move(op_outputs_); } @@ -226,9 +242,13 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, } std::vector unwrapped_results(expected_max_outputs); int real_num_outputs = expected_max_outputs; + TFE_OpSetCancellationManager(op_.get(), wrap(cancellation_manager_), status); if (TF_GetCode(status) != TF_OK) return; TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status); - if (TF_GetCode(status) != TF_OK) return; + if (TF_GetCode(status) != TF_OK) { + cancellation_manager_->StartCancel(); + return; + } unwrapped_results.resize(real_num_outputs); outputs->reserve(real_num_outputs); for (TFE_TensorHandle* unwrapped_result : unwrapped_results) { @@ -238,7 +258,8 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name, ParallelDevice::ParallelDevice(const std::vector& devices, const bool is_async) - : underlying_devices_(devices) { + : underlying_devices_(devices), + default_cancellation_manager_(absl::make_unique()) { device_threads_.reserve(devices.size()); for (int device_index = 0; device_index < devices.size(); ++device_index) { device_threads_.emplace_back( @@ -263,55 +284,6 @@ std::unique_ptr ParallelDevice::CopyToParallelDevice( status); } -std::unique_ptr ParallelDevice::Vector( - TFE_Context* context, TF_Status* status, - absl::Span values) const { - // TODO(allenl): We could cache DeviceIDs (keyed by context). - std::vector components; - components.reserve(underlying_devices_.size()); - - if (values.size() != num_underlying_devices()) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - "Number of values did not match number of underlying devices."); - return nullptr; - } - - for (int device_index = 0; device_index < num_underlying_devices(); - ++device_index) { - int32_t* device_value = new int32_t; - *device_value = values[device_index]; - std::unique_ptr tensor( - TF_NewTensor( - TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value, - sizeof(int32_t), - [](void* data, size_t, void* arg) { - delete reinterpret_cast(data); - }, - nullptr), - TF_DeleteTensor); - // TODO(allenl): Here and when executing regular operations, we could hold - // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing - // device names repeatedly. - OpPtr const_op(TFE_NewOp(context, "Const", status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), - status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32); - TFE_TensorHandle* device_handle; - int num_outputs = 1; - TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - components.emplace_back(device_handle); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - return ParallelTensor::FromTensorHandles(*this, std::move(components), - status); -} - std::unique_ptr ParallelDevice::DeviceIDs( TFE_Context* context, TF_Status* status) const { std::vector ids; @@ -319,7 +291,7 @@ std::unique_ptr ParallelDevice::DeviceIDs( for (int i = 0; i < num_underlying_devices(); ++i) { ids.push_back(i); } - return Vector(context, status, ids); + return ScalarsFromSequence(ids, context, status); } absl::optional>> @@ -328,11 +300,28 @@ ParallelDevice::Execute(TFE_Context* context, const char* operation_name, const TFE_OpAttrs* attributes, int expected_max_outputs, TF_Status* status) const { - absl::optional>> result; - // Compute per-device per-output tensors - std::vector> per_device_output_tensors; - per_device_output_tensors.reserve(underlying_devices_.size()); - int first_op_output_count = 0; + std::vector expected_output_shapes(expected_max_outputs); + StartExecute(context, inputs, operation_name, attributes, + expected_max_outputs, *default_cancellation_manager_); + auto result = Join(expected_output_shapes, status); + if (TF_GetCode(status) != TF_OK) { + std::unique_ptr await_status( + TF_NewStatus(), TF_DeleteStatus); + // Wait until all pending nodes have completed since they may have a + // reference to default_cancellation_manager_. We ignore the status return + // since we already have a bad status to propagate. + TFE_ContextAsyncWait(context, await_status.get()); + // Reset the cancellation manager on a bad status. Otherwise we'll cancel + // all future operations. + default_cancellation_manager_ = absl::make_unique(); + } + return result; +} + +void ParallelDevice::StartExecute( + TFE_Context* context, const std::vector& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, CancellationManager& cancellation_manager) const { for (int device_index = 0; device_index < underlying_devices_.size(); ++device_index) { DeviceThread* device_thread = device_threads_[device_index].get(); @@ -344,8 +333,19 @@ ParallelDevice::Execute(TFE_Context* context, } device_thread->StartExecute(context, operation_name, std::move(device_inputs), attributes, - expected_max_outputs); + expected_max_outputs, cancellation_manager); } +} + +absl::optional>> +ParallelDevice::Join( + const std::vector& expected_output_shapes, + TF_Status* status) const { + absl::optional>> result; + // Compute per-device per-output tensors + std::vector> per_device_output_tensors; + per_device_output_tensors.reserve(underlying_devices_.size()); + int first_op_output_count = 0; StatusPtr first_bad_status(nullptr); for (int device_index = 0; device_index < underlying_devices_.size(); ++device_index) { @@ -354,7 +354,11 @@ ParallelDevice::Execute(TFE_Context* context, // We will run every Join even if there are bad statuses in case the user // wants to recover and continue running ops on the parallel device (which // would otherwise deadlock). - if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) { + if (TF_GetCode(status) != TF_OK && + (first_bad_status == nullptr + // Prefer propagating non-cancellation related statuses to avoid + // shadowing the original failure. + || TF_GetCode(first_bad_status.get()) == TF_CANCELLED)) { first_bad_status.reset(TF_NewStatus()); TF_SetStatus(first_bad_status.get(), TF_GetCode(status), TF_Message(status)); @@ -386,50 +390,126 @@ ParallelDevice::Execute(TFE_Context* context, for (int j = 0; j < underlying_devices_.size(); ++j) { components.push_back(std::move(per_device_output_tensors[j][i])); } - per_device_outputs.push_back(ParallelTensor::FromTensorHandles( - *this, std::move(components), status)); + if (expected_output_shapes[i].IsFullyDefined()) { + per_device_outputs.push_back(ParallelTensor::FromTensorHandles( + *this, std::move(components), + absl::Span(expected_output_shapes[i].dim_sizes()), + status)); + } else { + per_device_outputs.push_back(ParallelTensor::FromTensorHandles( + *this, std::move(components), status)); + } if (TF_GetCode(status) != TF_OK) return result; } result.emplace(std::move(per_device_outputs)); return result; } +std::vector ParallelDevice::SummarizeDeviceNames() const { + std::vector parsed_components( + underlying_devices_.size()); + for (int component_index = 0; component_index < underlying_devices_.size(); + ++component_index) { + if (!DeviceNameUtils::ParseFullName(underlying_devices_[component_index], + &parsed_components[component_index]) || + !DeviceNameUtils::IsSameAddressSpace( + underlying_devices_[component_index], underlying_devices_[0])) { + // Device names are from different address spaces, or we can't figure out + // whether they are, so we'll fully-qualify everything. + return underlying_devices_; + } + } + std::vector local_names; + local_names.reserve(underlying_devices_.size()); + for (const DeviceNameUtils::ParsedName& parsed_component : + parsed_components) { + local_names.push_back( + absl::StrCat(parsed_component.type, ":", parsed_component.id)); + } + return local_names; +} + std::unique_ptr ParallelTensor::FromTensorHandles( const ParallelDevice& parallel_device, - std::vector components, TF_Status* status) { + std::vector components, absl::Span shape, + TF_Status* status) { TF_DataType dtype = TFE_TensorHandleDataType(components[0].get()); - std::vector shape( - TFE_TensorHandleNumDims(components[0].get(), status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - for (int i = 0; i < shape.size(); ++i) { - shape[i] = TFE_TensorHandleDim(components[0].get(), i, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - // Verify that the TensorHandle's shape and dtype match all of the component // shapes and dtypes. for (TensorHandlePtr& component : components) { - for (int i = 0; i < shape.size(); ++i) { - int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status); - if (TF_GetCode(status) != TF_OK) return nullptr; - if (tensor_dim != shape[i]) { - // TODO(allenl): Allow shapes to differ. - TF_SetStatus(status, TF_UNIMPLEMENTED, - "Components of a ParallelTensor must currently all have " - "the same shape"); - return nullptr; - } - if (TFE_TensorHandleDataType(component.get()) != dtype) { - TF_SetStatus(status, TF_INTERNAL, - "Components of a ParallelTensor must all have " - "the same dtype"); - return nullptr; + if (TFE_TensorHandleDataType(component.get()) != dtype) { + TF_SetStatus(status, TF_INTERNAL, + "Components of a ParallelTensor must all have " + "the same dtype"); + return nullptr; + } + } + return std::unique_ptr( + new ParallelTensor(parallel_device, std::move(components), shape, dtype)); +} + +std::unique_ptr ParallelTensor::FromTensorHandles( + const ParallelDevice& parallel_device, + std::vector components, TF_Status* status) { + TF_DataType dtype = TFE_TensorHandleDataType(components[0].get()); + // Verify that the combined TensorHandle's dtype matches all of the component + // dtypes. + for (TensorHandlePtr& component : components) { + if (TFE_TensorHandleDataType(component.get()) != dtype) { + TF_SetStatus(status, TF_INTERNAL, + "Components of a ParallelTensor must all have " + "the same dtype"); + return nullptr; + } + } + return std::unique_ptr( + new ParallelTensor(parallel_device, std::move(components), dtype)); +} + +Status ParallelTensor::Shape(const std::vector** shape) const { + if (!shape_.has_value()) { + TF_Status status; + PartialTensorShape first_shape; + TF_RETURN_IF_ERROR(unwrap(tensors_[0].get())->Shape(&first_shape)); + + // Verify that the TensorHandle's shape matches all of the component shapes. + for (const TensorHandlePtr& component : tensors_) { + PartialTensorShape component_shape; + TF_RETURN_IF_ERROR(unwrap(component.get())->Shape(&component_shape)); + if (!first_shape.IsIdenticalTo(component_shape)) { + return errors::Unimplemented(absl::StrCat( + "Computing the shape of a ParallelTensor when the components do " + "not all have the same shapes is not supported. One tensor had " + "shape ", + first_shape.DebugString(), " and another had shape ", + component_shape.DebugString())); } } + auto dim_sizes = first_shape.dim_sizes(); + shape_ = std::vector(dim_sizes.begin(), dim_sizes.end()); } + *shape = &*shape_; + return Status::OK(); +} - return std::unique_ptr(new ParallelTensor( - parallel_device, std::move(components), std::move(shape), dtype)); +Status ParallelTensor::SummarizeValue(std::string& summary) { + summary = "{"; + std::vector summarized_devices = device_.SummarizeDeviceNames(); + for (int component_index = 0; component_index < tensors_.size(); + ++component_index) { + // TODO(allenl): Add a C API for summarizing tensors. Currently custom + // devices limiting themselves to a C API (for ABI compatibility) would need + // to implement summarization for component tensors themselves. + ImmediateExecutionTensorHandle* component = + tensorflow::unwrap(tensors_[component_index].get()); + std::string component_summary; + TF_RETURN_IF_ERROR(component->SummarizeValue(component_summary)); + absl::StrAppend(&summary, component_index == 0 ? "" : ", ", "\"", + summarized_devices[component_index], + "\": ", component_summary); + } + summary += "}"; + return Status::OK(); } } // namespace parallel_device diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index 1bb9ce0f663955..0e2d07b9050685 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -26,6 +26,9 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace parallel_device { @@ -66,9 +69,10 @@ class ParallelDevice { TF_Status* status) const; // Construct a parallel tensor consisting of the scalar values from `values`. - std::unique_ptr Vector( - TFE_Context* context, TF_Status* status, - absl::Span values) const; + template + std::unique_ptr ScalarsFromSequence( + absl::Span values, TFE_Context* context, + TF_Status* status) const; // A parallel tensor with scalar integers numbering component devices. std::unique_ptr DeviceIDs(TFE_Context* context, @@ -93,6 +97,44 @@ class ParallelDevice { const char* operation_name, const TFE_OpAttrs* attributes, int expected_max_outputs, TF_Status* status) const; + // A non-blocking version of `Execute`. After each call, `Join` must be called + // before `StartExecute` is called again. Using `StartExecute` with `Join` + // allows the caller to schedule computation on multiple ParallelDevices + // without sequencing those operations (first call `StartExecute` on each + // parallel device, then call `Join` on each; even if some of the `Join`s + // return a bad status the caller must run all of the `Join`s or any future + // `StartExecute`s will deadlock). + // + // If `is_async=false` (constructor argument), `cancellation_manager` must + // live until `Join` finishes. If `is_async=true` it must live until `Join` is + // followed by `TFE_ContextAsyncWait` to clear pending operations. It will be + // used to cancel all other operations if any fails. + void StartExecute(TFE_Context* context, + const std::vector& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, + CancellationManager& cancellation_manager) const; + + // Blocks until the previous `StartExecute` has run `TFE_Execute` on each + // device. If is_async=false (constructor argument) this means the ops have + // run and have results. If is_async=true it means that all of the + // device-specific executors have scheduled the op. + // + // Accepts inferred shapes for outputs (`expected_output_shapes`), which if + // fully defined will avoid querying the shapes of the underlying + // TensorHandles when ParallelTensor::Shape is called. This allows async + // computation to continue without blocking. + // + // The return status and value is the same as `Execute`. + absl::optional>> Join( + const std::vector& expected_output_shapes, + TF_Status* status) const; + + // Device strings for component devices that only include a + // worker/task/replica if any of those differ across components. Useful for + // printing debug messages. + std::vector SummarizeDeviceNames() const; + private: // A sequence of device names, indicating which devices replicated operations // are forwarded to. @@ -110,6 +152,10 @@ class ParallelDevice { // than a single list of threads so aliased nested parallel devices don't // re-use a thread. std::vector> device_threads_; + // A cancellation manager to use if the caller does not provide one. When ops + // are executed asynchronously this must outlive the queued op, so it can't be + // function-local to Execute. + mutable std::unique_ptr default_cancellation_manager_; }; // Contains a tuple of tensors, one on each of the `underlying_devices_` of the @@ -117,33 +163,108 @@ class ParallelDevice { class ParallelTensor { public: // Construct a ParallelTensor from TensorHandles placed on the component - // devices of a ParallelDevice. + // devices of a ParallelDevice. If called, ParallelTensor::Shape inspects + // `components` to determine a shape. static std::unique_ptr FromTensorHandles( const ParallelDevice& parallel_device, std::vector components, TF_Status* status); + // Uses the provided shape without additional checks, which avoids blocking + // when ParallelTensor::Shape is called. + static std::unique_ptr FromTensorHandles( + const ParallelDevice& parallel_device, + std::vector components, absl::Span shape, + TF_Status* status); size_t num_tensors() const { return tensors_.size(); } TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); } - // A generalization of the shapes of the underlying tensors. - const std::vector& shape() const { return shape_; } + // If the `shape` argument to `FromTensorHandles` is specified, returns that. + // + // Otherwise if all of the tensors have the same shape, returns that via the + // `shape` output argument. This blocks waiting for async tensors, may return + // a delayed bad status encountered during async execution, and will return a + // bad status unless all tensors have the same shape. + Status Shape(const std::vector** shape) const; TF_DataType dtype() const { return dtype_; } + // Sets its output argument to a summary of the values of this tensor on every + // component device. + Status SummarizeValue(std::string& summary); + private: ParallelTensor(const ParallelDevice& device, std::vector tensors, - std::vector shape, const TF_DataType dtype) + absl::Span shape, const TF_DataType dtype) + : device_(device), + tensors_(std::move(tensors)), + shape_(std::vector(shape.begin(), shape.end())), + dtype_(dtype) {} + ParallelTensor(const ParallelDevice& device, + std::vector tensors, const TF_DataType dtype) : device_(device), tensors_(std::move(tensors)), - shape_(std::move(shape)), + shape_(absl::nullopt), dtype_(dtype) {} const ParallelDevice& device_; const std::vector tensors_; - const std::vector shape_; + // Parallel tensors are immutable but compute their shape lazily unless it is + // provided on construction. The optional has a value if the lazy computation + // has been completed or the shape was provided on construction. + mutable absl::optional> shape_; const TF_DataType dtype_; }; +template +std::unique_ptr ParallelDevice::ScalarsFromSequence( + absl::Span values, TFE_Context* context, + TF_Status* status) const { + std::vector components; + components.reserve(underlying_devices_.size()); + + if (values.size() != num_underlying_devices()) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "Number of values did not match number of underlying devices."); + return nullptr; + } + TF_DataType datatype_enum( + static_cast(DataTypeToEnum().value)); + for (int device_index = 0; device_index < num_underlying_devices(); + ++device_index) { + auto device_value = absl::make_unique(); + *device_value = values[device_index]; + std::unique_ptr tensor( + TF_NewTensor( + datatype_enum, /*dims=*/nullptr, /*num_dims=*/0, + device_value.release(), sizeof(DataType), + [](void* data, size_t, void* arg) { + delete reinterpret_cast(data); + }, + nullptr), + TF_DeleteTensor); + // TODO(allenl): Here and when executing regular operations, we could hold + // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing + // device names repeatedly. + std::unique_ptr const_op( + TFE_NewOp(context, "Const", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum); + TFE_TensorHandle* device_handle; + int num_outputs = 1; + TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(device_handle); + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + } // namespace parallel_device } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc index 35befe959cb1f8..fdc4ff16a6c1d1 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc @@ -19,11 +19,18 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace parallel_device { +using ::testing::HasSubstr; + TEST(PARALLEL_DEVICE_LIB, TestOpWithError) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -80,5 +87,240 @@ TEST(PARALLEL_DEVICE_LIB, TestOpWithError) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); } +TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::vector devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(std::move(devices)); + std::unique_ptr handle_op( + TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT); + TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + CancellationManager cancellation_manager; + parallel_device.StartExecute(context.get(), std::vector(), + "VarHandleOp", TFE_OpGetAttrs(handle_op.get()), + /*expected_max_outputs=*/1, + cancellation_manager); + auto outputs = parallel_device.Join( + /*expected_output_shapes=*/{PartialTensorShape({})}, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + const std::vector>& handles = *outputs; + const std::vector* shape; + Status s = handles[0]->Shape(&shape); + ASSERT_TRUE(s.ok()); + EXPECT_EQ(0, shape->size()); +} + +TEST(PARALLEL_DEVICE_LIB, TestCancelOnError) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::vector devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(devices); + const FunctionDef assert_and_collective = FunctionDefHelper::Define( + // Name + "AssertAndCollective", + // Args + {"x: float", "condition: bool"}, + // Return values + {"y: float"}, + // Attr def + {}, + // Nodes + { + {{"assert"}, + "Assert", + {"condition", "x"}, + {{"T", std::vector{DT_FLOAT}}}}, + {{"y"}, + "CollectiveReduce", + {"x"}, + {{"T", DT_FLOAT}, + {"group_size", static_cast(devices.size())}, + {"group_key", 0}, + {"instance_key", 0}, + {"merge_op", "Add"}, + {"final_op", "Id"}, + {"subdiv_offsets", std::vector()}}, + /*dep=*/{"assert"}}, + }); + TF_ASSERT_OK(ContextFromInterface(unwrap(context.get())) + ->AddFunctionDef(assert_and_collective)); + + std::unique_ptr call_op( + TFE_NewOp(context.get(), "AssertAndCollective", status.get()), + TFE_DeleteOp); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + std::unique_ptr reduced_values = + parallel_device.ScalarsFromSequence({1.0, 2.0}, context.get(), + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + std::unique_ptr run_collective = + parallel_device.ScalarsFromSequence({true, true}, context.get(), + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + auto outputs = parallel_device.Execute( + context.get(), {reduced_values.get(), run_collective.get()}, + "AssertAndCollective", TFE_OpGetAttrs(call_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ASSERT_EQ(outputs->size(), 1); + ParallelTensor* parallel_result = (*outputs)[0].get(); + ExpectScalarEq(parallel_result->tensor(0), 3.); + ExpectScalarEq(parallel_result->tensor(1), 3.); + + run_collective = parallel_device.ScalarsFromSequence( + {true, false}, context.get(), status.get()); + parallel_device.Execute(context.get(), + {reduced_values.get(), run_collective.get()}, + "AssertAndCollective", TFE_OpGetAttrs(call_op.get()), + /*expected_max_outputs=*/1, status.get()); + EXPECT_NE(TF_GetCode(status.get()), TF_CANCELLED); + EXPECT_EQ(TF_GetCode(status.get()), TF_INVALID_ARGUMENT); + EXPECT_THAT(TF_Message(status.get()), HasSubstr("assertion failed")); + + // Note that future collectives with the same context do not work at the + // moment; once canceled, the collective executor requires the program to be + // restarted / context to be reset. +} + +TEST(PARALLEL_DEVICE_LIB, TestDifferentShapes) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::vector devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(std::move(devices)); + TensorHandlePtr two_vector = VectorFloatTensorHandle({3., 4.}, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TensorHandlePtr three_vector = + VectorFloatTensorHandle({5., 6., 7.}, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::vector vector_handles; + vector_handles.reserve(2); + vector_handles.push_back(std::move(two_vector)); + vector_handles.push_back(std::move(three_vector)); + std::unique_ptr unknown_length_vector = + ParallelTensor::FromTensorHandles( + parallel_device, std::move(vector_handles), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + const std::vector* shape; + Status s = unknown_length_vector->Shape(&shape); + EXPECT_FALSE(s.ok()); + + TensorHandlePtr scalar = FloatTensorHandle(2., status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + two_vector = VectorFloatTensorHandle({3., 4.}, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + std::vector mixed_handles; + mixed_handles.reserve(2); + mixed_handles.push_back(std::move(scalar)); + mixed_handles.push_back(std::move(two_vector)); + std::unique_ptr unknown_dims_vector = + ParallelTensor::FromTensorHandles(parallel_device, + std::move(mixed_handles), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + // Can't take the shape of a parallel tensor with varying numbers of axes, but + // running operations on them is OK. + s = unknown_length_vector->Shape(&shape); + EXPECT_FALSE(s.ok()); + std::unique_ptr size_op( + TFE_NewOp(context.get(), "Size", status.get()), TFE_DeleteOp); + auto result = parallel_device.Execute( + context.get(), {unknown_dims_vector.get()}, "Size", + TFE_OpGetAttrs(size_op.get()), 1, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + s = (*result)[0]->Shape(&shape); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + EXPECT_EQ(0, shape->size()); +} + +TEST(PARALLEL_DEVICE_LIB, TestScalarsFromSequence) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + + std::vector devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(std::move(devices)); + { + std::unique_ptr float_tensors = + parallel_device.ScalarsFromSequence({10.0, 11.0}, context.get(), + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ExpectScalarEq(float_tensors->tensor(0), 10.0); + ExpectScalarEq(float_tensors->tensor(1), 11.0); + } + + { + std::unique_ptr int_tensors = + parallel_device.ScalarsFromSequence({5, 6}, context.get(), + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ExpectScalarEq(int_tensors->tensor(0), 5); + ExpectScalarEq(int_tensors->tensor(1), 6); + } +} + } // namespace parallel_device } // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc index 32a4b440d25963..41d6f14e06863f 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_remote_test.cc @@ -41,6 +41,9 @@ tensorflow::ServerDef GetServerDef(const std::string& job_name, int num_tasks) { return server_def; } +namespace tensorflow { +namespace parallel_device { + TEST(PARALLEL_DEVICE, TestRemoteBasic) { std::unique_ptr opts( TFE_NewContextOptions(), TFE_DeleteContextOptions); @@ -145,3 +148,5 @@ TEST(PARALLEL_DEVICE, TestAsyncCopyOff) { worker_server1.release(); worker_server2.release(); } +} // namespace parallel_device +} // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index 06a26ab2710092..dc97f89be113f4 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -21,7 +21,11 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" // NOTE(allenl): These tests currently go through TFE_Execute and so are @@ -29,6 +33,11 @@ limitations under the License. // correspond fairly well to the implementation, but testing the C++ directly is // another option. +namespace tensorflow { +namespace parallel_device { + +using ::testing::HasSubstr; + TEST(PARALLEL_DEVICE, TestBasicCPU) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -36,15 +45,14 @@ TEST(PARALLEL_DEVICE, TestBasicCPU) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr config( TF_CreateConfig( - /*xla*/ false, - /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 2), + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), TF_DeleteBuffer); TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, status.get()); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); BasicTestsForTwoDevices(context.get(), "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:1"); @@ -57,7 +65,7 @@ TEST(PARALLEL_DEVICE, TestBasicCPUAliased) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); BasicTestsForTwoDevices(context.get(), "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0"); @@ -70,18 +78,18 @@ TEST(PARALLEL_DEVICE, TestBasicTPUAliased) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Skip the test if no TPU is available. std::unique_ptr devices( TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); bool has_tpu = false; for (int device_index = 0; device_index < TF_DeviceListCount(devices.get()); ++device_index) { std::string device_type = TF_DeviceListType(devices.get(), device_index, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); if (device_type == "TPU") { has_tpu = true; break; @@ -101,15 +109,14 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr config( TF_CreateConfig( - /*xla*/ false, - /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 2), + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), TF_DeleteBuffer); TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, status.get()); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* first_device_name = @@ -120,18 +127,18 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { second_device_name}; RegisterParallelDevice(context.get(), device_name, underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Copying on to a parallel device is OK. TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice( cpu_value.get(), context.get(), device_name, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); const char* backing_device = TFE_TensorHandleBackingDeviceName(device_value.get(), status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); ASSERT_EQ(std::string(device_name), backing_device); // Un-pack the parallel tensor to verify that the copy was successful. @@ -139,7 +146,7 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) { std::array components; ExtractPerDeviceValues(context.get(), device_value.get(), &components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // The value of the original tensor is replicated on each device. ExpectScalarEq(components[0].get(), 3.); @@ -167,15 +174,14 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr config( TF_CreateConfig( - /*xla*/ false, - /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 2), + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), TF_DeleteBuffer); TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, status.get()); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; std::array underlying_devices{ @@ -183,24 +189,26 @@ TEST(PARALLEL_DEVICE, TestDifferentShapes) { "/job:localhost/replica:0/task:0/device:CPU:1"}; RegisterParallelDevice(context.get(), device_name, underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Create two vectors with different lengths std::vector size_two_value{1., 2.}; std::vector size_three_value{1., 2., 3.}; TensorHandlePtr size_two( VectorFloatTensorHandle(size_two_value, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr size_three( VectorFloatTensorHandle(size_three_value, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Try to combine these values into a single parallel tensor. std::array components{size_two.get(), size_three.get()}; TensorHandlePtr combined_value = CreatePerDeviceValues( context.get(), components, device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED) - << TF_Message(status.get()); + // We can create the handle, but fetching the shape is an error at the moment. + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + TFE_TensorHandleNumDims(combined_value.get(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED); } TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { @@ -210,15 +218,14 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr config( TF_CreateConfig( - /*xla*/ false, - /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 3), + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/3), TF_DeleteBuffer); TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, status.get()); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Create a parallel device with two CPUs const char* first_device_name = @@ -228,7 +235,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { "/job:localhost/replica:0/task:0/device:CPU:1"}; RegisterParallelDevice(context.get(), first_device_name, first_underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Create a second parallel device with the first parallel device and one // additional CPU. @@ -239,32 +246,32 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { "/job:localhost/replica:0/task:0/device:CPU:2"}; RegisterParallelDevice(context.get(), second_device_name, second_underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Create a tensor on the first parallel device TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); TensorHandlePtr value_two(FloatTensorHandle(2., status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array components{value_one.get(), value_two.get()}; TensorHandlePtr first_combined_value = CreatePerDeviceValues( context.get(), components, first_device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Nest the first parallel tensor into a second TensorHandlePtr value_three(FloatTensorHandle(3., status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); components[0] = first_combined_value.get(); components[1] = value_three.get(); TensorHandlePtr second_combined_value = CreatePerDeviceValues( context.get(), components, second_device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr negative_one(FloatTensorHandle(3., status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr multiply_result(Multiply(context.get(), second_combined_value.get(), negative_one.get(), status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Un-pack the parallel tensor to verify that the operation was // successful. The resulting structure should be: @@ -272,7 +279,7 @@ TEST(PARALLEL_DEVICE, TestNestedParallelDevices) { std::array second_components; ExtractPerDeviceValues(context.get(), multiply_result.get(), &second_components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); ExpectScalarEq(second_components[1].get(), 9.); @@ -311,14 +318,14 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) { "/job:localhost/replica:0/task:0/device:CPU:0"}; RegisterParallelDevice(context.get(), device_name, underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); TensorHandlePtr value_two(FloatTensorHandle(2., status.get())); { // Try to pack two TensorHandles onto a parallel device with a single // component. - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array components{value_one.get(), value_two.get()}; TensorHandlePtr combined_value = CreatePerDeviceValues( @@ -332,7 +339,7 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) { std::array correct_components{value_one.get()}; TensorHandlePtr combined_value = CreatePerDeviceValues( context.get(), correct_components, device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array incorrect_components; ExtractPerDeviceValues(context.get(), combined_value.get(), @@ -346,7 +353,7 @@ TEST(PARALLEL_DEVICE, TestInvalidPacking) { std::array correct_components{value_one.get()}; TensorHandlePtr combined_value = CreatePerDeviceValues( context.get(), correct_components, device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array incorrect_components{combined_value.get()}; TensorHandlePtr recombined_value = CreatePerDeviceValues( @@ -415,15 +422,14 @@ void TestCollective(bool async) { TFE_ContextOptionsSetAsync(opts.get(), async); std::unique_ptr config( TF_CreateConfig( - /*xla*/ false, - /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 2), + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), TF_DeleteBuffer); TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, status.get()); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; std::array underlying_devices{ @@ -431,26 +437,26 @@ void TestCollective(bool async) { "/job:localhost/replica:0/task:0/device:CPU:1"}; RegisterParallelDevice(context.get(), device_name, underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Create a tensor on the parallel device TensorHandlePtr value_one(FloatTensorHandle(1., status.get())); TensorHandlePtr value_two(FloatTensorHandle(2., status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array components{value_one.get(), value_two.get()}; TensorHandlePtr parallel_value = CreatePerDeviceValues( context.get(), components, device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); // Run a collective sum, so each component should now be the same. TensorHandlePtr reduced( CollectiveSum(context.get(), parallel_value.get(), 2, status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array result_components; ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); ExpectScalarEq(result_components[0].get(), 3.); ExpectScalarEq(result_components[1].get(), 3.); } @@ -512,15 +518,14 @@ TEST(PARALLEL_DEVICE, TestFunction) { TFE_NewContextOptions(), TFE_DeleteContextOptions); std::unique_ptr config( TF_CreateConfig( - /*xla*/ false, - /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ - 2), + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), TF_DeleteBuffer); TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, status.get()); std::unique_ptr context( TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; std::array underlying_devices{ @@ -528,38 +533,38 @@ TEST(PARALLEL_DEVICE, TestFunction) { "/job:localhost/replica:0/task:0/device:CPU:1"}; RegisterParallelDevice(context.get(), device_name, underlying_devices, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); const char* function_name = "test_reduce_mul"; RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr value_one(FloatTensorHandle(7., status.get())); TensorHandlePtr value_two(FloatTensorHandle(9., status.get())); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::array components{value_one.get(), value_two.get()}; TensorHandlePtr parallel_value = CreatePerDeviceValues( context.get(), components, device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); std::unique_ptr op( TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TFE_OpSetDevice(op.get(), device_name, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TFE_OpAddInput(op.get(), parallel_value.get(), status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TFE_TensorHandle* raw_result_handle; int num_retvals = 1; TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); TensorHandlePtr reduced(raw_result_handle); std::array result_components; ExtractPerDeviceValues(context.get(), reduced.get(), &result_components, status.get()); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); ExpectScalarEq(result_components[0].get(), 7. * 9.); ExpectScalarEq(result_components[1].get(), 7. * 9.); @@ -570,3 +575,41 @@ TEST(PARALLEL_DEVICE, TestFunction) { result_components[1].get(), status.get()); ASSERT_EQ(underlying_devices[1], second_device); } + +TEST(PARALLEL_DEVICE, TestSummaryString) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*enable_xla_compilation=*/false, + /*gpu_memory_allow_growth=*/true, /*num_cpu_devices=*/2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + + const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; + std::array underlying_devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + RegisterParallelDevice(context.get(), device_name, underlying_devices, + status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get())); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice( + cpu_value.get(), context.get(), device_name, status.get())); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ImmediateExecutionTensorHandle* unwrapped_handle = + tensorflow::unwrap(device_value.get()); + std::string summarized; + TF_ASSERT_OK(unwrapped_handle->SummarizeValue(summarized)); + EXPECT_THAT(summarized, HasSubstr("\"CPU:0\": 3")); +} + +} // namespace parallel_device +} // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc index 67bc596b180f01..b8ab7fce3263b4 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc @@ -28,6 +28,8 @@ limitations under the License. // correspond fairly well to the implementation, but testing the C++ directly is // another option. +namespace tensorflow { +namespace parallel_device { Variable* Variable::Create(TFE_Context* context, TF_DataType type, const int64_t* dims, const int num_dims, @@ -280,3 +282,6 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, ASSERT_EQ(underlying_devices[1], second_device); } } + +} // namespace parallel_device +} // namespace tensorflow diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h index 3f917224187bb3..ecc96dd66ee366 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.h @@ -16,29 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ -#include "tensorflow/c/eager/parallel_device/parallel_device.h" - #include #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" #include "tensorflow/core/platform/test.h" - -// Functor for making unique_ptr to TFE_TensorHandle slightly more -// ergonomic. Using decltype(TFE_DeleteTensorHandle) in the unique_ptr's second -// template argument requires passing a function pointer to -// TFE_DeleteTensorHandle when constructing the unique_ptr. -class TensorHandleDeleter { - public: - void operator()(TFE_TensorHandle* to_delete) { - TFE_DeleteTensorHandle(to_delete); - } -}; - -using TensorHandlePtr = std::unique_ptr; +namespace tensorflow { +namespace parallel_device { // A helper for performing common operations on variables. A much more // restricted stand-in for tf.Variable in Python. @@ -151,11 +140,13 @@ template void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - std::unique_ptr value_zero( + std::unique_ptr actual_value( TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); - ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_TensorType(actual_value.get()), + static_cast(DataTypeToEnum().value)); EXPECT_EQ(expected_value, - *static_cast(TF_TensorData(value_zero.get()))); + *static_cast(TF_TensorData(actual_value.get()))); } template @@ -171,4 +162,7 @@ void RegisterParallelDevice( TFE_RegisterCustomDevice(context, device, device_name, device_info, status); } +} // namespace parallel_device +} // namespace tensorflow + #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index efab4dfbeb2ebf..f096f609f94982 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -93,11 +93,14 @@ class VSpace { gtl::ArraySlice gradient_tensors) const = 0; // Calls the passed-in backward function. + // + // `unneeded_gradients` contains sorted list of input indices for which a + // gradient is not required. virtual Status CallBackwardFunction( - BackwardFunction* backward_function, + const string& op_type, BackwardFunction* backward_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, - std::vector* result) const = 0; + absl::Span result) const = 0; // Builds a tensor filled with ones with the same shape and dtype as `t`. virtual Status BuildOnesLike(const TapeTensor& t, @@ -133,11 +136,24 @@ class GradientTape { } } + // Returns whether any tensor in a list of tensors is being watched and has + // a trainable dtype. bool ShouldRecord(gtl::ArraySlice tensor_ids, - gtl::ArraySlice dtypes); + gtl::ArraySlice dtypes) const; + // Adds this tensor to the list of watched tensors. + // + // This is a no-op if the tensor is already being watched either from an + // earlier call to `GradientTape::Watch` or being an output of an op with + // watched inputs. void Watch(int64 tensor_id); + // Records an operation with inputs `input_tensor_id` and outputs + // `output_tensors` on the tape and marks all its outputs as watched if at + // least one input of the op is watched and has trainable dtype. + // + // op_type is used to decide which of the incoming gradients can be left as + // nullptr instead of building zeros when build_default_zeros_grads == true. void RecordOperation( const string& op_type, const std::vector& output_tensors, gtl::ArraySlice input_tensor_id, @@ -159,9 +175,10 @@ class GradientTape { const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, const std::unordered_map& sources_that_are_targets, - gtl::ArraySlice output_gradients, - std::vector* result, bool build_default_zeros_grads = true); + gtl::ArraySlice output_gradients, absl::Span result, + bool build_default_zeros_grads = true); + // Whether the tape is persistent. See ctor for detailed description. bool IsPersistent() const { return persistent_; } private: @@ -311,11 +328,10 @@ class ForwardAccumulator { // function is running; this effectively adds the backward tape to the active // set (but does not require complicated callbacks to the language bindings). Status ForwardpropFromTape( - const std::vector& output_tensors, + const string& op_type, const std::vector& output_tensors, const std::function& backward_function_getter, const std::function& backward_function_deleter, - const std::vector& in_grads, - std::vector* out_grads); + const std::vector& in_grads, absl::Span out_grads); // Maps from tensor IDs to corresponding JVPs. std::unordered_map accumulated_gradients_; @@ -368,7 +384,7 @@ inline bool IsDtypeTrainable(DataType dtype) { template bool GradientTape::ShouldRecord( gtl::ArraySlice tensor_ids, - gtl::ArraySlice dtypes) { + gtl::ArraySlice dtypes) const { CHECK_EQ(tensor_ids.size(), dtypes.size()); for (int i = 0; i < tensor_ids.size(); ++i) { if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { @@ -668,7 +684,7 @@ Status GradientTape::ComputeGradient( const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, const std::unordered_map& sources_that_are_targets, - gtl::ArraySlice output_gradients, std::vector* result, + gtl::ArraySlice output_gradients, absl::Span result, bool build_default_zeros_grads) { std::unordered_set sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); @@ -757,23 +773,17 @@ Status GradientTape::ComputeGradient( out_gradients.push_back(new_gradients); } } - std::vector in_gradients; + VLOG(1) << "Calling gradient function for '" << trace.op_type << "'"; + std::vector in_gradients(trace.input_tensor_id.size()); DCHECK(build_default_zeros_grads || zero_indices.empty()); if (any_gradient_nonzero) { for (const auto i : zero_indices) { out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); } Status s; - s = vspace.CallBackwardFunction(trace.backward_function, + s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function, unneeded_gradients, out_gradients, - &in_gradients); - if (in_gradients.size() != trace.input_tensor_id.size()) { - return tensorflow::errors::Internal( - "Recorded operation '", trace.op_type, - "' returned too few gradients. Expected ", - trace.input_tensor_id.size(), " but received ", - in_gradients.size()); - } + absl::MakeSpan(in_gradients)); if (!persistent_) { trace.backward_function_deleter(trace.backward_function); } @@ -781,7 +791,6 @@ Status GradientTape::ComputeGradient( return s; } } else { - in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { trace.backward_function_deleter(trace.backward_function); } @@ -791,8 +800,6 @@ Status GradientTape::ComputeGradient( } } } - VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " - << trace.input_tensor_id.size() << " sources"; for (int i = 0, end = in_gradients.size(); i < end; ++i) { const int64 id = trace.input_tensor_id[i]; if (in_gradients[i] != nullptr) { @@ -856,20 +863,25 @@ Status GradientTape::ComputeGradient( if (!state.op_tape.empty()) { return tensorflow::errors::Internal("Invalid tape state."); } - result->reserve(source_tensor_ids.size()); + if (result.size() != source_tensor_ids.size()) { + return errors::Internal("Expected result Span to be of size ", + source_tensor_ids.size(), " found ", result.size(), + " in call to Tape::ComputeGradient."); + } std::unordered_set used_gradient_ids(source_tensor_ids.size()); - for (auto is : source_tensor_ids) { - auto grad_it = gradients.find(is); + for (int i = 0; i < source_tensor_ids.size(); i++) { + int64 tensor_id = source_tensor_ids[i]; + auto grad_it = gradients.find(tensor_id); if (grad_it == gradients.end()) { - result->push_back(nullptr); + result[i] = nullptr; } else { if (grad_it->second.size() > 1) { Gradient* grad = vspace.AggregateGradients(grad_it->second); grad_it->second.clear(); grad_it->second.push_back(grad); } - result->push_back(grad_it->second[0]); - used_gradient_ids.insert(is); + result[i] = grad_it->second[0]; + used_gradient_ids.insert(tensor_id); } } VLOG(1) << "Final gradients size: " @@ -910,10 +922,10 @@ bool ForwardAccumulator::ShouldRecord( template Status ForwardAccumulator::ForwardpropFromTape( - const std::vector& output_tensors, + const string& op_type, const std::vector& output_tensors, const std::function& backward_function_getter, const std::function& backward_function_deleter, - const std::vector& in_grads, std::vector* out_grads) { + const std::vector& in_grads, absl::Span out_grads) { /* This function is approximately equivalent to this Python code: forwardprop_aids = tf.ones_like(output_tensors) @@ -957,7 +969,7 @@ ForwardAccumulator::ForwardpropFromTape( sources_set.insert(aid_id); tape->Watch(aid_id); } - std::vector grad; + std::vector grad(in_grads.size()); auto delete_grad = gtl::MakeCleanup([&grad, this] { for (Gradient* tensor : grad) { this->vspace_.DeleteGradient(tensor); @@ -969,16 +981,13 @@ ForwardAccumulator::ForwardpropFromTape( backward_function(backward_function_getter(), backward_function_deleter); TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction( - backward_function.get(), unneeded_gradients, forwardprop_aids, &grad)); + op_type, backward_function.get(), unneeded_gradients, forwardprop_aids, + absl::MakeSpan(grad))); } // Stop the tape from recording pop_backward_tape.release()(); - if (grad.size() != in_grads.size()) { - return tensorflow::errors::Internal("Wrong number of gradients returned."); - } - std::vector targets; std::vector used_in_grads; // We may end up with slightly fewer elements than we reserve, but grad.size() @@ -1076,9 +1085,10 @@ Status ForwardAccumulator::Accumulate( if (forward_function == nullptr) { // We have no special-cased forward gradient. Fall back to running the // backward function under a gradient tape. + forward_grads.resize(output_tensors.size()); TF_RETURN_IF_ERROR(ForwardpropFromTape( - output_tensors, backward_function_getter, backward_function_deleter, - in_grads, &forward_grads)); + op_type, output_tensors, backward_function_getter, + backward_function_deleter, in_grads, absl::MakeSpan(forward_grads))); } else { TF_RETURN_IF_ERROR( (*forward_function)(in_grads, &forward_grads, use_batch_)); diff --git a/tensorflow/c/eager/tfe_cancellation_manager_internal.h b/tensorflow/c/eager/tfe_cancellation_manager_internal.h index 7d500c874e60a9..6fdecd788f7215 100644 --- a/tensorflow/c/eager/tfe_cancellation_manager_internal.h +++ b/tensorflow/c/eager/tfe_cancellation_manager_internal.h @@ -15,10 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ #define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ +#include "tensorflow/c/conversion_macros.h" #include "tensorflow/core/framework/cancellation.h" -struct TFE_CancellationManager { - tensorflow::CancellationManager cancellation_manager; -}; +struct TFE_CancellationManager; +typedef struct TFE_CancellationManager TFE_CancellationManager; + +namespace tensorflow { +DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager, + TFE_CancellationManager); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager*, + TFE_CancellationManager*); +} // namespace tensorflow #endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ diff --git a/tensorflow/c/eager/tfe_op_attrs_internal.h b/tensorflow/c/eager/tfe_op_attrs_internal.h index 0287502dea632b..24e3692a13feaf 100644 --- a/tensorflow/c/eager/tfe_op_attrs_internal.h +++ b/tensorflow/c/eager/tfe_op_attrs_internal.h @@ -16,8 +16,8 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_op_attrs.h" #include "tensorflow/c/tf_status.h" -#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/framework/attr_value.pb.h" // An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways @@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context; typedef struct TFE_Op TFE_Op; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AttrBuilder, TFE_OpAttrs); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOpAttrs, TFE_OpAttrs); // Set an AttrValue on the op. Doesn't handle the list types. void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, diff --git a/tensorflow/c/eager/unified_api_test.cc b/tensorflow/c/eager/unified_api_test.cc new file mode 100644 index 00000000000000..e8a285b459bc86 --- /dev/null +++ b/tensorflow/c/eager/unified_api_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +class UnifiedAPI + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } + + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } +}; + +// Checks that inputs[0] is a scalar. +Status TestScalarShape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape)); + if (shape.dims() != 0) { + return errors::InvalidArgument( + "Tensor expected to have scalar shape found rank: ", shape.dims()); + } + return Status::OK(); +} + +TEST_P(UnifiedAPI, TestTensorShapeScalar) { + if (UseFunction() && UseMlir()) { + // TODO(b/173074167): Remove this. + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + Status s = RunModel(TestScalarShape, ctx.get(), + /*inputs=*/{x.get()}, + /*outputs=*/{}, + /*use_function=*/UseFunction()); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); +} + +// Checks that inputs[0] is a matrix with shape 2x4. +Status TestTensorShape2x4(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape)); + if (shape.dims() != 2) { + return errors::InvalidArgument( + "Tensor expected to have rank 2 found rank: ", shape.dims()); + } + int64 dim_sizes[] = {2, 4}; + for (int i = 0; i < shape.dims(); i++) { + if (shape.dim_size(i) != dim_sizes[i]) { + return errors::InvalidArgument("Dim ", i, " expected to be of size ", + dim_sizes[i], + " found: ", shape.dim_size(i)); + } + } + return Status::OK(); +} + +TEST_P(UnifiedAPI, TestTensorShape2x4) { + if (UseFunction() && UseMlir()) { + // TODO(b/173074167): Remove this. + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + float data[] = {0., 0., 0., 0., 0., 0., 0., 0}; + int64_t dim_sizes[] = {2, 4}; + Status s = TestTensorHandleWithDims(ctx.get(), data, + dim_sizes, 2, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + Status s = RunModel(TestTensorShape2x4, ctx.get(), + /*inputs=*/{x.get()}, + /*outputs=*/{}, + /*use_function=*/UseFunction()); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); +} + +TEST_P(UnifiedAPI, TestUnknownShapeTracing) { + if (!UseFunction()) { + GTEST_SKIP() << "Tracing only test."; + } + if (UseMlir()) { + // TODO(b/173074167): Remove this. + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx(BuildFunction("test_fn")); + AbstractTensorHandlePtr x; + { + tracing::TracingTensorHandle* x_raw = nullptr; + PartialTensorShape shape; + Status s = dyn_cast(ctx.get())->AddParameter( + DT_FLOAT, shape, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + PartialTensorShape shape; + Status s = x->Shape(&shape); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_TRUE(shape.unknown_rank()); +} + +TEST_P(UnifiedAPI, TestPartialShapeTracing) { + if (!UseFunction()) { + GTEST_SKIP() << "Tracing only test."; + } + if (UseMlir()) { + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx(BuildFunction("test_fn")); + AbstractTensorHandlePtr x; + { + tracing::TracingTensorHandle* x_raw = nullptr; + PartialTensorShape shape; + int64 dim_sizes[] = {2, -1}; + Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + s = dyn_cast(ctx.get())->AddParameter( + DT_FLOAT, shape, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + PartialTensorShape shape; + Status s = x->Shape(&shape); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_FALSE(shape.unknown_rank()); + + ASSERT_EQ(2, shape.dim_size(0)); + ASSERT_EQ(-1, shape.dim_size(1)); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCppAPI, UnifiedAPI, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(true, false), + /*use_function*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCppAPI, UnifiedAPI, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_testutil.cc b/tensorflow/c/eager/unified_api_testutil.cc new file mode 100644 index 00000000000000..0096d241543752 --- /dev/null +++ b/tensorflow/c/eager/unified_api_testutil.cc @@ -0,0 +1,143 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/unified_api_testutil.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +AbstractContext* BuildFunction(const char* fn_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); + return unwrap(graph_ctx); +} + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + std::vector* params) { + tracing::TracingTensorHandle* handle = nullptr; + for (auto input : inputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(input->Shape(&shape)); + TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( + input->DataType(), shape, &handle)); + params->emplace_back(handle); + } + return Status::OK(); +} + +// Runs `model` maybe wrapped in a function. +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function) { + if (use_function) { + const char* fn_name = "test_fn"; + std::unique_ptr scoped_func; + // Returning null tensors from a tf.function is not supported, so we keep + // track of indices in the model's outputs are nullptr in this set. + // The FunctionDef only outputs the non-null tensors. We later pad the + // function op outputs to have nullptrs at the `null_indices`. + absl::flat_hash_set null_indices; + { + AbstractContextPtr func_ctx(BuildFunction(fn_name)); + std::vector func_inputs; + func_inputs.reserve(inputs.size()); + TF_RETURN_IF_ERROR( + CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); + std::vector model_outputs; + model_outputs.resize(outputs.size()); + TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), + absl::MakeSpan(model_outputs))); + for (auto func_input : func_inputs) { + func_input->Unref(); + } + AbstractFunction* func = nullptr; + OutputList output_list; + output_list.expected_num_outputs = 0; + output_list.outputs.reserve(outputs.size()); + for (int i = 0; i < model_outputs.size(); i++) { + if (model_outputs[i]) { + output_list.outputs.emplace_back(model_outputs[i]); + output_list.expected_num_outputs += 1; + } else { + null_indices.insert(i); + } + } + TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) + ->Finalize(&output_list, &func)); + scoped_func.reset(func); + for (auto output : output_list.outputs) { + output->Unref(); + } + TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); + } + + AbstractOperationPtr fn_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(fn_op->AddInput(input)); + } + int retvals = outputs.size() - null_indices.size(); + std::vector fn_outputs(retvals); + TF_RETURN_IF_ERROR(fn_op->Execute( + absl::Span(fn_outputs.data(), fn_outputs.size()), + &retvals)); + int skipped_indices = 0; + for (int i = 0; i < outputs.size(); i++) { + if (!null_indices.contains(i)) { + outputs[i] = fn_outputs[i - skipped_indices]; + } else { + skipped_indices += 1; + } + } + TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); + return Status::OK(); + } else { + return model(ctx, inputs, outputs); + } +} + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_DeleteContextOptions(opts); + return Status::OK(); +} + +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return StatusFromTF_Status(status.get()); +} + +} // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_testutil.h b/tensorflow/c/eager/unified_api_testutil.h new file mode 100644 index 00000000000000..3e76f242abef88 --- /dev/null +++ b/tensorflow/c/eager/unified_api_testutil.h @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Builds and returns a `TracingContext` using the default tracing impl. +AbstractContext* BuildFunction(const char* fn_name); + +// Creates parameters (placeholders) in the tracing `ctx` using the shape and +// dtype of `inputs`. +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + std::vector* params); + +// A callable that takes tensor inputs and returns zero or more tensor outputs. +using Model = std::function, + absl::Span)>; + +// Runs `model` maybe wrapped in a function call op. This can be thought as +// being equivalent to the following python code. +// +// if use_function: +// outputs = tf.function(model)(inputs) +// else: +// outputs = model(inputs) +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function); + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); + +// Return a tensor handle with given type, values and dimensions. +template +Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data, + const int64_t* dims, int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDims(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +// Return a scalar tensor handle with given value. +template +Status TestScalarTensorHandle(AbstractContext* ctx, const T value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestScalarTensorHandle(eager_ctx, value); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +// Places data from `t` into *result_tensor. +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index fbde13dea5aa03..0bdcada1f53553 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/c/env.h" -#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h index 63e2c86ad44f1b..ac6a9e32aff6e2 100644 --- a/tensorflow/c/env.h +++ b/tensorflow/c/env.h @@ -20,8 +20,9 @@ limitations under the License. #include #include -#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/tf_file_statistics.h" +#include "tensorflow/c/tf_status.h" // -------------------------------------------------------------------------- // C API for tensorflow::Env. diff --git a/tensorflow/c/experimental/filesystem/filesystem_interface.h b/tensorflow/c/experimental/filesystem/filesystem_interface.h index 6e05c86143951d..3ac74a5827509c 100644 --- a/tensorflow/c/experimental/filesystem/filesystem_interface.h +++ b/tensorflow/c/experimental/filesystem/filesystem_interface.h @@ -83,6 +83,26 @@ typedef struct TF_TransactionToken { TF_Filesystem* owner; } TF_TransactionToken; +typedef struct TF_Filesystem_Option_Value { + int type_tag; + int num_values; + union { + int64_t inv_val; + double real_val; + struct { + char* buf; + int buf_length; + } buffer_val; + } * values; // owned +} TF_Filesystem_Option_Value; + +typedef struct TF_Filesystem_Option { + char* name; // null terminated, owned + char* description; // null terminated, owned + int per_file; // bool actually, but bool is not a C type + TF_Filesystem_Option_Value* value; // owned +} TF_Filesystem_Option; + /// SECTION 2. Function tables for functionality provided by plugins /// ---------------------------------------------------------------------------- /// @@ -811,6 +831,85 @@ typedef struct TF_FilesystemOps { char* (*decode_transaction_token)(const TF_Filesystem* filesystem, const TF_TransactionToken* token); + /// Returns pointer to an array of available configuration options and their + /// current/default values in `options` and number of options in array in + /// `num_options`. Ownership of the array is transferred to caller and the + /// caller is responsible of freeing the buffers using respective file systems + /// allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `options` and `num_options` set. + /// If there is no configurable option, `num_options` should be 0. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return 0 options and `TF_OK`. + void (*get_filesystem_configuration)(const TF_Filesystem* filesystem, + TF_Filesystem_Option** options, + int* num_options, TF_Status* status); + + /// Updates filesystem configuration with options passed in `options`. It can + /// contain full set of options supported by the filesystem or just a subset + /// of them. Ownership of options and buffers therein belongs to the caller + /// and any buffers need to be allocated through filesystem allocation API. + /// Filesystems may choose to ignore configuration errors but should at least + /// display a warning or error message to warn the users. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if options are updated. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_NOT_FOUND`. + void (*set_filesystem_configuration)(const TF_Filesystem* filesystem, + const TF_Filesystem_Option** options, + int num_options, TF_Status* status); + + /// Returns the value of the filesystem option given in `key` in `option`. + /// Valid values of the `key` are returned by + /// `get_file_system_configuration_keys` call. Ownership of the + /// `option` is transferred to caller. Buffers therein should be allocated and + /// freed by the relevant filesystems allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `option` is set + /// * Must set `status` to `TF_NOT_FOUND` if the key is invalid + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_NOT_FOUND`. + void (*get_filesystem_configuration_option)(const TF_Filesystem* filesystem, + const char* key, + TF_Filesystem_Option** option, + TF_Status* status); + + /// Sets the value of the filesystem option given in `key` to value in + /// `option`. Valid values of the `key` are returned by + /// `get_file_system_configuration_keys` call. Ownership of the `option` and + /// the `key` belogs to the caller. Buffers therein should be allocated and + /// freed by the filesystems allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `option` is set/updated + /// * Must set `status` to `TF_NOT_FOUND` if the key is invalid + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_NOT_FOUND`. + void (*set_filesystem_configuration_option)( + const TF_Filesystem* filesystem, const TF_Filesystem_Option* option, + TF_Status* status); + + /// Returns a list of valid configuration keys in `keys` array and number of + /// keys in `num_keys`. Ownership of the buffers in `keys` are transferred to + /// caller and needs to be freed using relevant filesystem allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` on success. If there are no configurable + /// keys, `num_keys` should be set to 0 + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_OK` and `num_keys`=0. + void (*get_filesystem_configuration_keys)(const TF_Filesystem* filesystem, + char** keys, int* num_keys, + TF_Status* status); + } TF_FilesystemOps; // LINT.ThenChange(:filesystem_ops_version) diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index 9c8d3518800b6b..3fdeaf32eeba57 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -133,7 +133,7 @@ bool ModularFileSystem::FilesExist(const std::vector& files, TransactionToken* token, std::vector* status) { if (ops_->paths_exist == nullptr) - return FileSystem::FilesExist(files, status); + return FileSystem::FilesExist(files, token, status); std::vector translated_names; translated_names.reserve(files.size()); @@ -234,7 +234,7 @@ Status ModularFileSystem::DeleteRecursively(const std::string& dirname, "`undeleted_dirs` set to NULL"); if (ops_->delete_recursively == nullptr) - return FileSystem::DeleteRecursively(dirname, undeleted_files, + return FileSystem::DeleteRecursively(dirname, token, undeleted_files, undeleted_dirs); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); @@ -264,7 +264,7 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname, Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname, TransactionToken* token) { if (ops_->recursively_create_dir == nullptr) - return FileSystem::RecursivelyCreateDir(dirname); + return FileSystem::RecursivelyCreateDir(dirname, token); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); std::string translated_name = TranslateName(dirname); @@ -312,7 +312,8 @@ Status ModularFileSystem::Stat(const std::string& fname, Status ModularFileSystem::IsDirectory(const std::string& name, TransactionToken* token) { - if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name); + if (ops_->is_directory == nullptr) + return FileSystem::IsDirectory(name, token); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); std::string translated_name = TranslateName(name); @@ -362,7 +363,8 @@ Status ModularFileSystem::RenameFile(const std::string& src, Status ModularFileSystem::CopyFile(const std::string& src, const std::string& target, TransactionToken* token) { - if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target); + if (ops_->copy_file == nullptr) + return FileSystem::CopyFile(src, target, token); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); std::string translated_src = TranslateName(src); diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index 8cd8ad7ca8196f..b6c0a405c79589 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -81,7 +81,7 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok, return; } - size_t bucket_end = fname.find("/", scheme_end + 1); + size_t bucket_end = fname.find('/', scheme_end + 1); if (bucket_end == std::string::npos) { TF_SetStatus(status, TF_INVALID_ARGUMENT, "GCS path doesn't contain a bucket name."); diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc index 5ff28e4229af37..67eaa23fa4c1a9 100644 --- a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -38,7 +38,7 @@ void ParseHadoopPath(const std::string& fname, std::string* scheme, size_t scheme_end = fname.find("://") + 2; // We don't want `://` in scheme. *scheme = fname.substr(0, scheme_end - 2); - size_t nn_end = fname.find("/", scheme_end + 1); + size_t nn_end = fname.find('/', scheme_end + 1); if (nn_end == std::string::npos) { *namenode = fname.substr(scheme_end + 1); *path = ""; @@ -182,9 +182,8 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); std::string cacheKey(scheme); - hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); if (scheme == "file") { - libhdfs->hdfsBuilderSetNameNode(builder, nullptr); + namenode = ""; } else if (scheme == "viewfs") { char* defaultFS = nullptr; libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS); @@ -200,24 +199,27 @@ hdfsFS Connect(tf_hadoop_filesystem::HadoopFile* hadoop_file, // The default NameNode configuration will be used (from the XML // configuration files). See: // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 - libhdfs->hdfsBuilderSetNameNode(builder, "default"); + namenode = "default"; } else if (scheme == "har") { std::string path_har = path; SplitArchiveNameAndPath(&path_har, &namenode, status); if (TF_GetCode(status) != TF_OK) return nullptr; - libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); - cacheKey += namenode; } else { - libhdfs->hdfsBuilderSetNameNode( - builder, namenode.empty() ? "default" : namenode.c_str()); - cacheKey += namenode; + if (namenode.empty()) { + namenode = "default"; + } } + cacheKey += namenode; + absl::MutexLock l(&hadoop_file->connection_cache_lock); if (hadoop_file->connection_cache.find(cacheKey) == hadoop_file->connection_cache.end()) { + hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); + libhdfs->hdfsBuilderSetNameNode( + builder, namenode.empty() ? nullptr : namenode.c_str()); auto cacheFs = libhdfs->hdfsBuilderConnect(builder); if (cacheFs == nullptr) { - TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); + TF_SetStatusFromIOError(status, TF_ABORTED, strerror(errno)); return cacheFs; } hadoop_file->connection_cache[cacheKey] = cacheFs; diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index e8a50e322169ad..2d879cbec5f17a 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -1,6 +1,21 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +# buildifier: disable=same-origin-load +load( + "//tensorflow:tensorflow.bzl", + "if_libtpu", + "tf_cuda_cc_test", +) +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core/platform:build_config_root.bzl", + "tf_cuda_tests_tags", +) + # Library of gradient functions. package( licenses = ["notice"], # Apache 2.0 @@ -16,11 +31,8 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_operation", - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:gradients_internal", - "//tensorflow/core/lib/llvm_rtti", ], ) @@ -65,12 +77,28 @@ cc_library( ], ) +cc_library( + name = "not_differentiable", + srcs = ["not_differentiable.cc"], + hdrs = [ + "not_differentiable.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:gradients_internal", + ], +) + cc_library( name = "gradients", hdrs = [ "array_grad.h", "math_grad.h", "nn_grad.h", + "not_differentiable.h", ], visibility = [ "//tensorflow:internal", @@ -79,19 +107,146 @@ cc_library( ":array_grad", ":math_grad", ":nn_grad", + ":not_differentiable", + "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:gradients_internal", ], ) +tf_cuda_cc_test( + name = "custom_gradient_test", + size = "small", + srcs = [ + "custom_gradient_test.cc", + ], + args = ["--heap_check=local"], # TODO(b/174752220): Remove + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/ops", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:errors", + ], +) + filegroup( name = "pywrap_required_hdrs", srcs = [ "array_grad.h", "math_grad.h", "nn_grad.h", + "not_differentiable.h", ], visibility = [ "//tensorflow/core:__pkg__", "//tensorflow/python:__pkg__", ], ) + +cc_library( + name = "grad_test_helper", + testonly = True, + srcs = ["grad_test_helper.cc"], + hdrs = ["grad_test_helper.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:gradient_checker", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cuda_cc_test( + name = "nn_grad_test", + size = "small", + srcs = [ + "nn_grad_test.cc", + ], + args = ["--heap_check=local"], # TODO(b/174752220): Remove + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, + deps = [ + ":grad_test_helper", + ":nn_grad", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/platform:tensor_float_32_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_libtpu( + if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], + if_true = [], + ), +) + +tf_cuda_cc_test( + name = "math_grad_test", + size = "small", + srcs = [ + "math_grad_test.cc", + ], + args = ["--heap_check=local"], # TODO(b/174752220): Remove + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, + deps = [ + ":grad_test_helper", + ":math_grad", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/core/platform:tensor_float_32_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_libtpu( + if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], + if_true = [], + ), +) + +tf_cuda_cc_test( + name = "array_grad_test", + size = "small", + srcs = [ + "array_grad_test.cc", + ], + args = ["--heap_check=local"], # TODO(b/174752220): Remove + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, + deps = [ + ":grad_test_helper", + ":array_grad", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/core/platform:tensor_float_32_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:unified_api_testutil", + ] + if_libtpu( + if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"], + if_true = [], + ), +) diff --git a/tensorflow/c/experimental/gradients/array_grad.cc b/tensorflow/c/experimental/gradients/array_grad.cc index 069209a4b6bd1d..5e6c3a49bea81d 100644 --- a/tensorflow/c/experimental/gradients/array_grad.cc +++ b/tensorflow/c/experimental/gradients/array_grad.cc @@ -14,23 +14,24 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/array_grad.h" +#include "tensorflow/c/eager/abstract_context.h" + namespace tensorflow { namespace gradients { namespace { -using std::vector; class IdentityNGradientFunction : public GradientFunction { public: - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { - grad_outputs->resize(grad_inputs.size(), nullptr); - for (int i = 0; i < grad_inputs.size(); i++) { - auto grad_input = grad_inputs[i]; + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + for (int i = 0; i < grad_outputs.size(); i++) { + auto grad_input = grad_outputs[i]; // TODO(srbs): Should we add a copy contructor to AbstractTensorHandle // that takes care of this similar to `Tensor`? if (grad_input) { grad_input->Ref(); } - (*grad_outputs)[i] = grad_input; + grad_inputs[i] = grad_input; } return Status::OK(); } @@ -38,10 +39,8 @@ class IdentityNGradientFunction : public GradientFunction { }; } // namespace -BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) { - auto gradient_function = new IdentityNGradientFunction; - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* IdentityNRegisterer(const ForwardOperation& op) { + return new IdentityNGradientFunction; } } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/array_grad.h b/tensorflow/c/experimental/gradients/array_grad.h index edeeb5fcb4a6d7..3dcf98b0969f05 100644 --- a/tensorflow/c/experimental/gradients/array_grad.h +++ b/tensorflow/c/experimental/gradients/array_grad.h @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace gradients { -BackwardFunction* IdentityNRegisterer(const ForwardOperation& op); +GradientFunction* IdentityNRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/array_grad_test.cc b/tensorflow/c/experimental/gradients/array_grad_test.cc new file mode 100644 index 00000000000000..b3488d3bc265c5 --- /dev/null +++ b/tensorflow/c/experimental/gradients/array_grad_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/array_grad.h" + +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/experimental/gradients/grad_test_helper.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +using tensorflow::TF_StatusPtr; + +Status IdentityNModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + std::vector temp_outputs(2); + TF_RETURN_IF_ERROR( + ops::IdentityN(ctx, inputs, absl::MakeSpan(temp_outputs), "IdentityN")); + // Although, `ops::IdentityN` returns 2 tensors, the first tensor isn't needed + // for computing gradient so we could safely drop it. + outputs[0] = temp_outputs[1]; + temp_outputs[0]->Unref(); + return Status::OK(); +} + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + status_ = StatusFromTF_Status(status.get()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + { + AbstractContext* ctx_raw = nullptr; + status_ = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + immediate_execution_ctx_.reset(ctx_raw); + } + + // Computing numerical gradients with TensorFloat-32 is numerically + // unstable. Some forward pass tests also fail with TensorFloat-32 due to + // low tolerances + enable_tensor_float_32_execution(false); + } + + AbstractContextPtr immediate_execution_ctx_; + GradientRegistry registry_; + Status status_; + + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } +}; + +TEST_P(CppGradients, TestIdentityNGrad) { + // This test is interesting because the current implementation of GradientTape + // would return [0, 1] whereas we use build_default_zeros_grads=false here + // so we get back [nullptr, 1]. + + AbstractTensorHandlePtr x1; + { + AbstractTensorHandle* x1_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 1.0f, &x1_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x1.reset(x1_raw); + } + + AbstractTensorHandlePtr x2; + { + AbstractTensorHandle* x2_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 1.0f, &x2_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x2.reset(x2_raw); + } + + status_ = registry_.Register("IdentityN", IdentityNRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + auto IdentityNGradModel = BuildGradModel(IdentityNModel, registry_); + + std::vector outputs(2); + status_ = + RunModel(IdentityNGradModel, immediate_execution_ctx_.get(), + {x1.get(), x2.get()}, absl::MakeSpan(outputs), UseFunction()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + EXPECT_EQ(outputs[0], nullptr); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {1.0f}, /*dims*/ {}, + /*abs_error*/ 0)); + outputs[1]->Unref(); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/custom_gradient_test.cc b/tensorflow/c/experimental/gradients/custom_gradient_test.cc new file mode 100644 index 00000000000000..a266f47266acb9 --- /dev/null +++ b/tensorflow/c/experimental/gradients/custom_gradient_test.cc @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { +using std::vector; + +class CustomGradientTest + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } +}; + +class PassThroughGradientFunction : public GradientFunction { + public: + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + CHECK_EQ(grad_outputs.size(), 1); + CHECK_EQ(grad_inputs.size(), 1); + grad_inputs[0] = grad_outputs[0]; + if (grad_inputs[0]) { + grad_inputs[0]->Ref(); + } + return Status::OK(); + } +}; + +// Computes: +// +// @tf.custom_gradient +// def f(input): +// def grad(grads): +// return grads[0] +// return tf.exp(input), grad +// outputs = [f(inputs[0])] +Status ExpWithPassThroughGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + Tape tape(/*persistent=*/false); + tape.Watch(inputs[0]); // Watch x. + std::vector exp_outputs(1); + TF_RETURN_IF_ERROR(ops::Exp(ctx, inputs, absl::MakeSpan(exp_outputs), "Exp")); + std::unique_ptr gradient_function( + new PassThroughGradientFunction); + tape.RecordOperation(inputs, exp_outputs, gradient_function.release()); + TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, + /*targets*/ exp_outputs, + /*sources=*/inputs, + /*output_gradients=*/{}, + /*result=*/outputs)); + for (auto exp_output : exp_outputs) { + exp_output->Unref(); + } + return Status::OK(); +} + +TEST_P(CustomGradientTest, ExpWithPassThroughGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + // Pseudo-code: + // + // tape.watch(x) + // y = exp(x) + // outputs = tape.gradient(y, x) + std::vector outputs(1); + Status s = RunModel(ExpWithPassThroughGrad, ctx.get(), {x.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam())); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = GetValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + CustomGradientTest, CustomGradientTest, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(true, false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + CustomGradientTest, CustomGradientTest, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc new file mode 100644 index 00000000000000..a7b47fa20ae35a --- /dev/null +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/grad_test_helper.h" + +#include "tensorflow/c/eager/gradient_checker.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +void CompareNumericalAndAutodiffGradients( + Model model, Model grad_model, AbstractContext* ctx, + absl::Span inputs, bool use_function, + double abs_error) { + auto num_inputs = inputs.size(); + std::vector outputs(num_inputs); + auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), + /*use_function=*/use_function); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + for (int i = 0; i < num_inputs; ++i) { + if (!outputs[i]) continue; + + AbstractTensorHandlePtr numerical_grad; + { + AbstractTensorHandle* numerical_grad_raw; + s = CalcNumericalGrad(ctx, model, inputs, + /*input_index=*/i, use_function, + &numerical_grad_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + numerical_grad.reset(numerical_grad_raw); + } + + TF_Tensor* numerical_tensor; + s = GetValue(numerical_grad.get(), &numerical_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto num_elem_numerical = TF_TensorElementCount(numerical_tensor); + + TF_Tensor* analytical_tensor; + s = GetValue(outputs[i], &analytical_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto num_elem_analytical = TF_TensorElementCount(analytical_tensor); + + ASSERT_EQ(num_elem_numerical, num_elem_analytical); + + float* dnumerical = new float[num_elem_numerical]{0}; + memcpy(&dnumerical[0], TF_TensorData(numerical_tensor), + TF_TensorByteSize(numerical_tensor)); + float* danalytical = new float[num_elem_analytical]{0}; + memcpy(&danalytical[0], TF_TensorData(analytical_tensor), + TF_TensorByteSize(analytical_tensor)); + + for (int j = 0; j < num_elem_numerical; j++) { + ASSERT_NEAR(dnumerical[j], danalytical[j], abs_error); + } + TF_DeleteTensor(analytical_tensor); + TF_DeleteTensor(numerical_tensor); + delete[] danalytical; + delete[] dnumerical; + outputs[i]->Unref(); + } +} + +void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, + absl::Span dims, double abs_error) { + TF_Tensor* analytical_tensor; + auto s = GetValue(t, &analytical_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + int64_t num_elem_analytical = 1; + auto num_dims_analytical = TF_NumDims(analytical_tensor); + ASSERT_EQ(dims.size(), num_dims_analytical); + for (int j = 0; j < num_dims_analytical; j++) { + auto dim_analytical = TF_Dim(analytical_tensor, j); + ASSERT_EQ(dims[j], dim_analytical); + num_elem_analytical *= dim_analytical; + } + + float* danalytical = new float[num_elem_analytical]{0}; + memcpy(&danalytical[0], TF_TensorData(analytical_tensor), + TF_TensorByteSize(analytical_tensor)); + + for (int64_t j = 0; j < num_elem_analytical; j++) { + if (abs_error == 0) { + ASSERT_EQ(manuals[j], danalytical[j]); + } else { + ASSERT_NEAR(manuals[j], danalytical[j], abs_error); + } + } + + TF_DeleteTensor(analytical_tensor); + delete[] danalytical; +} + +Model BuildGradModel(Model forward, GradientRegistry registry) { + return [forward_model = std::move(forward), + grad_registry = std::move(registry)]( + AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) -> Status { + Tape tape(/*persistent=*/false); + for (size_t i{}; i < inputs.size(); ++i) { + tape.Watch(inputs[i]); + } + std::vector temp_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, grad_registry)); + TF_RETURN_IF_ERROR( + forward_model(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs))); + + TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, + /*sources=*/inputs, + /*output_gradients=*/{}, outputs)); + for (auto temp_output : temp_outputs) { + temp_output->Unref(); + } + return Status::OK(); + }; +} + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h new file mode 100644 index 00000000000000..84761f96405b53 --- /dev/null +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ + +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/unified_api_testutil.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +void CompareNumericalAndAutodiffGradients( + Model model, Model grad_model, AbstractContext* ctx, + absl::Span inputs, bool use_function, + double abs_error = 1e-2); + +void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, + absl::Span dims, double abs_error = 1e-2); + +Model BuildGradModel(Model forward, GradientRegistry registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index 5551642127de53..896b40c671ac30 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -21,10 +21,14 @@ limitations under the License. #include "tensorflow/c/experimental/ops/nn_ops.h" using std::vector; +using tensorflow::ops::Add; using tensorflow::ops::Conj; +using tensorflow::ops::Div; +using tensorflow::ops::DivNoNan; using tensorflow::ops::MatMul; using tensorflow::ops::Mul; using tensorflow::ops::Neg; +using tensorflow::ops::OnesLike; using tensorflow::ops::SqrtGrad; namespace tensorflow { @@ -33,17 +37,17 @@ namespace { class AddGradientFunction : public GradientFunction { public: - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { - grad_outputs->resize(2); + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { // TODO(b/161805092): Support broadcasting. - DCHECK(grad_inputs[0]); - (*grad_outputs)[0] = grad_inputs[0]; - (*grad_outputs)[1] = grad_inputs[0]; + DCHECK(grad_outputs[0]); + grad_inputs[0] = grad_outputs[0]; + grad_inputs[1] = grad_outputs[0]; - (*grad_outputs)[0]->Ref(); - (*grad_outputs)[1]->Ref(); + grad_inputs[0]->Ref(); + grad_inputs[1]->Ref(); return Status::OK(); } ~AddGradientFunction() override {} @@ -54,18 +58,18 @@ class ExpGradientFunction : public GradientFunction { explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) { exp->Ref(); } - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { vector conj_outputs(1); std::string name = "Conj_Exp_Grad"; - TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()}, - absl::MakeSpan(conj_outputs), name.c_str())); + TF_RETURN_IF_ERROR( + Conj(ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), name.c_str())); AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]); - grad_outputs->resize(1); name = "Mul_Exp_Grad"; - TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]}, - absl::MakeSpan(*grad_outputs), name.c_str())); + TF_RETURN_IF_ERROR(Mul(ctx, {conj_outputs[0], grad_outputs[0]}, grad_inputs, + name.c_str())); return Status::OK(); } ~ExpGradientFunction() override {} @@ -79,12 +83,12 @@ class SqrtGradientFunction : public GradientFunction { explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) { sqrt->Ref(); } - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { std::string name = "Sqrt_Grad"; - grad_outputs->resize(1); - TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]}, - absl::MakeSpan(*grad_outputs), name.c_str())); + TF_RETURN_IF_ERROR(SqrtGrad(ctx, {sqrt_.get(), grad_outputs[0]}, + absl::MakeSpan(grad_inputs), name.c_str())); return Status::OK(); } ~SqrtGradientFunction() override {} @@ -97,10 +101,17 @@ class MatMulGradientFunction : public GradientFunction { public: explicit MatMulGradientFunction(vector f_inputs, AttrBuilder f_attrs) - : forward_inputs(f_inputs), forward_attrs(f_attrs) {} + : forward_inputs_(f_inputs), forward_attrs_(f_attrs) { + for (auto input : forward_inputs_) { + if (input) { + input->Ref(); + } + } + } - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { /* Given upstream grad U and a matmul op A*B, the gradients are: * * dA = U * B.T @@ -108,29 +119,28 @@ class MatMulGradientFunction : public GradientFunction { * * where A.T means `transpose(A)` */ - AbstractTensorHandle* upstream_grad = grad_inputs[0]; - grad_outputs->resize(2); + AbstractTensorHandle* upstream_grad = grad_outputs[0]; // Get transpose attrs bool t_a; - TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a)); + TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_a", &t_a)); bool t_b; - TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b)); + TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_b", &t_b)); // Conj each input vector conj_outputs(1); std::string name = "Conj_A_MatMul_Grad"; - TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]}, + TF_RETURN_IF_ERROR(Conj(ctx, {forward_inputs_[0]}, absl::MakeSpan(conj_outputs), name.c_str())); - AbstractTensorHandle* A = conj_outputs[0]; + AbstractTensorHandlePtr A(conj_outputs[0]); name = "Conj_B_MatMul_Grad"; - TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]}, + TF_RETURN_IF_ERROR(Conj(ctx, {forward_inputs_[1]}, absl::MakeSpan(conj_outputs), name.c_str())); - AbstractTensorHandle* B = conj_outputs[0]; + AbstractTensorHandlePtr B(conj_outputs[0]); // Calc Grad vector matmul_A_outputs(1); @@ -138,50 +148,50 @@ class MatMulGradientFunction : public GradientFunction { std::string name_grad_A = "MatMul_Grad_A"; std::string name_grad_B = "MatMul_Grad_B"; if (!t_a && !t_b) { - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B}, + TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, B.get()}, absl::MakeSpan(matmul_A_outputs), name_grad_A.c_str(), /*transpose_a = */ false, /*transpose_b = */ true)); - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad}, + TF_RETURN_IF_ERROR(MatMul(ctx, {A.get(), upstream_grad}, absl::MakeSpan(matmul_B_outputs), name_grad_B.c_str(), /*transpose_a = */ true, /*transpose_b = */ false)); } else if (!t_a && t_b) { - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B}, + TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, B.get()}, absl::MakeSpan(matmul_A_outputs), name_grad_A.c_str(), /*transpose_a = */ false, /*transpose_b = */ false)); - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A}, + TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, A.get()}, absl::MakeSpan(matmul_B_outputs), name_grad_B.c_str(), /*transpose_a = */ true, /*transpose_b = */ false)); } else if (t_a && !t_b) { - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad}, + TF_RETURN_IF_ERROR(MatMul(ctx, {B.get(), upstream_grad}, absl::MakeSpan(matmul_A_outputs), name_grad_A.c_str(), /*transpose_a = */ false, /*transpose_b = */ true)); - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad}, + TF_RETURN_IF_ERROR(MatMul(ctx, {A.get(), upstream_grad}, absl::MakeSpan(matmul_B_outputs), name_grad_B.c_str(), /*transpose_a = */ false, /*transpose_b = */ false)); } else { // t_a && t_b - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad}, + TF_RETURN_IF_ERROR(MatMul(ctx, {B.get(), upstream_grad}, absl::MakeSpan(matmul_A_outputs), name_grad_A.c_str(), /*transpose_a = */ true, /*transpose_b = */ true)); - TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A}, + TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, A.get()}, absl::MakeSpan(matmul_B_outputs), name_grad_B.c_str(), /*transpose_a = */ true, @@ -189,33 +199,40 @@ class MatMulGradientFunction : public GradientFunction { } // Gradient for A - (*grad_outputs)[0] = matmul_A_outputs[0]; + grad_inputs[0] = matmul_A_outputs[0]; // Gradient for B - (*grad_outputs)[1] = matmul_B_outputs[0]; + grad_inputs[1] = matmul_B_outputs[0]; return Status::OK(); } - ~MatMulGradientFunction() override {} + ~MatMulGradientFunction() override { + for (auto input : forward_inputs_) { + if (input) { + input->Unref(); + } + } + } private: - vector forward_inputs; - AttrBuilder forward_attrs; + // TODO(b/174778737): Only hold needed inputs. + vector forward_inputs_; + AttrBuilder forward_attrs_; }; class NegGradientFunction : public GradientFunction { public: - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { /* Given upstream grad U and a Neg op Y = -X, the gradients are: * * dX = -U * */ - grad_outputs->resize(1); std::string name = "Neg_Grad"; - TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(*grad_outputs), name.c_str())); + TF_RETURN_IF_ERROR( + ops::Neg(ctx, {grad_outputs[0]}, grad_inputs, name.c_str())); return Status::OK(); } ~NegGradientFunction() override {} @@ -223,8 +240,9 @@ class NegGradientFunction : public GradientFunction { class SubGradientFunction : public GradientFunction { public: - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { /* Given upstream grad U and a Sub op A-B, the gradients are: * * dA = U @@ -232,80 +250,246 @@ class SubGradientFunction : public GradientFunction { * */ - grad_outputs->resize(2); - // Grad for A - DCHECK(grad_inputs[0]); - (*grad_outputs)[0] = grad_inputs[0]; - (*grad_outputs)[0]->Ref(); + DCHECK(grad_outputs[0]); + grad_inputs[0] = grad_outputs[0]; + grad_inputs[0]->Ref(); // Grad for B // negate the upstream grad - std::vector neg_outputs(1); std::string name = "Neg_Sub_Grad_B"; - TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]}, - absl::MakeSpan(neg_outputs), name.c_str())); - (*grad_outputs)[1] = neg_outputs[0]; + TF_RETURN_IF_ERROR(ops::Neg(ctx, {grad_outputs[0]}, + grad_inputs.subspan(1, 1), name.c_str())); return Status::OK(); } ~SubGradientFunction() override {} }; +class MulGradientFunction : public GradientFunction { + public: + explicit MulGradientFunction(vector f_inputs) + : forward_inputs_(f_inputs) { + for (auto input : forward_inputs_) { + if (input) { + input->Ref(); + } + } + } + + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + /* Given upstream grad U and a mul op A*B, the gradients are: + * + * dA = U * B + * dB = A * U + * + */ + + AbstractTensorHandle* upstream_grad = grad_outputs[0]; + + // Gradient for A + std::string name = "Mul_Grad_A"; + TF_RETURN_IF_ERROR(Mul(ctx, {upstream_grad, forward_inputs_[1]}, + grad_inputs.subspan(0, 1), name.c_str())); + + // Gradient for B + name = "Mul_Grad_B"; + TF_RETURN_IF_ERROR(Mul(ctx, {forward_inputs_[0], upstream_grad}, + grad_inputs.subspan(1, 1), name.c_str())); + return Status::OK(); + } + ~MulGradientFunction() override { + for (auto input : forward_inputs_) { + if (input) { + input->Unref(); + } + } + } + + private: + // TODO(b/174778737): Only hold needed inputs. + vector forward_inputs_; +}; + +class Log1pGradientFunction : public GradientFunction { + public: + explicit Log1pGradientFunction(vector f_inputs) + : forward_inputs_(f_inputs) { + for (auto input : forward_inputs_) { + if (input) { + input->Ref(); + } + } + } + + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + // TODO(vnvo2409): Add control dependency + /* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are: + * + * dX = U / (1 + X) + * + */ + + AbstractTensorHandle* upstream_grad = grad_outputs[0]; + AbstractTensorHandle* X = forward_inputs_[0]; + + vector temp_outputs(1); + + // Calculate conjugate of X + std::string name = "Conj_Log1p_Grad_X"; + TF_RETURN_IF_ERROR( + Conj(ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str())); + + AbstractTensorHandlePtr Conj_X(temp_outputs[0]); + + // Creates Ones + name = "OnesLike_Log1p_Grad_X"; + TF_RETURN_IF_ERROR(OnesLike(ctx, {Conj_X.get()}, + absl::MakeSpan(temp_outputs), name.c_str())); + + AbstractTensorHandlePtr Ones_X(temp_outputs[0]); + + name = "Add_Log1p_Grad_X"; + // Calculate 1 + Conj(X) + TF_RETURN_IF_ERROR(Add(ctx, {Ones_X.get(), Conj_X.get()}, + absl::MakeSpan(temp_outputs), name.c_str())); + + AbstractTensorHandlePtr Conj_XP1(temp_outputs[0]); + + name = "Div_Log1p_Grad_X"; + // Calculate U / (1 + Conj(X)) + TF_RETURN_IF_ERROR( + Div(ctx, {upstream_grad, Conj_XP1.get()}, grad_inputs, name.c_str())); + + return Status::OK(); + } + ~Log1pGradientFunction() override { + for (auto input : forward_inputs_) { + if (input) { + input->Unref(); + } + } + } + + private: + // TODO(b/174778737): Only hold needed inputs. + vector forward_inputs_; +}; + +class DivNoNanGradientFunction : public GradientFunction { + public: + explicit DivNoNanGradientFunction(vector f_inputs, + vector f_outputs) + : forward_inputs_(f_inputs), forward_outputs_(f_outputs) { + for (auto input : forward_inputs_) { + if (input) { + input->Ref(); + } + } + for (auto output : forward_outputs_) { + if (output) { + output->Ref(); + } + } + } + + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + // TODO(vnvo2409): Add shape broadcasting + /* Given upstream grad U and a Div op: Z = X/Y, the gradients are: + * + * dX = U / Y + * dY = -U*X / Y^2 = (X/Y) * -U / Y = -U*Z / Y + * + */ + + AbstractTensorHandle* upstream_grad = grad_outputs[0]; + AbstractTensorHandle* Y = forward_inputs_[1]; + AbstractTensorHandle* Z = forward_outputs_[0]; + + // Calculate dX = U / Y + std::string name = "Div_Grad_X"; + TF_RETURN_IF_ERROR(DivNoNan(ctx, {upstream_grad, Y}, + grad_inputs.subspan(0, 1), name.c_str())); + + vector temp_outputs(1); + // Calculate dY = -U*Z / Y + name = "Neg_Div_Grad_Y"; + TF_RETURN_IF_ERROR(Neg(ctx, {upstream_grad}, absl::MakeSpan(temp_outputs), + name.c_str())); // -U + AbstractTensorHandlePtr MinusU(temp_outputs[0]); + + name = "Mul_Div_Grad_Y"; + TF_RETURN_IF_ERROR(Mul(ctx, {MinusU.get(), Z}, absl::MakeSpan(temp_outputs), + name.c_str())); // -U*Z + AbstractTensorHandlePtr UZ(temp_outputs[0]); + + name = "Div_Grad_Y"; + TF_RETURN_IF_ERROR(DivNoNan(ctx, {UZ.get(), Y}, grad_inputs.subspan(1, 1), + name.c_str())); // -U*Z / Y + + return Status::OK(); + } + ~DivNoNanGradientFunction() override { + for (auto input : forward_inputs_) { + if (input) { + input->Unref(); + } + } + for (auto output : forward_outputs_) { + if (output) { + output->Unref(); + } + } + } + + private: + // TODO(b/174778737): Only hold needed inputs and outputs. + vector forward_inputs_; + vector forward_outputs_; +}; + } // namespace -BackwardFunction* AddRegisterer(const ForwardOperation& op) { - auto gradient_function = new AddGradientFunction; - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* AddRegisterer(const ForwardOperation& op) { + return new AddGradientFunction; +} + +GradientFunction* ExpRegisterer(const ForwardOperation& op) { + return new ExpGradientFunction(op.outputs[0]); +} + +GradientFunction* MatMulRegisterer(const ForwardOperation& op) { + return new MatMulGradientFunction(op.inputs, op.attrs); +} + +GradientFunction* SqrtRegisterer(const ForwardOperation& op) { + return new SqrtGradientFunction(op.outputs[0]); } -BackwardFunction* ExpRegisterer(const ForwardOperation& op) { - auto gradient_function = new ExpGradientFunction(op.outputs[0]); - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* NegRegisterer(const ForwardOperation& op) { + return new NegGradientFunction; } -BackwardFunction* MatMulRegisterer(const ForwardOperation& op) { - auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs); - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* SubRegisterer(const ForwardOperation& op) { + return new SubGradientFunction; } -BackwardFunction* SqrtRegisterer(const ForwardOperation& op) { - auto gradient_function = new SqrtGradientFunction(op.outputs[0]); - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* MulRegisterer(const ForwardOperation& op) { + return new MulGradientFunction(op.inputs); } -BackwardFunction* NegRegisterer(const ForwardOperation& op) { - auto gradient_function = new NegGradientFunction; - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* Log1pRegisterer(const ForwardOperation& op) { + return new Log1pGradientFunction(op.inputs); } -BackwardFunction* SubRegisterer(const ForwardOperation& op) { - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto gradient_function = new SubGradientFunction; - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* DivNoNanRegisterer(const ForwardOperation& op) { + return new DivNoNanGradientFunction(op.inputs, op.outputs); } } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 756c5f8415359f..e26ee899260a4c 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -20,12 +20,15 @@ limitations under the License. namespace tensorflow { namespace gradients { -BackwardFunction* AddRegisterer(const ForwardOperation& op); -BackwardFunction* ExpRegisterer(const ForwardOperation& op); -BackwardFunction* MatMulRegisterer(const ForwardOperation& op); -BackwardFunction* SqrtRegisterer(const ForwardOperation& op); -BackwardFunction* NegRegisterer(const ForwardOperation& op); -BackwardFunction* SubRegisterer(const ForwardOperation& op); +GradientFunction* AddRegisterer(const ForwardOperation& op); +GradientFunction* ExpRegisterer(const ForwardOperation& op); +GradientFunction* MatMulRegisterer(const ForwardOperation& op); +GradientFunction* SqrtRegisterer(const ForwardOperation& op); +GradientFunction* NegRegisterer(const ForwardOperation& op); +GradientFunction* SubRegisterer(const ForwardOperation& op); +GradientFunction* MulRegisterer(const ForwardOperation& op); +GradientFunction* Log1pRegisterer(const ForwardOperation& op); +GradientFunction* DivNoNanRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad_test.cc b/tensorflow/c/experimental/gradients/math_grad_test.cc new file mode 100644 index 00000000000000..33cbd44b4dc478 --- /dev/null +++ b/tensorflow/c/experimental/gradients/math_grad_test.cc @@ -0,0 +1,448 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/math_grad.h" + +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/experimental/gradients/grad_test_helper.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +using tensorflow::TF_StatusPtr; + +Status AddModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Add(ctx, inputs, outputs, "Add"); +} + +Status ExpModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Exp(ctx, inputs, outputs, "Exp"); +} + +Status SqrtModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Sqrt(ctx, inputs, outputs, "Sqrt"); +} + +Status NegModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Neg(ctx, inputs, outputs, "Neg"); +} + +Status SubModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Sub(ctx, inputs, outputs, "Sub"); +} + +Status MulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Mul(ctx, inputs, outputs, "Mul"); +} + +Status Log1pModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Log1p(ctx, inputs, outputs, "Log1p"); +} + +Status DivNoNanModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::DivNoNan(ctx, inputs, outputs, "DivNoNan"); +} + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + status_ = StatusFromTF_Status(status.get()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + { + AbstractContext* ctx_raw = nullptr; + status_ = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + immediate_execution_ctx_.reset(ctx_raw); + } + + // Computing numerical gradients with TensorFloat-32 is numerically + // unstable. Some forward pass tests also fail with TensorFloat-32 due to + // low tolerances + enable_tensor_float_32_execution(false); + } + + AbstractContextPtr immediate_execution_ctx_; + GradientRegistry registry_; + Status status_; + + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } +}; + +TEST_P(CppGradients, TestAddGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + y.reset(y_raw); + } + + // TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to + // AddV2Registerer. + status_ = registry_.Register("AddV2", AddRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + AddModel, BuildGradModel(AddModel, registry_), + immediate_execution_ctx_.get(), {x.get(), y.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestExpGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + status_ = registry_.Register("Exp", ExpRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + ExpModel, BuildGradModel(ExpModel, registry_), + immediate_execution_ctx_.get(), {x.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestMatMulGrad) { + // TODO(vnvo2409): Figure out why `gradient_checker` does not work very + // well with `MatMul` and remove `TestMatMul*` in + // `mnist_gradients_test` when done. + GTEST_SKIP(); + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + int64_t A_dims[] = {3, 3}; + AbstractTensorHandlePtr A; + { + AbstractTensorHandle* A_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + A.reset(A_raw); + } + + float B_vals[] = {9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + int64_t B_dims[] = {3, 3}; + AbstractTensorHandlePtr B; + { + AbstractTensorHandle* B_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + B.reset(B_raw); + } + + status_ = registry_.Register("MatMul", MatMulRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + for (bool transpose_a : {false, true}) { + for (bool transpose_b : {false, true}) { + Model MatMulModel = + [transpose_a, transpose_b]( + AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) -> Status { + return ops::MatMul(ctx, inputs, outputs, "MatMul", transpose_a, + transpose_b); + }; + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + MatMulModel, BuildGradModel(MatMulModel, registry_), + immediate_execution_ctx_.get(), {A.get(), B.get()}, UseFunction())); + } + } +} + +TEST_P(CppGradients, TestMatMulGradManual) { + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; + int64_t A_dims[] = {3, 3}; + AbstractTensorHandlePtr A; + { + AbstractTensorHandle* A_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + A.reset(A_raw); + } + + float B_vals[] = {9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + int64_t B_dims[] = {3, 3}; + AbstractTensorHandlePtr B; + { + AbstractTensorHandle* B_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + B.reset(B_raw); + } + + status_ = registry_.Register("MatMul", MatMulRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + bool transpose_a_vals[] = {false, false, true, true}; + bool transpose_b_vals[] = {false, true, false, true}; + float dA_vals[4][9] = {{24, 15, 6, 24, 15, 6, 24, 15, 6}, + {18, 15, 12, 18, 15, 12, 18, 15, 12}, + {24, 24, 24, 15, 15, 15, 6, 6, 6}, + {18, 18, 18, 15, 15, 15, 12, 12, 12}}; + float dB_vals[4][9] = {{12, 12, 12, 15, 15, 15, 18, 18, 18}, + {12, 15, 18, 12, 15, 18, 12, 15, 18}, + {6, 6, 6, 15, 15, 15, 24, 24, 24}, + {6, 15, 24, 6, 15, 24, 6, 15, 24}}; + + for (int i{}; i < 4; ++i) { + bool transpose_a = transpose_a_vals[i]; + bool transpose_b = transpose_b_vals[i]; + Model MatMulModel = + [transpose_a, transpose_b]( + AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) -> Status { + return ops::MatMul(ctx, inputs, outputs, "MatMul", transpose_a, + transpose_b); + }; + Model MatMulGradModel = BuildGradModel(MatMulModel, registry_); + std::vector outputs(2); + status_ = + RunModel(MatMulGradModel, immediate_execution_ctx_.get(), + {A.get(), B.get()}, absl::MakeSpan(outputs), UseFunction()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], dA_vals[i], + /*dims*/ {3, 3}, + /*abs_error*/ 0)); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], dB_vals[i], + /*dims*/ {3, 3}, + /*abs_error*/ 0)); + outputs[0]->Unref(); + outputs[1]->Unref(); + } +} + +TEST_P(CppGradients, TestSqrtGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + status_ = registry_.Register("Sqrt", SqrtRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + SqrtModel, BuildGradModel(SqrtModel, registry_), + immediate_execution_ctx_.get(), {x.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestNegGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + status_ = registry_.Register("Neg", NegRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + NegModel, BuildGradModel(NegModel, registry_), + immediate_execution_ctx_.get(), {x.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestSubGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + y.reset(y_raw); + } + + status_ = registry_.Register("Sub", SubRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + SubModel, BuildGradModel(SubModel, registry_), + immediate_execution_ctx_.get(), {x.get(), y.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestMulGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + y.reset(y_raw); + } + + status_ = registry_.Register("Mul", MulRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + MulModel, BuildGradModel(MulModel, registry_), + immediate_execution_ctx_.get(), {x.get(), y.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestLog1pGrad) { + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + status_ = registry_.Register("Log1p", Log1pRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + Log1pModel, BuildGradModel(Log1pModel, registry_), + immediate_execution_ctx_.get(), {x.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestDivNoNanGrad) { + status_ = registry_.Register("DivNoNan", DivNoNanRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + auto DivNoNanGradModel = BuildGradModel(DivNoNanModel, registry_); + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + y.reset(y_raw); + } + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + DivNoNanModel, DivNoNanGradModel, immediate_execution_ctx_.get(), + {x.get(), y.get()}, UseFunction())); + + // `DivNoNanGradModel` should return {`0`, `0`} when the denominator is `0`. + AbstractTensorHandlePtr z; + { + AbstractTensorHandle* z_raw = nullptr; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 0.0f, &z_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + z.reset(z_raw); + } + std::vector outputs(2); + status_ = + RunModel(DivNoNanGradModel, immediate_execution_ctx_.get(), + {x.get(), z.get()}, absl::MakeSpan(outputs), UseFunction()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {}, + /*abs_error*/ 0)); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {0.0f}, /*dims*/ {}, + /*abs_error*/ 0)); + outputs[0]->Unref(); + outputs[1]->Unref(); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index 64532c8ffc0515..7434f05a74ecd0 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" using std::vector; +using tensorflow::ops::BiasAddGrad; using tensorflow::ops::Mul; using tensorflow::ops::ReluGrad; @@ -35,29 +36,37 @@ namespace { class ReluGradientFunction : public GradientFunction { public: explicit ReluGradientFunction(vector f_outputs) - : forward_outputs(f_outputs) {} + : forward_outputs_(f_outputs) { + for (auto output : forward_outputs_) { + if (output) { + output->Ref(); + } + } + } - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { - AbstractTensorHandle* upstream_grad = grad_inputs[0]; - AbstractTensorHandle* activations = forward_outputs[0]; - grad_outputs->resize(1); - vector relugrad_outputs(1); + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + AbstractTensorHandle* upstream_grad = grad_outputs[0]; + AbstractTensorHandle* activations = forward_outputs_[0]; // Calculate Grad std::string name = "relu_grad"; - - TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations}, - absl::MakeSpan(relugrad_outputs), - name.c_str())); - (*grad_outputs)[0] = relugrad_outputs[0]; - + TF_RETURN_IF_ERROR( + ReluGrad(ctx, {upstream_grad, activations}, grad_inputs, name.c_str())); return Status::OK(); } - ~ReluGradientFunction() override {} + ~ReluGradientFunction() override { + for (auto output : forward_outputs_) { + if (output) { + output->Unref(); + } + } + } private: - vector forward_outputs; + // TODO(b/174778737): Only hold needed outputs. + vector forward_outputs_; }; Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec, @@ -86,47 +95,79 @@ class SparseSoftmaxCrossEntropyWithLogitsGradientFunction public: explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction( vector f_outputs) - : forward_outputs(f_outputs) {} - - Status Compute(Context* ctx, const IncomingGradients& grad_inputs, - vector* grad_outputs) override { - grad_outputs->resize(2); + : forward_outputs_(f_outputs) {} + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { // Grad for Softmax Input - vector mul_outputs(1); TF_RETURN_IF_ERROR(BroadcastMul( - ctx->ctx, grad_inputs[0], forward_outputs[1], - absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad - (*grad_outputs)[0] = mul_outputs[0]; + ctx, grad_outputs[0], forward_outputs_[1], + grad_inputs.subspan(0, 1))); // upstream_grad * local softmax grad // Grad for labels is null - (*grad_outputs)[1] = nullptr; - + grad_inputs[1] = nullptr; return Status::OK(); } ~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {} private: - vector forward_outputs; + vector forward_outputs_; +}; + +// TODO(vnvo2409): Add python test +class BiasAddGradientFunction : public GradientFunction { + public: + explicit BiasAddGradientFunction(AttrBuilder f_attrs) + : forward_attrs_(f_attrs) {} + + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override { + /* Given upstream grad U and a BiasAdd: A + bias, the gradients are: + * + * dA = U + * dbias = reduceSum(U, dims = channel_dim) + */ + + AbstractTensorHandle* upstream_grad = grad_outputs[0]; + DCHECK(upstream_grad); + + // Recover data format from forward pass for gradient. + std::string data_format; + TF_RETURN_IF_ERROR(forward_attrs_.Get("data_format", &data_format)); + + // Grad for A + grad_inputs[0] = upstream_grad; + grad_inputs[0]->Ref(); + + // Grad for bias + std::string name = "bias_add_grad"; + TF_RETURN_IF_ERROR(BiasAddGrad(ctx, {upstream_grad}, + grad_inputs.subspan(1, 1), + data_format.c_str(), name.c_str())); + + return Status::OK(); + } + ~BiasAddGradientFunction() override {} + + private: + AttrBuilder forward_attrs_; }; } // namespace -BackwardFunction* ReluRegisterer(const ForwardOperation& op) { - auto gradient_function = new ReluGradientFunction(op.outputs); - // For ops with a single output, the gradient function is not called if there - // is no incoming gradient. So we do not need to worry about creating zeros - // grads in this case. - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); +GradientFunction* ReluRegisterer(const ForwardOperation& op) { + return new ReluGradientFunction(op.outputs); } -BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( +GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op) { - auto gradient_function = - new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs); - auto default_gradients = new PassThroughDefaultGradients(op); - return new BackwardFunction(gradient_function, default_gradients); + return new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs); +} + +GradientFunction* BiasAddRegisterer(const ForwardOperation& op) { + return new BiasAddGradientFunction(op.attrs); } } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h index 034f20d732516e..2a635f540b2d82 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.h +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -19,9 +19,10 @@ limitations under the License. namespace tensorflow { namespace gradients { -BackwardFunction* ReluRegisterer(const ForwardOperation& op); -BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( +GradientFunction* ReluRegisterer(const ForwardOperation& op); +GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op); +GradientFunction* BiasAddRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc new file mode 100644 index 00000000000000..3f1feda8be02f4 --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -0,0 +1,226 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad.h" + +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/experimental/gradients/grad_test_helper.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/tensor_float_32_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +using tensorflow::TF_StatusPtr; + +Status ReluModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::Relu(ctx, inputs, outputs, "Relu"); +} + +Status SparseSoftmaxCrossEntropyWithLogitsModel( + AbstractContext* ctx, absl::Span inputs, + absl::Span outputs) { + std::vector temp_outputs(2); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + ctx, inputs, absl::MakeSpan(temp_outputs), + "SparseSoftmaxCrossEntropyWithLogits")); + // `gradient_checker` only works with model that returns only 1 tensor. + // Although, `ops::SparseSoftmaxCrossEntropyWithLogits` returns 2 tensors, the + // second tensor isn't needed for computing gradient so we could safely drop + // it. + outputs[0] = temp_outputs[0]; + temp_outputs[1]->Unref(); + return Status::OK(); +} + +Status BiasAddModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd"); +} + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + status_ = StatusFromTF_Status(status.get()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + { + AbstractContext* ctx_raw = nullptr; + status_ = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + immediate_execution_ctx_.reset(ctx_raw); + } + + // Computing numerical gradients with TensorFloat-32 is numerically + // unstable. Some forward pass tests also fail with TensorFloat-32 due to + // low tolerances + enable_tensor_float_32_execution(false); + } + + AbstractContextPtr immediate_execution_ctx_; + GradientRegistry registry_; + Status status_; + + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } +}; + +TEST_P(CppGradients, TestReluGrad) { + status_ = registry_.Register("Relu", ReluRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + auto ReluGradModel = BuildGradModel(ReluModel, registry_); + + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 10.0f, -1.0f}; + int64_t X_dims[] = {3, 3}; + AbstractTensorHandlePtr X; + { + AbstractTensorHandle* X_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + X.reset(X_raw); + } + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + ReluModel, ReluGradModel, immediate_execution_ctx_.get(), {X.get()}, + UseFunction())); + + // Mathematically, Relu isn't differentiable at `0`. So `gradient_checker` + // does not work with it. + AbstractTensorHandlePtr Y; + { + AbstractTensorHandle* Y_raw; + status_ = TestScalarTensorHandle( + immediate_execution_ctx_.get(), 0.0f, &Y_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + Y.reset(Y_raw); + } + + std::vector outputs(1); + status_ = RunModel(ReluGradModel, immediate_execution_ctx_.get(), {Y.get()}, + absl::MakeSpan(outputs), UseFunction()); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {}, + /*abs_error*/ 0)); + outputs[0]->Unref(); +} + +TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) { + if (UseFunction()) { + // TODO(b/168850692): Enable this. + GTEST_SKIP() << "Can't take gradient of " + "SparseSoftmaxCrossEntropyWithLogits in tracing mode."; + } + + // Score + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; + int64_t X_dims[] = {3, 3}; + AbstractTensorHandlePtr X; + { + AbstractTensorHandle* X_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + X.reset(X_raw); + } + // Label + int32_t Y_vals[] = {1, 0, 1}; + int64_t Y_dims[] = {3}; + AbstractTensorHandlePtr Y; + { + AbstractTensorHandle* Y_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), Y_vals, Y_dims, 1, &Y_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + Y.reset(Y_raw); + } + + status_ = registry_.Register("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyWithLogitsRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + SparseSoftmaxCrossEntropyWithLogitsModel, + BuildGradModel(SparseSoftmaxCrossEntropyWithLogitsModel, registry_), + immediate_execution_ctx_.get(), {X.get(), Y.get()}, UseFunction())); +} + +TEST_P(CppGradients, TestBiasAddGrad) { + if (UseFunction() && UseMlir()) { + GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; + } + + // A + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + AbstractTensorHandlePtr A; + { + AbstractTensorHandle* A_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + A.reset(A_raw); + } + // Bias + float Bias_vals[] = {2.0f, 3.0f}; + int64_t Bias_dims[] = {2}; + AbstractTensorHandlePtr Bias; + { + AbstractTensorHandle* Bias_raw; + status_ = TestTensorHandleWithDims( + immediate_execution_ctx_.get(), Bias_vals, Bias_dims, 1, &Bias_raw); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + Bias.reset(Bias_raw); + } + + status_ = registry_.Register("BiasAdd", BiasAddRegisterer); + ASSERT_EQ(errors::OK, status_.code()) << status_.error_message(); + + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( + BiasAddModel, BuildGradModel(BiasAddModel, registry_), + immediate_execution_ctx_.get(), {A.get(), Bias.get()}, UseFunction())); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/not_differentiable.cc b/tensorflow/c/experimental/gradients/not_differentiable.cc new file mode 100644 index 00000000000000..e8dbb7ecdae415 --- /dev/null +++ b/tensorflow/c/experimental/gradients/not_differentiable.cc @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/not_differentiable.h" + +namespace tensorflow { +namespace gradients { +Status NotDifferentiableGradientFunction::Compute( + AbstractContext* ctx, absl::Span grad_outputs, + absl::Span grad_inputs) { + for (int i = 0; i < grad_inputs.size(); i++) { + grad_inputs[i] = nullptr; + } + return Status::OK(); +} + +Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op) { + return registry->Register(op, [](const ForwardOperation& op) { + return new NotDifferentiableGradientFunction; + }); +} +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/not_differentiable.h b/tensorflow/c/experimental/gradients/not_differentiable.h new file mode 100644 index 00000000000000..1a864dbf6e1eb4 --- /dev/null +++ b/tensorflow/c/experimental/gradients/not_differentiable.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +// Ignores `grad_outputs` and sets all entries in grad_inputs to nullptr. +class NotDifferentiableGradientFunction : public GradientFunction { + Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override; +}; +// Shorthand for registry->Register(op, new NotDifferentiableGradientFunction) +Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_ diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD index bada49ea669919..4e02daf36a22cd 100644 --- a/tensorflow/c/experimental/gradients/tape/BUILD +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -17,8 +17,6 @@ cc_library( deps = [ ":tape_operation", "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_function", - "//tensorflow/c/eager:abstract_operation", ], ) @@ -33,7 +31,6 @@ cc_library( ], deps = [ "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_function", "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:gradients_internal", ], @@ -51,6 +48,9 @@ cc_library( deps = [ ":tape_context", ":tape_operation", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:gradients_internal", ], ) diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc index 0b247d08f6cf01..4f0fc5fbdec99c 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape, parent_op_(parent_op), tape_(tape), registry_(registry) { - // TODO(srbs): Make AbstractOperation RefCounted. + // TODO(b/172003047): Consider making AbstractOperation RefCounted. // parent_op_->Ref(); } void TapeOperation::Release() { @@ -33,7 +33,7 @@ void TapeOperation::Release() { delete this; } TapeOperation::~TapeOperation() { - // TODO(srbs): Make AbstractOperation RefCounted. + // TODO(b/172003047): Consider making AbstractOperation RefCounted. // parent_op->Unref(); } Status TapeOperation::Reset(const char* op, const char* raw_device_name) { @@ -197,12 +197,6 @@ AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; } Status TapeOperation::Execute(absl::Span retvals, int* num_retvals) { TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals)); - std::vector input_ids(forward_op_.inputs.size()); - std::vector input_dtypes(forward_op_.inputs.size()); - for (int i = 0; i < forward_op_.inputs.size(); i++) { - input_ids[i] = ToId(forward_op_.inputs[i]); - input_dtypes[i] = forward_op_.inputs[i]->DataType(); - } for (int i = 0; i < *num_retvals; i++) { // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. forward_op_.outputs.push_back(retvals[i]); @@ -212,25 +206,11 @@ Status TapeOperation::Execute(absl::Span retvals, // Consider getting rid of this and making the behavior between number types // and string consistent. forward_op_.attrs.BuildNodeDef(); - std::vector tape_tensors; - for (auto t : retvals) { - tape_tensors.push_back(TapeTensor(t)); - } - tape_->RecordOperation( - parent_op_->Name(), tape_tensors, input_ids, input_dtypes, - [this]() -> BackwardFunction* { - std::unique_ptr backward_fn; - Status s = registry_.Lookup(forward_op_, &backward_fn); - if (!s.ok()) { - return nullptr; - } - return backward_fn.release(); - }, - [](BackwardFunction* ptr) { - if (ptr) { - delete ptr; - } - }); + // TODO(b/170307493): Populate skip_input_indices here. + std::unique_ptr backward_fn; + TF_RETURN_IF_ERROR(registry_.Lookup(forward_op_, &backward_fn)); + tape_->RecordOperation(forward_op_.inputs, forward_op_.outputs, + backward_fn.release(), parent_op_->Name()); return Status::OK(); } diff --git a/tensorflow/c/experimental/grappler/BUILD b/tensorflow/c/experimental/grappler/BUILD new file mode 100644 index 00000000000000..316fd8211059aa --- /dev/null +++ b/tensorflow/c/experimental/grappler/BUILD @@ -0,0 +1,67 @@ +# Description: +# Graph C API. + +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "grappler_hdrs", + hdrs = ["grappler.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status_headers", + ], +) + +cc_library( + name = "grappler", + srcs = ["grappler.cc"], + hdrs = [ + "grappler.h", + "grappler_internal.h", + ], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +tf_cc_test( + name = "grappler_test", + srcs = ["grappler_test.cc"], + deps = [ + ":grappler", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:single_machine", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + ], +) diff --git a/tensorflow/c/experimental/grappler/grappler.cc b/tensorflow/c/experimental/grappler/grappler.cc new file mode 100644 index 00000000000000..788647e1764a0a --- /dev/null +++ b/tensorflow/c/experimental/grappler/grappler.cc @@ -0,0 +1,404 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file extends/implements core graph optimizer base classes in terms of +// the C API defined in grappler.h. A class "CSomething" represents a +// "Something" that can be manipulated via calls in the C interface and a C +// struct called "TP_Something". + +#include "tensorflow/c/experimental/grappler/grappler.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/experimental/grappler/grappler_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" + +namespace { + +#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ + do { \ + if (STRUCT_OBJ.struct_size == 0) { \ + return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \ + "struct_size field in " #STRUCT_NAME \ + " must be set to " #SIZE_VALUE_NAME "."); \ + } \ + } while (0) + +#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \ + do { \ + if (STRUCT_OBJ.NAME == 0) { \ + return tensorflow::Status(tensorflow::error::FAILED_PRECONDITION, \ + "'" #NAME "' field in " #STRUCT_NAME \ + " must be set."); \ + } \ + } while (0) + +tensorflow::Status ValidateTPOptimizerRegistrationParams( + const TP_OptimizerRegistrationParams& params) { + VALIDATE_STRUCT_SIZE(TP_OptimizerRegistrationParams, params, + TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE); + VALIDATE_MEMBER(TP_OptimizerRegistrationParams, params, device_type); + return tensorflow::Status::OK(); +} + +tensorflow::Status ValidateTPOptimizer(const TP_Optimizer& optimizer) { + VALIDATE_STRUCT_SIZE(TP_Optimizer, optimizer, TP_OPTIMIZER_STRUCT_SIZE); + VALIDATE_MEMBER(TP_Optimizer, optimizer, optimize_func); + return tensorflow::Status::OK(); +} + +tensorflow::Status ValidateTPOptimizerConfigs( + const TP_OptimizerConfigs& configs) { + VALIDATE_STRUCT_SIZE(TP_OptimizerConfigs, configs, + TP_OPTIMIZER_CONFIGS_STRUCT_SIZE); + return tensorflow::Status::OK(); +} + +#undef VALIDATE_MEMBER +#undef VALIDATE_STRUCT_SIZE + +// A map containing the input graph as its key, and TF_GrapplerItem as the +// value. Users can fetch GrapplerItem for additional info to transform the +// graph. +absl::flat_hash_map* GrapplerItemMap() { + static absl::flat_hash_map* + grappler_items = + new absl::flat_hash_map; + return grappler_items; +} +} // namespace + +namespace tensorflow { +namespace grappler { + +Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph_def) { + OwnedTFStatus c_status(TF_NewStatus()); + OwnedTFBuffer graph_buf(TF_NewBuffer()); + OwnedTFBuffer optimized_graph_buf(TF_NewBuffer()); + TF_RETURN_IF_ERROR(MessageToBuffer(item.graph, graph_buf.get())); + + const auto it = GrapplerItemMap()->find(graph_buf.get()); + if (it == GrapplerItemMap()->end()) + GrapplerItemMap()->insert( + {graph_buf.get(), reinterpret_cast(&item)}); + + optimizer_.optimize_func(c_optimizer_, graph_buf.get(), + optimized_graph_buf.get(), c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR( + BufferToMessage(optimized_graph_buf.get(), optimized_graph_def)); + + GrapplerItemMap()->erase(graph_buf.get()); + return Status::OK(); +} + +#define CONFIG_TOGGLE(optimizer) \ + if (tp_configs.optimizer == TF_TriState_Off) \ + configs.toggle_config[#optimizer] = RewriterConfig::OFF; \ + else \ + configs.toggle_config[#optimizer] = RewriterConfig::ON; + +void CGraphOptimizerRegister( + const PluginGraphOptimizerRegistry::Creator& creator, + const TP_OptimizerConfigs tp_configs, const char* device_type) { + ConfigList configs; + // disable_model_pruning is turned off by default. + if (tp_configs.disable_model_pruning == TF_TriState_On) + configs.disable_model_pruning = true; + else + configs.disable_model_pruning = false; + // The other configs are turned on by default. + CONFIG_TOGGLE(implementation_selector); + CONFIG_TOGGLE(function_optimization); + CONFIG_TOGGLE(common_subgraph_elimination); + CONFIG_TOGGLE(arithmetic_optimization); + CONFIG_TOGGLE(debug_stripper); + CONFIG_TOGGLE(constant_folding); + CONFIG_TOGGLE(shape_optimization); + CONFIG_TOGGLE(auto_mixed_precision); + CONFIG_TOGGLE(auto_mixed_precision_mkl); + CONFIG_TOGGLE(pin_to_host_optimization); + CONFIG_TOGGLE(layout_optimizer); + CONFIG_TOGGLE(remapping); + CONFIG_TOGGLE(loop_optimization); + CONFIG_TOGGLE(dependency_optimization); + CONFIG_TOGGLE(auto_parallel); + CONFIG_TOGGLE(memory_optimization); + CONFIG_TOGGLE(scoped_allocator_optimization); + PluginGraphOptimizerRegistry::RegisterPluginOptimizerOrDie( + creator, device_type, configs); +} + +#undef CONFIG_TOGGLE + +tensorflow::Status InitGraphPlugin(void* dso_handle) { + tensorflow::Env* env = tensorflow::Env::Default(); + + // Step 1: Load symbol for `TF_InitPlugin` + void* dso_symbol; + TF_RETURN_IF_ERROR( + env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol)); + + // Step 2: Call `TF_InitPlugin` + auto init_fn = reinterpret_cast(dso_symbol); + return InitGraphPlugin(init_fn); +} + +tensorflow::Status InitGraphPlugin(TFInitGraphPluginFn init_fn) { + TP_OptimizerRegistrationParams params{ + TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE}; + TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE}; + TP_OptimizerConfigs optimizer_configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE}; + params.major_version = GO_MAJOR; + params.minor_version = GO_MINOR; + params.patch_version = GO_PATCH; + params.optimizer = &optimizer; + params.optimizer_configs = &optimizer_configs; + + OwnedTFStatus c_status(TF_NewStatus()); + init_fn(¶ms, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateTPOptimizerRegistrationParams(params)); + TF_RETURN_IF_ERROR(ValidateTPOptimizer(optimizer)); + TF_RETURN_IF_ERROR(ValidateTPOptimizerConfigs(optimizer_configs)); + + CGraphOptimizerRegister( + [=]() { return new CGraphOptimizer(optimizer, params.device_type); }, + optimizer_configs, params.device_type); + + return Status::OK(); +} + +} // namespace grappler +} // namespace tensorflow + +const TF_GrapplerItem* TF_GetGrapplerItem(TF_Buffer* graph, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const auto it = GrapplerItemMap()->find(graph); + if (it != GrapplerItemMap()->end()) { + return it->second; + } else { + status->status = tensorflow::errors::NotFound("GrapplerItem is not found"); + return nullptr; + } +} + +void TF_GetNodesToPreserveListSize(const TF_GrapplerItem* item, int* num_values, + size_t* storage_size, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const std::unordered_set& nodes = + reinterpret_cast(item) + ->NodesToPreserve(); + *num_values = nodes.size(); + *storage_size = 0; + for (const std::string& str : nodes) { + *storage_size += str.size(); + } +} + +void TF_GetNodesToPreserveList(const TF_GrapplerItem* item, char** values, + size_t* lengths, int num_values, void* storage, + size_t storage_size, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const std::unordered_set& nodes = + reinterpret_cast(item) + ->NodesToPreserve(); + char* p = static_cast(storage); + + int index = 0; + for (const std::string& s : nodes) { + if (index >= num_values) break; + values[index] = p; + lengths[index] = s.size(); + if ((p + s.size()) > (static_cast(storage) + storage_size)) { + status->status = tensorflow::errors::InvalidArgument( + "Not enough storage to hold the requested list of nodes"); + return; + } + memcpy(values[index], s.data(), s.size()); + p += s.size(); + index++; + } +} + +void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, int* num_values, + size_t* storage_size, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const std::vector& nodes = + reinterpret_cast(item)->fetch; + *num_values = nodes.size(); + *storage_size = 0; + for (const std::string& str : nodes) { + *storage_size += str.size(); + } +} + +void TF_GetFetchNodesList(const TF_GrapplerItem* item, char** values, + size_t* lengths, int num_values, void* storage, + size_t storage_size, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const std::vector& nodes = + reinterpret_cast(item)->fetch; + + const int len = std::min(num_values, static_cast(nodes.size())); + char* p = static_cast(storage); + for (int index = 0; index < len; ++index) { + const std::string& s = nodes[index]; + values[index] = p; + lengths[index] = s.size(); + if ((p + s.size()) > (static_cast(storage) + storage_size)) { + status->status = tensorflow::errors::InvalidArgument( + "Not enough storage to hold the requested list of nodes"); + return; + } + memcpy(values[index], s.data(), s.size()); + p += s.size(); + } +} + +TF_GraphProperties* TF_NewGraphProperties(const TF_GrapplerItem* item) { + return reinterpret_cast( + new tensorflow::grappler::GraphProperties( + *reinterpret_cast(item))); +} + +void TF_DeleteGraphProperties(TF_GraphProperties* graph_properties) { + if (graph_properties == nullptr) return; + delete reinterpret_cast( + graph_properties); +} + +void TF_InferStatically(TF_GraphProperties* graph_properties, + TF_Bool assume_valid_feeds, + TF_Bool aggressive_shape_inference, + TF_Bool include_input_tensor_values, + TF_Bool include_output_tensor_values, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + tensorflow::Status s = + reinterpret_cast(graph_properties) + ->InferStatically(assume_valid_feeds, aggressive_shape_inference, + include_input_tensor_values, + include_output_tensor_values); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + } +} + +void TF_GetInputPropertiesListSize(TF_GraphProperties* graph_properties, + const char* name, int* num_values, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + *num_values = + reinterpret_cast(graph_properties) + ->GetInputProperties(name) + .size(); +} + +void TF_GetOutputPropertiesListSize(TF_GraphProperties* graph_properties, + const char* name, int* num_values, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + *num_values = + reinterpret_cast(graph_properties) + ->GetOutputProperties(name) + .size(); +} + +void TF_GetInputPropertiesList(TF_GraphProperties* graph_properties, + const char* name, TF_Buffer** properties, + int num_values, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const std::vector& tensor_properties = + reinterpret_cast(graph_properties) + ->GetInputProperties(name); + const int len = + std::min(num_values, static_cast(tensor_properties.size())); + for (int i = 0; i < len; ++i) { + tensorflow::Status s = + tensorflow::MessageToBuffer(tensor_properties[i], properties[i]); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return; + } + } +} + +void TF_GetOutputPropertiesList(TF_GraphProperties* graph_properties, + const char* name, TF_Buffer** properties, + int num_values, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const std::vector& tensor_properties = + reinterpret_cast(graph_properties) + ->GetOutputProperties(name); + const int len = + std::min(num_values, static_cast(tensor_properties.size())); + for (int i = 0; i < len; ++i) { + tensorflow::Status s = + tensorflow::MessageToBuffer(tensor_properties[i], properties[i]); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return; + } + } +} + +TF_FunctionLibraryDefinition* TF_NewFunctionLibraryDefinition( + TF_Buffer* graph_buf, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + tensorflow::GraphDef graph_def; + tensorflow::Status s = tensorflow::BufferToMessage(graph_buf, &graph_def); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + return reinterpret_cast( + new tensorflow::FunctionLibraryDefinition( + tensorflow::OpRegistry::Global(), graph_def.library())); +} + +void TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition* fn_lib) { + if (fn_lib == nullptr) return; + delete reinterpret_cast(fn_lib); +} + +void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, const char* name, + TF_Buffer* buf, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + const tensorflow::OpDef* op_def_ptr = nullptr; + tensorflow::Status s = + reinterpret_cast(fn_lib) + ->LookUpOpDef(name, &op_def_ptr); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return; + } + + s = tensorflow::MessageToBuffer(*op_def_ptr, buf); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return; + } +} diff --git a/tensorflow/c/experimental/grappler/grappler.h b/tensorflow/c/experimental/grappler/grappler.h new file mode 100644 index 00000000000000..05d48bb3e80646 --- /dev/null +++ b/tensorflow/c/experimental/grappler/grappler.h @@ -0,0 +1,286 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" + +// -------------------------------------------------------------------------- +// C API for Graph. The API is under active development and eventually +// should allow registering a plugin graph optimizer with TensorFlow. +// +// Conventions: +// * Struct prefix indicates whether struct fields should be filled by the +// plugin or core implementation: +// * Struct that should be filled by the plugin: `TP_OptimizerConfigs`, +// `TP_Optimizer`, `TP_OptimizerRegistrationParams` +// * Struct that should be filled by the proper: `TF_GrapplerItem`, +// `TF_GraphProperties`, `TF_FunctionLibraryDefinition` +// * We use `struct_size` for version checking. It should be set both by +// core and the plugin. +// * For example, `TF_InitGraph` function receives +// `TP_OptimizerRegistrationParams*` as input with `struct_size` +// populated by core. The plugin is responsible for setting +// `struct_size` as well, along with all other fields. +// * Refer to "TensorFlow Versioning Strategy" section at +// https://github.com/tensorflow/community/pull/257/files. +// * Note that the API is still under active development and doesn't have +// versioning guarantees yet. +// * `void* ext` is a free-form field that can be populated by +// a plugin in `TP_*` structs or potential future extension points . +// +// Example usage: +// +// /* Sample TensorFlow code below, exact implementation might differ. */ +// // Version checking uses `struct_size`. It should be set both by core +// // and the plugin. +// TP_OptimizerRegistrationParams params{ +// TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE}; +// TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE}; +// TP_OptimizerConfigs configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE}; +// params.optimizer = &optimizer; +// params.configs = &configs; +// +// /* Plugin code below */ +// void TF_InitGraph(TP_OptimizerRegistrationParams* params, +// TF_Status* status) { +// params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE; +// params->device_type = "MY_DEVICE"; +// +// // Disable certain optimizer. +// params->optimizer_configs->struct_size = +// TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; params->optimizer_configs->remapping = +// TF_TriState_Off; +// +// // Set functions to create a new optimizer. +// params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE; +// params->optimizer->create_func = (My_optimizer::create_func); +// } + +#define GO_MAJOR 0 +#define GO_MINOR 0 +#define GO_PATCH 1 + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_TriState is the C API typedef for tri-state. +typedef enum TF_TriState { + TF_TriState_Default = 0, + TF_TriState_Off, + TF_TriState_On, +} TF_TriState; + +// Flags indicating whether existing optimizers should be turned off. +// It's optional for plugin to set functions to return true/false. If not +// set, proper uses configuration set by user. +typedef struct TP_OptimizerConfigs { + size_t struct_size; + void* ext; // reserved for future use + TF_TriState disable_model_pruning; + TF_TriState implementation_selector; + TF_TriState function_optimization; + TF_TriState common_subgraph_elimination; + TF_TriState arithmetic_optimization; + TF_TriState debug_stripper; + TF_TriState constant_folding; + TF_TriState shape_optimization; + TF_TriState auto_mixed_precision; + TF_TriState auto_mixed_precision_mkl; + TF_TriState pin_to_host_optimization; + TF_TriState layout_optimizer; + TF_TriState remapping; + TF_TriState loop_optimization; + TF_TriState dependency_optimization; + TF_TriState auto_parallel; + TF_TriState memory_optimization; + TF_TriState scoped_allocator_optimization; +} TP_OptimizerConfigs; + +#define TP_OPTIMIZER_CONFIGS_STRUCT_SIZE \ + TF_OFFSET_OF_END(TP_OptimizerConfigs, scoped_allocator_optimization) + +// Struct for Optimizer. Plugin authors must provide an optimize function. +// Creation and deletion functions are optional. +typedef struct TP_Optimizer { + size_t struct_size; + void* ext; // reserved for future use + + // [Optional] + // Create function for optimizer. + void* (*create_func)(); + + // Optimizer function for optimizer. The first param is an optimizer created + // by create_func. The second param is input graph. The third param is output + // graph. + void (*optimize_func)(void*, TF_Buffer*, TF_Buffer*, TF_Status*); + + // [Optional] + // Destroy function for optimizer. If Create function is provided, destroy + // function is must. + void (*destroy_func)(void*); +} TP_Optimizer; + +#define TP_OPTIMIZER_STRUCT_SIZE TF_OFFSET_OF_END(TP_Optimizer, destroy_func) + +typedef struct TP_OptimizerRegistrationParams { + size_t struct_size; + void* ext; // reserved for future use + + // Graph C API version. + int32_t major_version; + int32_t minor_version; + int32_t patch_version; + + // Backend device type supported by the optimizer. + const char* device_type; + TP_OptimizerConfigs* optimizer_configs; // output, set by plugin + TP_Optimizer* optimizer; // output, set by plugin +} TP_OptimizerRegistrationParams; + +#define TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(TP_OptimizerRegistrationParams, optimizer) + +// TF_InitGraph is used to do graph optimizer registration. +// Plugin should implement TF_InitGraph to register graph optimizers. +void TF_InitGraph(TP_OptimizerRegistrationParams* params, TF_Status* status); + +// TF_GrapplerItem represents a combination of a graph, one of more fetch nodes, +// and potentially a set of nodes to feed. +typedef struct TF_GrapplerItem TF_GrapplerItem; + +// Get TF_GrapplerItem from TF_Buffer. +const TF_GrapplerItem* TF_GetGrapplerItem(TF_Buffer* graph, TF_Status* status); + +// Get a set of node names that must be preserved. They can not be transformed +// or removed during the graph transformation. This includes feed and fetch +// nodes, keep_ops, init_ops. Fills in `num_values` and `storage_size`, they +// will be used in `TF_GetNodesToPreserveList`. +void TF_GetNodesToPreserveListSize(const TF_GrapplerItem* item, int* num_values, + size_t* storage_size, TF_Status* status); + +// Get a set of node names that must be preserved. They can not be transformed +// or removed during the graph transformation. This includes feed and fetch +// nodes, keep_ops, init_ops. Fills in `values` and `lengths`, each of which +// must point to an array of length at least `num_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. `num_values` and `storage` can be +// obtained from TF_GetNodesToPreserveSize +// +// Fails if storage_size is too small to hold the requested number of strings. +void TF_GetNodesToPreserveList(const TF_GrapplerItem* item, char** values, + size_t* lengths, int num_values, void* storage, + size_t storage_size, TF_Status* status); + +// Get a set of node names for fetch nodes. Fills in `values` and `lengths`, +// they will be used in `TF_GetFetchNodesList` +void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, int* num_values, + size_t* storage_size, TF_Status* status); + +// Get a set of node names for fetch nodes. Fills in `values` and `lengths`, +// each of which must point to an array of length at least `num_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. `num_values` and `storage` can be +// obtained from TF_GetFetchNodesSize +// +// Fails if storage_size is too small to hold the requested number of strings. +void TF_GetFetchNodesList(const TF_GrapplerItem* item, char** values, + size_t* lengths, int num_values, void* storage, + size_t storage_size, TF_Status* status); + +// Infer OpInfo::TensorProperties for graph nodes inputs/outputs. +// +// Typical use case, is to infer tensor properties from a graph, before doing +// optimization pass. Nodes modified during optimization pass have to be +// invalidated, to prevent further incorrect optimizations based on wrong shape +// and data type properties. +typedef struct TF_GraphProperties TF_GraphProperties; + +// Create GraphProperties. The item must outlive the properties. +TF_GraphProperties* TF_NewGraphProperties(const TF_GrapplerItem* item); + +// Delete GraphProperties. +void TF_DeleteGraphProperties(TF_GraphProperties* graph_properties); + +// Infer tensor shapes through abstract interpretation. +// If assume_valid_feeds is true, it can help infer shapes in the fanout of fed +// nodes. This may cause incorrectness in graph analyses, but is useful for +// simulation or scheduling. +// If aggressive_shape_inference is true, nodes are executed on the host to +// identify output values when possible and does other aggressive strategies. +// This may cause incorrectness in graph analyses, but is useful for simulation +// or scheduling. +// If include_input_tensor_values is true, the values of constant +// tensors will included in the input properties. +// If include_output_tensor_values is true, the values of constant tensors will +// be included in the output properties. +void TF_InferStatically(TF_GraphProperties* graph_properties, + TF_Bool assume_valid_feeds, + TF_Bool aggressive_shape_inference, + TF_Bool include_input_tensor_values, + TF_Bool include_output_tensor_values, TF_Status* s); + +// Get the size of input OpInfo::TensorProperties given node name. +void TF_GetInputPropertiesListSize(TF_GraphProperties* graph_properties, + const char* name, int* num_values, + TF_Status* status); + +// Get the size of output OpInfo::TensorProperties given node name. +void TF_GetOutputPropertiesListSize(TF_GraphProperties* graph_properties, + const char* name, int* num_values, + TF_Status* status); + +// Get a list of input OpInfo::TensorProperties given node name. +// Return the serialized list `properties`. +void TF_GetInputPropertiesList(TF_GraphProperties* graph_properties, + const char* name, TF_Buffer** properties, + int num_values, TF_Status* status); + +// Get a list of output OpInfo::TensorProperties given node name. +// Return the serialized list `properties`. +void TF_GetOutputPropertiesList(TF_GraphProperties* graph_properties, + const char* name, TF_Buffer** properties, + int num_values, TF_Status* status); + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +// Typical use case, is to look up an OpDef by type name. +typedef struct TF_FunctionLibraryDefinition TF_FunctionLibraryDefinition; + +// Create NewFunctionLibraryDefinition. +TF_FunctionLibraryDefinition* TF_NewFunctionLibraryDefinition( + TF_Buffer* graph_buf, TF_Status* status); + +// Delete NewFunctionLibraryDefinition. +void TF_DeleteFunctionLibraryDefinition(TF_FunctionLibraryDefinition* fn_lib); + +// Shorthand for calling LookUp to get the OpDef from FunctionLibraryDefinition +// given op name. The returned OpDef is represented by TF_Buffer. +void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, const char* name, + TF_Buffer* buf, TF_Status* s); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_ diff --git a/tensorflow/c/experimental/grappler/grappler_internal.h b/tensorflow/c/experimental/grappler/grappler_internal.h new file mode 100644 index 00000000000000..8b1fa07c96f27a --- /dev/null +++ b/tensorflow/c/experimental/grappler/grappler_internal.h @@ -0,0 +1,106 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Classes and utilities that work with Graph C API for internal use. +// This includes functions used for optimizer registration and interfaces needed +// for testing. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/grappler/grappler.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Plugin initialization function that a device plugin +// must define. +typedef void (*TFInitGraphPluginFn)(TP_OptimizerRegistrationParams* const, + TF_Status* const); + +// Registers Graph optimizers. +Status InitGraphPlugin(void* dso_handle); + +// Allow registering a graph optimizer using a function (used for +// testing). +Status InitGraphPlugin(TFInitGraphPluginFn init_fn); + +struct GrapplerItem; +class Cluster; + +struct TFStatusDeleter { + void operator()(TF_Status* s) const { TF_DeleteStatus(s); } +}; +using OwnedTFStatus = std::unique_ptr; + +struct TFBufferDeleter { + void operator()(TF_Buffer* buf) const { TF_DeleteBuffer(buf); } +}; +using OwnedTFBuffer = std::unique_ptr; + +class CGraphOptimizer : public CustomGraphOptimizer { + public: + explicit CGraphOptimizer(TP_Optimizer optimizer, const char* device_type) + : optimizer_(optimizer), device_type_(device_type) { + if (optimizer.create_func != nullptr) { + c_optimizer_ = (*optimizer_.create_func)(); + } else { + c_optimizer_ = nullptr; + } + } + std::string name() const override { return "PluggableGraphOptimizer"; } + bool UsesFunctionLibrary() const override { return false; } + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph_def) override; + + ~CGraphOptimizer() override { + if (optimizer_.destroy_func != nullptr) { + (*optimizer_.destroy_func)(c_optimizer_); + } + } + + private: + TP_Optimizer optimizer_; + std::string device_type_; + void* c_optimizer_; +}; + +// Registration function to register a CGraphOptimizer along with plugin configs +// and device type. +void CGraphOptimizerRegister( + const PluginGraphOptimizerRegistry::Creator& creator, + const TP_OptimizerConfigs tp_configs, const char* device_type); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_ diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc new file mode 100644 index 00000000000000..37d203d8d719f9 --- /dev/null +++ b/tensorflow/c/experimental/grappler/grappler_test.cc @@ -0,0 +1,307 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0(the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/grappler/grappler.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/experimental/grappler/grappler_internal.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/grappler/clusters/single_machine.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +void optimize_func(void* optimizer, TF_Buffer* graph_buf, + TF_Buffer* optimized_graph_buf, TF_Status* tf_status) {} + +void PopulateDefaultParam(TP_OptimizerRegistrationParams* params) { + params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE; + params->optimizer_configs->struct_size = TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; + params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE; + params->optimizer->create_func = nullptr; + params->optimizer->optimize_func = optimize_func; + params->optimizer->destroy_func = nullptr; +} + +TEST(Grappler, SuccessfulRegistration) { + auto plugin_init = [](TP_OptimizerRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultParam(params); + params->device_type = "Success"; + params->optimizer_configs->remapping = TF_TriState_Off; + }; + + TF_ASSERT_OK(InitGraphPlugin(plugin_init)); + ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers( + std::set{"Success"}) + .size(), + 1); + ConfigList config = PluginGraphOptimizerRegistry::GetPluginConfigs( + true, std::set{"Success"}); + ASSERT_EQ(config.toggle_config["remapping"], RewriterConfig::OFF); +} + +TEST(Grappler, MultiplePluginRegistration) { + auto plugin_init_0 = [](TP_OptimizerRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultParam(params); + params->device_type = "Device0"; + }; + auto plugin_init_1 = [](TP_OptimizerRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultParam(params); + params->device_type = "Device1"; + }; + + TF_ASSERT_OK(InitGraphPlugin(plugin_init_0)); + TF_ASSERT_OK(InitGraphPlugin(plugin_init_1)); + ASSERT_EQ(PluginGraphOptimizerRegistry::CreateOptimizers( + std::set{"Device0", "Device1"}) + .size(), + 2); +} + +TEST(Grappler, DeviceTypeNotSet) { + auto plugin_init = [](TP_OptimizerRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultParam(params); + params->device_type = nullptr; + }; + + tensorflow::Status status = InitGraphPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ( + status.error_message(), + "'device_type' field in TP_OptimizerRegistrationParams must be set."); +} + +TEST(Grappler, OptimizeFuncNotSet) { + auto plugin_init = [](TP_OptimizerRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultParam(params); + params->device_type = "FuncNotSet"; + params->optimizer->optimize_func = nullptr; + }; + + tensorflow::Status status = InitGraphPlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ(status.error_message(), + "'optimize_func' field in TP_Optimizer must be set."); +} + +TEST(TF_GrapplerItem, NodesToPreserve) { + GrapplerItem item; + item.fetch = std::vector{"Conv", "BiasAdd"}; + std::unordered_set nodes_preserved = item.NodesToPreserve(); + TF_GrapplerItem* c_item = reinterpret_cast(&item); + + int list_total_size = 0; + for (const string& s : nodes_preserved) { + list_total_size += s.size(); + } + + size_t storage_size = 0; + int num_values = 0; + TF_Status* status = TF_NewStatus(); + TF_GetNodesToPreserveListSize(c_item, &num_values, &storage_size, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(nodes_preserved.size(), num_values); + EXPECT_EQ(list_total_size, storage_size); + + std::unique_ptr values(new char*[nodes_preserved.size()]); + std::unique_ptr lens(new size_t[nodes_preserved.size()]); + std::unique_ptr storage(new char[storage_size]); + TF_GetNodesToPreserveList(c_item, values.get(), lens.get(), + nodes_preserved.size(), storage.get(), storage_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + for (size_t i = 0; i < nodes_preserved.size(); ++i) { + EXPECT_EQ(nodes_preserved.find(string(static_cast(values[i]), + lens[i])) != nodes_preserved.end(), + true); + } + TF_DeleteStatus(status); +} + +TEST(TF_GrapplerItem, FetchNodes) { + GrapplerItem item; + item.fetch = std::vector{"Conv", "BiasAdd"}; + TF_GrapplerItem* c_item = reinterpret_cast(&item); + + int list_total_size = 0; + for (const string& s : item.fetch) { + list_total_size += s.size(); + } + + size_t storage_size = 0; + int num_values = 0; + TF_Status* status = TF_NewStatus(); + TF_GetFetchNodesListSize(c_item, &num_values, &storage_size, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(item.fetch.size(), num_values); + EXPECT_EQ(list_total_size, storage_size); + + std::unique_ptr values(new char*[item.fetch.size()]); + std::unique_ptr lens(new size_t[item.fetch.size()]); + std::unique_ptr storage(new char[storage_size]); + TF_GetFetchNodesList(c_item, values.get(), lens.get(), item.fetch.size(), + storage.get(), storage_size, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + for (size_t i = 0; i < item.fetch.size(); ++i) { + EXPECT_EQ(item.fetch[i].size(), lens[i]) << i; + EXPECT_EQ(item.fetch[i], + string(static_cast(values[i]), lens[i])) + << i; + } + TF_DeleteStatus(status); +} + +TEST(TF_GraphProperties, InputProperties) { + std::unique_ptr cluster(new SingleMachine(5 * 60, 3, 0)); + TF_ASSERT_OK(cluster->Provision()); + + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TF_Status* status = TF_NewStatus(); + TF_GraphProperties* graph_properties = + TF_NewGraphProperties(reinterpret_cast(&item)); + TF_InferStatically(graph_properties, true, false, false, false, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + for (const NodeDef& node : item.graph.node()) { + if (node.op() == "AddN") { + int num_values = 0; + TF_GetInputPropertiesListSize(graph_properties, node.name().c_str(), + &num_values, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(num_values, 1); + + std::vector in_props_buf(num_values, TF_NewBuffer()); + + TF_GetInputPropertiesList(graph_properties, node.name().c_str(), + in_props_buf.data(), num_values, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::OpInfo::TensorProperties in_props; + Status s = tensorflow::BufferToMessage(in_props_buf[0], &in_props); + TF_ASSERT_OK(s); + + EXPECT_EQ(DT_FLOAT, in_props.dtype()); + EXPECT_FALSE(in_props.shape().unknown_rank()); + EXPECT_EQ(2, in_props.shape().dim_size()); + EXPECT_EQ(10, in_props.shape().dim(0).size()); + EXPECT_EQ(1, in_props.shape().dim(1).size()); + + for (int i = 0; i < in_props_buf.size(); i++) + TF_DeleteBuffer(in_props_buf[i]); + } + } + TF_DeleteGraphProperties(graph_properties); + TF_DeleteStatus(status); + TF_ASSERT_OK(cluster->Shutdown()); +} + +TEST(TF_GraphProperties, OutputProperties) { + std::unique_ptr cluster(new SingleMachine(5 * 60, 3, 0)); + TF_ASSERT_OK(cluster->Provision()); + + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TF_Status* status = TF_NewStatus(); + TF_GraphProperties* graph_properties = + TF_NewGraphProperties(reinterpret_cast(&item)); + TF_InferStatically(graph_properties, true, false, false, false, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + for (const NodeDef& node : item.graph.node()) { + if (node.op() == "AddN") { + int num_values = 0; + TF_GetOutputPropertiesListSize(graph_properties, node.name().c_str(), + &num_values, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(num_values, 1); + + std::vector out_props_buf(num_values, TF_NewBuffer()); + + TF_GetOutputPropertiesList(graph_properties, node.name().c_str(), + out_props_buf.data(), num_values, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + tensorflow::OpInfo::TensorProperties out_props; + Status s = tensorflow::BufferToMessage(out_props_buf[0], &out_props); + TF_ASSERT_OK(s); + + EXPECT_EQ(DT_FLOAT, out_props.dtype()); + EXPECT_FALSE(out_props.shape().unknown_rank()); + EXPECT_EQ(2, out_props.shape().dim_size()); + EXPECT_EQ(10, out_props.shape().dim(0).size()); + EXPECT_EQ(1, out_props.shape().dim(1).size()); + + for (int i = 0; i < out_props_buf.size(); i++) + TF_DeleteBuffer(out_props_buf[i]); + } + } + TF_DeleteStatus(status); + TF_DeleteGraphProperties(graph_properties); + TF_ASSERT_OK(cluster->Shutdown()); +} + +TEST(TF_FunctionLibraryDefinition, LookUpOpDef) { + TF_Buffer* g_buf = TF_NewBuffer(); + TF_Buffer* op_buf = TF_NewBuffer(); + TF_Status* status = TF_NewStatus(); + GraphDef g_def; + Status s = MessageToBuffer(g_def, g_buf); + TF_ASSERT_OK(s); + TF_FunctionLibraryDefinition* func = + TF_NewFunctionLibraryDefinition(g_buf, status); + + TF_LookUpOpDef(func, "Add", op_buf, status); + string actual_string(reinterpret_cast(op_buf->data), + op_buf->length); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + + const OpDef* expected_op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def)); + string expected_serialized; + expected_op_def->SerializeToString(&expected_serialized); + EXPECT_EQ(expected_serialized, actual_string); + TF_DeleteBuffer(g_buf); + TF_DeleteBuffer(op_buf); + TF_DeleteStatus(status); + TF_DeleteFunctionLibraryDefinition(func); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/c/experimental/op_handler/BUILD b/tensorflow/c/experimental/op_handler/BUILD new file mode 100644 index 00000000000000..bdb5328180c44c --- /dev/null +++ b/tensorflow/c/experimental/op_handler/BUILD @@ -0,0 +1,43 @@ +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_test( + name = "internal_test", + srcs = ["internal_test.cc"], + deps = [ + ":internal", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:errors", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "internal", + srcs = ["internal.cc"], + hdrs = ["internal.h"], + deps = [ + ":wrapper_operation", + "//tensorflow/c:conversion_macros", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/platform:types", + ], +) + +cc_library( + name = "wrapper_operation", + srcs = ["wrapper_operation.cc"], + hdrs = ["wrapper_operation.h"], + deps = ["//tensorflow/c/eager:abstract_operation"], +) diff --git a/tensorflow/c/experimental/op_handler/internal.cc b/tensorflow/c/experimental/op_handler/internal.cc new file mode 100644 index 00000000000000..b9acbf445832f0 --- /dev/null +++ b/tensorflow/c/experimental/op_handler/internal.cc @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_ +#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_ + +#include "tensorflow/c/experimental/op_handler/internal.h" + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/experimental/op_handler/wrapper_operation.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx) + : AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {} +OpHandlerContext::~OpHandlerContext() {} +void OpHandlerContext::Release() { delete this; } +Status OpHandlerContext::RegisterFunction(AbstractFunction* function) { + return parent_ctx_->RegisterFunction(function); +} + +Status OpHandlerContext::RemoveFunction(const string& function) { + return parent_ctx_->RemoveFunction(function); +} + +void OpHandlerContext::set_default_handler(OpHandler* handler) { + handler->Ref(); + default_handler_.reset(handler); +} + +OpHandlerOperation* OpHandlerContext::CreateOperation() { + OpHandlerOperation* result = + new OpHandlerOperation(parent_ctx_->CreateOperation()); + if (default_handler_ != nullptr) { + result->set_handler(default_handler_.get()); + } + return result; +} + +OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op) + : WrapperOperation(parent_op, kOpHandler) {} + +OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); } + +void OpHandlerOperation::set_handler(OpHandler* handler) { + if (handler != nullptr) { + handler->Ref(); + } + handler_.reset(handler); +} + +Status OpHandlerOperation::Execute(absl::Span retvals, + int* num_retvals) { + if (handler_ == nullptr) { + return WrapperOperation::Execute(retvals, num_retvals); + } else { + return handler_->Execute(this, retvals, num_retvals); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ diff --git a/tensorflow/c/experimental/op_handler/internal.h b/tensorflow/c/experimental/op_handler/internal.h new file mode 100644 index 00000000000000..de893f77a7edf4 --- /dev/null +++ b/tensorflow/c/experimental/op_handler/internal.h @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/experimental/op_handler/wrapper_operation.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class OpHandlerOperation; + +// Op handlers are a convenient way to intercept and transform computation. +// +// The implementation is currently experimental and incomplete, but aims +// eventually to support tracing and replay of function bodies, gradients +// through copy operations, and a variety of hooks for things like debug +// strings. A public C API for op handlers is planned. +class OpHandler : public core::RefCounted { + public: + // Called on operation->Execute when operation->get_handler() == this. + // + // Allows the handler to customize or inspect `operation`'s execution. + virtual Status Execute(OpHandlerOperation* operation, + absl::Span retvals, + int* num_retvals) = 0; + // Creates a new handler by merging this handler with `next_handler`. + // + // The new handler is expected to transform operations first with this handler + // and then execute the resulting operations on `next_handler` (by calling + // `OpHandlerOperation::set_handler` and passing `next_handler`). If this is + // not possible then the merge operation should fail. + virtual Status Merge(OpHandler* next_handler, + core::RefCountPtr& merged_handler) = 0; +}; + +// Keeps some handler-specific metadata, but otherwise wraps a single +// AbstractOperation in the underlying context. The operation is created, its +// attributes set, etc., and at execution time it is presented to its handler, +// which may choose to execute it or simply inspect it and do something else. +// +// This is somewhat different than the Context approach, where the operation's +// construction is streamed through each layered Context. The streaming approach +// would require a much larger op handler public API, one function pointer per +// attribute type, and there is some ambiguity before an op is finalized about +// whether it should be presented as-is to handlers (regular operations) or +// replayed (function calls and control flow operations). +class OpHandlerOperation : public WrapperOperation { + public: + explicit OpHandlerOperation(AbstractOperation*); + OpHandler* get_handler(); + void set_handler(OpHandler* handler); + Status Execute(absl::Span retvals, + int* num_retvals) override; + + protected: + core::RefCountPtr handler_; +}; + +// A context which allows a default handler to be set for new operations. It +// otherwise defers to the context it wraps. +// +// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar +// in some ways. Having each handler be its own context seems almost doable, +// with things like copy operations and function/control flow replay being +// somewhat tricky (since they should be generated at the top of the handler +// stack and "caught" at the bottom). After handlers have evolved for a bit we +// should re-evaluate whether the handler+context concepts can be merged. +class OpHandlerContext : public AbstractContext { + public: + explicit OpHandlerContext(AbstractContext*); + void Release() override; + OpHandlerOperation* CreateOperation() override; + Status RegisterFunction(AbstractFunction*) override; + Status RemoveFunction(const string&) override; + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kOpHandler; + } + ~OpHandlerContext() override; + + void set_default_handler(OpHandler* handler); + + private: + AbstractContext* parent_ctx_; // Not owned. + core::RefCountPtr default_handler_; +}; + +class ReleaseOpHandlerOperation { + public: + void operator()(OpHandlerOperation* operation) { operation->Release(); } +}; + +typedef std::unique_ptr + OpHandlerOperationPtr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_ diff --git a/tensorflow/c/experimental/op_handler/internal_test.cc b/tensorflow/c/experimental/op_handler/internal_test.cc new file mode 100644 index 00000000000000..d8ac8b3b9850cd --- /dev/null +++ b/tensorflow/c/experimental/op_handler/internal_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/op_handler/internal.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class TestOpHandler : public OpHandler { + public: + TestOpHandler() : last_operation_(new std::string("")) {} + Status Execute(OpHandlerOperation* operation, + absl::Span retvals, + int* num_retvals) override { + CHECK(operation->get_handler() == this); + *last_operation_ = operation->Name(); + operation->set_handler(next_handler_.get()); + return operation->Execute(retvals, num_retvals); + } + Status Merge(OpHandler* next_handler, + core::RefCountPtr& merged_handler) override { + merged_handler.reset(new TestOpHandler(next_handler, last_operation_)); + return Status::OK(); + } + + core::RefCountPtr next_handler_ = nullptr; + // Shared between merged handlers of this type. + std::shared_ptr last_operation_; + + private: + TestOpHandler(OpHandler* next_handler, + std::shared_ptr last_operation) + : next_handler_(next_handler), last_operation_(last_operation) { + next_handler->Ref(); + } +}; + +TEST(INTERNAL_TEST, UseOpHandler) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr + c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()), + TF_DeleteExecutionContext); + OpHandlerContext ctx(unwrap(c_ctx.get())); + core::RefCountPtr outer_handler(new TestOpHandler()); + core::RefCountPtr inner_handler(new TestOpHandler()); + ctx.set_default_handler(outer_handler.get()); + OpHandlerOperationPtr op(ctx.CreateOperation()); + Status s = op->Reset("NoOp", ""); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + std::vector retvals; + int num_retvals = 0; + EXPECT_EQ("", *outer_handler->last_operation_); + s = op->Execute(absl::Span(retvals), &num_retvals); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + EXPECT_EQ("NoOp", *outer_handler->last_operation_); + *outer_handler->last_operation_ = ""; + EXPECT_EQ("", *inner_handler->last_operation_); + + // This op executes on both handlers, changing the state of `inner_handler` + // since the handler has decided to preserve that state across merges. + core::RefCountPtr merged; + s = inner_handler->Merge(outer_handler.get(), merged); + ctx.set_default_handler(merged.get()); + op.reset(ctx.CreateOperation()); + s = op->Reset("NoOp", ""); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + s = op->Execute(absl::Span(retvals), &num_retvals); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + EXPECT_EQ("NoOp", *inner_handler->last_operation_); + EXPECT_EQ("NoOp", *outer_handler->last_operation_); + + inner_handler.reset(); + outer_handler.reset(); + op.reset(ctx.CreateOperation()); + s = op->Reset("NoOp", ""); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + s = op->Execute(absl::Span(retvals), &num_retvals); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/op_handler/wrapper_operation.cc b/tensorflow/c/experimental/op_handler/wrapper_operation.cc new file mode 100644 index 00000000000000..018bba04b8a3d6 --- /dev/null +++ b/tensorflow/c/experimental/op_handler/wrapper_operation.cc @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/op_handler/wrapper_operation.h" + +namespace tensorflow { +WrapperOperation::WrapperOperation(AbstractOperation* parent_op, + AbstractOperationKind kind) + : AbstractOperation(kind), parent_op_(parent_op) { + // TODO(b/172003047): Consider making AbstractOperation RefCounted. + // parent_op_->Ref(); +} +void WrapperOperation::Release() { + parent_op_->Release(); + // TODO(b/172003047): Consider making AbstractOperation RefCounted. + delete this; +} + +Status WrapperOperation::Reset(const char* op, const char* raw_device_name) { + return parent_op_->Reset(op, raw_device_name); +} +const string& WrapperOperation::Name() const { return parent_op_->Name(); } +const string& WrapperOperation::DeviceName() const { + return parent_op_->DeviceName(); +} +Status WrapperOperation::SetDeviceName(const char* name) { + return parent_op_->SetDeviceName(name); +} +Status WrapperOperation::AddInput(AbstractTensorHandle* input) { + return parent_op_->AddInput(input); +} +Status WrapperOperation::AddInputList( + absl::Span inputs) { + return parent_op_->AddInputList(inputs); +} +Status WrapperOperation::SetAttrString(const char* attr_name, const char* data, + size_t length) { + return parent_op_->SetAttrString(attr_name, data, length); +} +Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) { + return parent_op_->SetAttrInt(attr_name, value); +} +Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) { + return parent_op_->SetAttrFloat(attr_name, value); +} +Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) { + return parent_op_->SetAttrBool(attr_name, value); +} +Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) { + return parent_op_->SetAttrType(attr_name, value); +} +Status WrapperOperation::SetAttrShape(const char* attr_name, + const int64_t* dims, const int num_dims) { + return parent_op_->SetAttrShape(attr_name, dims, num_dims); +} +Status WrapperOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { + return parent_op_->SetAttrFunction(attr_name, value); +} +Status WrapperOperation::SetAttrFunctionName(const char* attr_name, + const char* value, size_t length) { + return parent_op_->SetAttrFunctionName(attr_name, value, length); +} +Status WrapperOperation::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { + return parent_op_->SetAttrTensor(attr_name, tensor); +} +Status WrapperOperation::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) { + return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values); +} +Status WrapperOperation::SetAttrFloatList(const char* attr_name, + const float* values, int num_values) { + return parent_op_->SetAttrFloatList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrIntList(const char* attr_name, + const int64_t* values, int num_values) { + return parent_op_->SetAttrIntList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrTypeList(const char* attr_name, + const DataType* values, + int num_values) { + return parent_op_->SetAttrTypeList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { + return parent_op_->SetAttrBoolList(attr_name, values, num_values); +} +Status WrapperOperation::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, int num_values) { + return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); +} +Status WrapperOperation::SetAttrFunctionList( + const char* attr_name, absl::Span values) { + return parent_op_->SetAttrFunctionList(attr_name, values); +} +AbstractOperation* WrapperOperation::GetBackingOperation() { + return parent_op_; +} +Status WrapperOperation::Execute(absl::Span retvals, + int* num_retvals) { + return parent_op_->Execute(retvals, num_retvals); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/op_handler/wrapper_operation.h b/tensorflow/c/experimental/op_handler/wrapper_operation.h new file mode 100644 index 00000000000000..b0ec9f174f0d5d --- /dev/null +++ b/tensorflow/c/experimental/op_handler/wrapper_operation.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_ + +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { + +// Forwards all of the AbstractOperation's methods to its wrapped operation. +// +// Useful as a base class to default to forwarding while adding some +// customization. +class WrapperOperation : public AbstractOperation { + public: + explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind); + void Release() override; + Status Reset(const char* op, const char* raw_device_name) override; + const string& Name() const override; + const string& DeviceName() const override; + Status SetDeviceName(const char* name) override; + Status AddInput(AbstractTensorHandle* input) override; + Status AddInputList(absl::Span inputs) override; + Status Execute(absl::Span retvals, + int* num_retvals) override; + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrInt(const char* attr_name, int64_t value) override; + Status SetAttrFloat(const char* attr_name, float value) override; + Status SetAttrBool(const char* attr_name, bool value) override; + Status SetAttrType(const char* attr_name, DataType value) override; + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override; + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override; + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + Status SetAttrFunctionList( + const char* attr_name, + absl::Span values) override; + AbstractOperation* GetBackingOperation(); + + private: + AbstractOperation* parent_op_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_ diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc index debeba18edfd0b..cedd19b427b4f0 100644 --- a/tensorflow/c/experimental/ops/array_ops.cc +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -22,14 +22,13 @@ using tensorflow::tracing::MaybeSetOpName; namespace tensorflow { namespace ops { -Status Identity(AbstractContext* ctx, - absl::Span inputs, +Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, absl::Span outputs, const char* name) { AbstractOperationPtr identity_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR( identity_op->Reset("Identity", /*raw_device_name=*/nullptr)); TF_RETURN_IF_ERROR(MaybeSetOpName(identity_op.get(), name)); - TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(identity_op->AddInput(input)); int num_retvals = 1; return identity_op->Execute(outputs, &num_retvals); } @@ -81,5 +80,17 @@ Status ExpandDims(AbstractContext* ctx, return op->Execute(outputs, &num_retvals); } +Status OnesLike(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(op.get(), name)); + TF_RETURN_IF_ERROR(op->AddInput(inputs[0])); + + int num_retvals = 1; + return op->Execute(outputs, &num_retvals); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h index f63412ed248352..dae99b2c31cc70 100644 --- a/tensorflow/c/experimental/ops/array_ops.h +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -22,8 +22,7 @@ limitations under the License. namespace tensorflow { namespace ops { -Status Identity(AbstractContext* ctx, - absl::Span inputs, +Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, absl::Span outputs, const char* name); Status IdentityN(AbstractContext* ctx, @@ -42,6 +41,10 @@ Status ExpandDims(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status OnesLike(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index 20aab8a77d30a6..b91a1d0d33086d 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -43,9 +43,19 @@ Status Conj(AbstractContext* ctx, auto dtype = inputs[0]->DataType(); if (DataTypeIsFloating(BaseType(dtype)) || DataTypeIsInteger(BaseType(dtype))) { - TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name)); + TF_RETURN_IF_ERROR(Identity(ctx, inputs[0], outputs, name)); + } else if (DataTypeIsComplex(BaseType(dtype)) || + BaseType(dtype) == DT_VARIANT) { + AbstractOperationPtr conj_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(conj_op->Reset("Conj", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(conj_op.get(), name)); + TF_RETURN_IF_ERROR(conj_op->AddInput(inputs[0])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(conj_op->Execute(outputs, &num_retvals)); } else { - return errors::Unimplemented("Conj does not support complex types yet."); + return errors::InvalidArgument( + "Expected numeric or variant tensor, got dtype ", dtype); } return Status::OK(); } @@ -118,6 +128,19 @@ Status Sum(AbstractContext* ctx, absl::Span inputs, return Status::OK(); } +Status Div(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr div_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(div_op->Reset("Div", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name)); + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y + + int num_retvals = 1; + TF_RETURN_IF_ERROR(div_op->Execute(outputs, &num_retvals)); // z = x / y + return Status::OK(); +} + Status DivNoNan(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { @@ -172,5 +195,18 @@ Status SqrtGrad(AbstractContext* ctx, return s; } +Status Log1p(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr log1p_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(log1p_op->Reset("Log1p", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(log1p_op.get(), name)); + TF_RETURN_IF_ERROR(log1p_op->AddInput(inputs[0])); + + int num_retvals = 1; + Status s = log1p_op->Execute(outputs, &num_retvals); + return s; +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index 7051e38656f0df..56707c0a60a0ff 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -44,6 +44,9 @@ Status Sum(AbstractContext* ctx, absl::Span inputs, Status Sub(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status Div(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + Status DivNoNan(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); @@ -59,6 +62,10 @@ Status SqrtGrad(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status Log1p(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index 6a97dbf0939926..b1cc2ffc7d6c63 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -69,5 +69,38 @@ Status Relu(AbstractContext* ctx, return Status::OK(); } +Status BiasAdd(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr bias_add_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + bias_add_op->Reset("BiasAdd", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_op.get(), name)); + TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[0])); // tensor input + TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[1])); // bias + + int num_retvals = 1; + TF_RETURN_IF_ERROR(bias_add_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +// Computes Bias Add gradient given upstream grads +Status BiasAddGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const char* data_format, const char* name) { + AbstractOperationPtr bias_add_grad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + bias_add_grad_op->Reset("BiasAddGrad", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_grad_op.get(), name)); + TF_RETURN_IF_ERROR(bias_add_grad_op->SetAttrString("data_format", data_format, + strlen(data_format))); + TF_RETURN_IF_ERROR(bias_add_grad_op->AddInput(inputs[0])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(bias_add_grad_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index 3c0e04579a11fe..d5b8cdde356b09 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -34,6 +34,15 @@ Status Relu(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status BiasAdd(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status BiasAddGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const char* data_format, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index ac168830a0efb9..63396d22b49981 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -146,6 +146,7 @@ cc_library( ":tf_signature_def_function", ":variable", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc index 1c61540564422f..b9344238b79eb0 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.cc @@ -343,7 +343,8 @@ Status InitializeCreateResourceFunctions(ImmediateExecutionContext* ctx, std::unique_ptr out; TF_RETURN_IF_ERROR(CreateConcreteFunction(ctx, *create_resource_fn, obj_graph, objects, &out)); - revived->concrete_functions[create_resource_fn->node_id] = std::move(out); + revived->concrete_functions.Insert(std::move(out), + create_resource_fn->node_id); } return Status(); } @@ -352,8 +353,6 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx, const SavedObjectGraph& obj_graph, const PartiallyRevivedObjects& objects, RevivedObjects* revived) { - gtl::FlatMap>* destination_func_map = - &revived->concrete_functions; gtl::FlatMap>* destination_sig_map = &revived->signature_def_functions; @@ -361,7 +360,7 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx, int node_id = id_and_func.first; const TFConcreteFunctionRevivalState& func = id_and_func.second; - if (destination_func_map->find(node_id) != destination_func_map->end()) { + if (revived->concrete_functions.Find(node_id)) { // The function has already been initialized in the destination_map, // so we can skip this node. This can occur because we initialize // CreateResource functions before calling this function. @@ -371,7 +370,7 @@ Status InitializeAllFunctions(ImmediateExecutionContext* ctx, std::unique_ptr out; TF_RETURN_IF_ERROR( CreateConcreteFunction(ctx, func, obj_graph, objects, &out)); - (*destination_func_map)[node_id] = std::move(out); + revived->concrete_functions.Insert(std::move(out), node_id); } for (const auto& id_and_func : objects.signature_def_functions) { @@ -398,20 +397,16 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, for (auto& id_and_resource : objects->restored_resources) { RestoredResourceRevivalState& resource = id_and_resource.second; int create_resource_fn_node = resource.create_resource->node_id; - const gtl::FlatMap>& - revived_functions = revived->concrete_functions; - const auto& revived_functions_iter = - revived_functions.find(create_resource_fn_node); - if (revived_functions_iter == revived_functions.end()) { + const TFConcreteFunction* create_resource_fn = + revived->concrete_functions.Find(create_resource_fn_node); + if (create_resource_fn == nullptr) { return errors::FailedPrecondition( "ConcreteFunction at node ", create_resource_fn_node, " should have been initialized prior to being called."); } - const TFConcreteFunction& create_resource_fn = - *revived_functions_iter->second; ImmediateOpPtr function_op; - TF_RETURN_IF_ERROR(create_resource_fn.MakeCallOp({}, &function_op)); + TF_RETURN_IF_ERROR(create_resource_fn->MakeCallOp({}, &function_op)); TF_RETURN_IF_ERROR(function_op->SetDeviceName(resource.device.c_str())); AbstractTensorHandle* resource_handle = nullptr; @@ -431,21 +426,6 @@ Status CreateAllResourceHandles(ImmediateExecutionContext* ctx, return Status(); } -// Finds a ConcreteFunction with node id `node` in `objects`, and sets *out to -// point to it. If node doesn't exist in `objects`, out is untouched, and an -// error status is returned. -Status FindConcreteFunction(int node, RevivedObjects* objects, - TFConcreteFunction** out) { - auto func_iter = objects->concrete_functions.find(node); - if (func_iter == objects->concrete_functions.end()) { - return errors::FailedPrecondition( - "Failed to find ConcreteFunction with node id ", node, - " in revived objects"); - } - *out = func_iter->second.get(); - return Status(); -} - Status BuildResources(ImmediateExecutionContext* ctx, const SavedObjectGraph& obj_graph, PartiallyRevivedObjects* objects, @@ -460,22 +440,35 @@ Status BuildResources(ImmediateExecutionContext* ctx, // Check all the functions associated with the resource have already been // initialized in `revived` if (resource_revival_state.create_resource != nullptr) { - TF_RETURN_IF_ERROR( - FindConcreteFunction(resource_revival_state.create_resource->node_id, - revived, &create_resource)); + create_resource = revived->concrete_functions.Find( + resource_revival_state.create_resource->node_id); + if (create_resource == nullptr) { + return errors::FailedPrecondition( + "'create_resource' function with node id ", + resource_revival_state.create_resource->node_id, " not found"); + } } TFConcreteFunction* initialize = nullptr; if (resource_revival_state.initialize != nullptr) { - TF_RETURN_IF_ERROR(FindConcreteFunction( - resource_revival_state.initialize->node_id, revived, &initialize)); + initialize = revived->concrete_functions.Find( + resource_revival_state.initialize->node_id); + if (initialize == nullptr) { + return errors::FailedPrecondition( + "'initialize' function with node id ", + resource_revival_state.initialize->node_id, " not found"); + } } TFConcreteFunction* destroy_resource = nullptr; if (resource_revival_state.destroy_resource != nullptr) { - TF_RETURN_IF_ERROR( - FindConcreteFunction(resource_revival_state.destroy_resource->node_id, - revived, &destroy_resource)); + destroy_resource = revived->concrete_functions.Find( + resource_revival_state.destroy_resource->node_id); + if (destroy_resource == nullptr) { + return errors::FailedPrecondition( + "'destroy_resource' function with node id ", + resource_revival_state.destroy_resource->node_id, " not found"); + } } if (resource_revival_state.resource_handle == nullptr) { diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h index cc9be0b937d708..0f09c743afc27f 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" @@ -29,6 +30,43 @@ limitations under the License. namespace tensorflow { +// A container for revived saved model objects. +// +// Most of the objects will be revived from nodes in the object graph, and for +// those objects this container provides a map from node id to the revived +// objects. +// +// For objects that have to be revived but are not part of the object graph, +// this container provides a place where the objects can be stored so they are +// available to the runtime. +template +class RevivedObjectContainer { + public: + // Insert an object that is not related to a node id. This usually means the + // object was not referenced by the object_graph, but is needed by other + // objects. + void Insert(std::unique_ptr object) { + objects_.push_back(std::move(object)); + } + + // Insert an object that is tied to the given object graph node id. + void Insert(std::unique_ptr object, int node_id) { + objects_by_id_[node_id] = object.get(); + Insert(std::move(object)); + } + + // Find an object by the object graph node id. + // Returns nullptr if there is no such object. + T* Find(int node_id) { + auto it = objects_by_id_.find(node_id); + return it == objects_by_id_.end() ? nullptr : it->second; + } + + private: + std::vector> objects_; + absl::flat_hash_map objects_by_id_; +}; + // RevivedObjects is mainly used as a container for all the "state" owned by // SavedModel. It stores all non-"user object" nodes from a SavedModel // (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62) @@ -37,12 +75,14 @@ namespace tensorflow { // (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29) // to the revived object of the corresponding type. struct RevivedObjects { + // Order of declaration is important here: we want the RestoredResources to be + // freed after TFConcreteFunctions, for example. gtl::FlatMap> variables; gtl::FlatMap> assets; gtl::FlatMap> constants; - gtl::FlatMap> concrete_functions; gtl::FlatMap> signature_def_functions; + RevivedObjectContainer concrete_functions; gtl::FlatMap restored_resources; gtl::FlatMap signatures_map; }; diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_api.h b/tensorflow/c/experimental/saved_model/core/saved_model_api.h index ff891e13ba47e1..dd06aa89682482 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_api.h @@ -46,8 +46,6 @@ class SavedModelAPI { virtual Status GetSignatureDefFunction(const std::string& signature_def_key, SignatureDefFunction** function) = 0; - virtual std::vector ListFunctions() = 0; - virtual ~SavedModelAPI() = default; }; diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index 988f7e382a82d6..2036318e2e50a1 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -45,8 +45,7 @@ EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr) { return EagerContextPtr(new EagerContext( SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, - /* async= */ false, - /* lazy_copy_function_remote_inputs= */ false, device_mgr, + /* async= */ false, device_mgr, /* device_mgr_owned= */ false, /* rendezvous= */ nullptr, /* cluster_flr= */ nullptr)); } diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index f0990235963b81..7ed614ffe16bbc 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -73,7 +73,6 @@ using FlatTensorFunctionMap = namespace { - const TrackableObjectGraph::TrackableObject::SerializedTensor* FindSerializedTensorInTrackable( const TrackableObjectGraph::TrackableObject& trackable_object, @@ -181,12 +180,11 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path, return errors::NotFound("No saved object found at path ", function_path); } - auto function_iter = revived_objects_.concrete_functions.find(*node); - if (function_iter == revived_objects_.concrete_functions.end()) { + *function = revived_objects_.concrete_functions.Find(*node); + if (*function == nullptr) { return errors::NotFound("No function found at path ", function_path); } - *function = function_iter->second.get(); return Status(); } @@ -211,15 +209,6 @@ Status TFSavedModelAPI::GetSignatureDefFunction( return Status(); } -std::vector TFSavedModelAPI::ListFunctions() { - std::vector result; - result.reserve(revived_objects_.concrete_functions.size()); - for (auto& index_and_function : revived_objects_.concrete_functions) { - result.push_back(index_and_function.second.get()); - } - return result; -} - Status TFSavedModelAPI::GetVariable(const std::string& variable_path, Variable** variable) { absl::optional node = @@ -263,10 +252,10 @@ Status TFSavedModelAPI::Load( // This occurs in python here: // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 - // Step 1: For each node in the graph, we should initialize an object of the + // For each node in the graph, we should initialize an object of the // corresponding type. For objects that depend on the initialization of other // objects (like functions which capture resources), we will initialize them - // in step 2. + // later. PartiallyRevivedObjects partially_revived_objects; TF_RETURN_IF_ERROR(internal::PartiallyReviveSavedModelObjects( bundle.meta_graph_def(), context, directory, &partially_revived_objects)); @@ -275,6 +264,22 @@ Status TFSavedModelAPI::Load( TF_RETURN_IF_ERROR(partially_revived_objects.Build( context, bundle.saved_object_graph(), &revived_objects)); + // Revive function library functions as concrete functions without captures. + // This is necessary because object graph functions may refer to functions + // _not_ in the object graph: A while loop, for example, will create two + // auxiliary `while_cond` and `while_body` functions that are only present in + // the graph def function library. + for (const FunctionDef& function : + bundle.meta_graph_def().graph_def().library().function()) { + std::unique_ptr concrete_function; + TF_RETURN_IF_ERROR(TFConcreteFunction::Create(/*function_def=*/&function, + /*captures=*/{}, + /*metadata=*/{}, + /*ctx=*/context, + /*out=*/&concrete_function)); + revived_objects.concrete_functions.Insert(std::move(concrete_function)); + } + TF_RETURN_IF_ERROR( RestoreCheckpoint(&bundle, revived_objects, directory, context)); diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index bc39a974ad2c44..45c8673e65f718 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -66,8 +66,6 @@ class TFSavedModelAPI : public SavedModelAPI { ImmediateExecutionContext* context, std::unique_ptr* out); - std::vector ListFunctions() override; - ~TFSavedModelAPI() override = default; Status GetVariable(const std::string& variable_path, Variable** variable); diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index b89fb9f6d64962..cb2e5751bed7c7 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -122,9 +122,4 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, return tensorflow::wrap(result); } -TF_ConcreteFunctionList* TF_ListSavedModelFunctions(TF_SavedModel* model) { - return new TF_ConcreteFunctionList{ - tensorflow::unwrap(model)->ListFunctions()}; -} - } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index 5a4f676ec06773..845683f2d7e2f0 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -524,6 +524,62 @@ TEST_P(CSavedModelAPITest, LoadSavedModelWithUninitializedVariable) { TFE_DeleteContext(ctx); } +TEST_P(CSavedModelAPITest, LoadSavedModelWithWhileLoop) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + bool use_tfrt = GetParam(); + if (use_tfrt) { + TFE_DeleteContextOptions(opts); + TF_DeleteStatus(status); + GTEST_SKIP(); // TODO(chky) : Enable this once TFRT is open sourced. + } + + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + std::string model_dir = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), + "c/experimental/saved_model/internal/testdata/SimpleWhileLoop"); + + TF_SavedModel* saved_model = + TF_LoadSavedModel(model_dir.c_str(), ctx, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_ConcreteFunction* while_fn = + TF_GetSavedModelConcreteFunction(saved_model, "compute", status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + std::vector while_fn_inputs; + while_fn_inputs.push_back(TestScalarTensorHandle(ctx, 10.0f)); + + TFE_Op* while_fn_op = TF_ConcreteFunctionMakeCallOp( + while_fn, while_fn_inputs.data(), while_fn_inputs.size(), status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TFE_TensorHandle* while_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(while_fn_op, &while_fn_outputs[0], &num_retvals, status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(while_fn_outputs[0], status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + ASSERT_EQ(TF_NumDims(result), 0); + float output_value = *static_cast(TF_TensorData(result)); + ASSERT_FLOAT_EQ(output_value, 55); // 10+9+...+1 + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(while_fn_outputs[0]); + TFE_DeleteOp(while_fn_op); + TFE_DeleteTensorHandle(while_fn_inputs[0]); + TF_DeleteSavedModel(saved_model); + TF_DeleteStatus(status); + TFE_DeleteContext(ctx); +} + INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticSavedModelTests, CSavedModelAPITest, ::testing::Bool()); diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD index f446401ae77cbc..6d07018a78ab68 100644 --- a/tensorflow/c/experimental/saved_model/internal/testdata/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/testdata/BUILD @@ -12,8 +12,9 @@ py_strict_binary( srcs = ["gen_saved_models.py"], python_version = "PY3", deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:platform", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_spec", "//tensorflow/python:variables", @@ -21,7 +22,7 @@ py_strict_binary( "//tensorflow/python/eager:def_function", "//tensorflow/python/module", "//tensorflow/python/saved_model", - "//tensorflow/python/saved_model:save_options", + "@absl_py//absl:app", ], ) @@ -29,6 +30,7 @@ py_strict_binary( filegroup( name = "saved_models", srcs = glob([ + "SimpleWhileLoop/**", "UninitializedVariable/**", ]), visibility = [ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/saved_model.pb b/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/saved_model.pb new file mode 100644 index 00000000000000..b94c853029a065 Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/saved_model.pb differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.data-00000-of-00001 b/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000000..1039a8fe6dd60e Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.index b/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.index new file mode 100644 index 00000000000000..71e4af3fa42e4b Binary files /dev/null and b/tensorflow/c/experimental/saved_model/internal/testdata/SimpleWhileLoop/variables/variables.index differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb index 81ce8fe662bff8..d03f2591fa42f5 100644 Binary files a/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb and b/tensorflow/c/experimental/saved_model/internal/testdata/UninitializedVariable/saved_model.pb differ diff --git a/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py b/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py index f2a8bd5a9a4e15..a65de68f6c0998 100644 --- a/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py +++ b/tensorflow/c/experimental/saved_model/internal/testdata/gen_saved_models.py @@ -26,16 +26,18 @@ from __future__ import print_function import os +from absl import app from tensorflow.python.compat import v2_compat from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_spec from tensorflow.python.module import module +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import app from tensorflow.python.saved_model import saved_model @@ -72,11 +74,32 @@ def compute(self, value): to_save, export_dir=os.path.join(base_dir, "UninitializedVariable")) +def _gen_simple_while_loop(base_dir): + """Generates a saved model with a while loop.""" + + class Module(module.Module): + """A module with a while loop.""" + + @def_function.function( + input_signature=[tensor_spec.TensorSpec((), dtypes.float32)]) + def compute(self, value): + acc, _ = control_flow_ops.while_loop( + cond=lambda acc, i: i > 0, + body=lambda acc, i: (acc + i, i - 1), + loop_vars=(constant_op.constant(0.0), value)) + return acc + + to_save = Module() + saved_model.save( + to_save, export_dir=os.path.join(base_dir, "SimpleWhileLoop")) + + def main(args): if len(args) != 2: raise app.UsageError("Expected one argument (base_dir).") _, base_dir = args _gen_uninitialized_variable(base_dir) + _gen_simple_while_loop(base_dir) if __name__ == "__main__": diff --git a/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/tensorflow/c/experimental/saved_model/public/saved_model_api.h index 80ba37bab264a0..cef7fe860e5358 100644 --- a/tensorflow/c/experimental/saved_model/public/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -100,11 +100,6 @@ TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, const char* signature_def_key, TF_Status* status); -// Returns a list of all ConcreteFunctions stored in this SavedModel. -// The lifetime of the returned list is bound to `model`. -TF_CAPI_EXPORT extern TF_ConcreteFunctionList* TF_ListSavedModelFunctions( - TF_SavedModel* model); - #ifdef __cplusplus } // end extern "C" #endif // __cplusplus diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 214313c960aaaf..47851b67c28f73 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -7,17 +7,28 @@ load( "tf_cc_test", ) +# buildifier: disable=same-origin-load +load("//tensorflow:tensorflow.bzl", "filegroup") + package( licenses = ["notice"], # Apache 2.0 ) +filegroup( + name = "headers", + srcs = [ + "stream_executor.h", + ], + visibility = ["//tensorflow:__subpackages__"], +) + cc_library( name = "stream_executor_hdrs", hdrs = ["stream_executor.h"], visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/c:c_api_macros", - "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_headers", ], ) @@ -49,9 +60,14 @@ cc_library( "stream_executor.h", "stream_executor_internal.h", ], + visibility = [ + "//tensorflow/c:__subpackages__", + "//tensorflow/core/common_runtime/pluggable_device:__subpackages__", + ], deps = [ "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", "//tensorflow/stream_executor:executor_cache", "//tensorflow/stream_executor/lib", ], @@ -63,6 +79,7 @@ tf_cc_test( deps = [ ":stream_executor", ":stream_executor_internal", + ":stream_executor_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/protobuf:error_codes_proto_impl_cc", @@ -71,3 +88,14 @@ tf_cc_test( "//tensorflow/stream_executor:stream_executor_pimpl", ], ) + +cc_library( + name = "stream_executor_test_util", + srcs = ["stream_executor_test_util.cc"], + hdrs = ["stream_executor_test_util.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":stream_executor_hdrs", + "//tensorflow/c:tf_status", + ], +) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index ec2bada791e183..f9122d58d2a241 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" -#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" @@ -44,6 +43,7 @@ using tensorflow::StatusFromTF_Status; namespace stream_executor { using tensorflow::StringPiece; +using OwnedTFStatus = std::unique_ptr; namespace { @@ -188,41 +188,6 @@ port::Status ValidateSEPlatformRegistrationParams( } #undef VALIDATE_MEMBER -struct TFStatusDeleter { - void operator()(TF_Status* s) const { TF_DeleteStatus(s); } -}; -using OwnedTFStatus = std::unique_ptr; - -class CStream : public internal::StreamInterface { - public: - CStream(SP_Device* device, SP_StreamExecutor* stream_executor) - : device_(device), - stream_executor_(stream_executor), - stream_handle_(nullptr) {} - ~CStream() override { Destroy(); } - - port::Status Create() { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); - port::Status s = StatusFromTF_Status(c_status.get()); - return s; - } - - void Destroy() { - if (stream_handle_ != nullptr) { - stream_executor_->destroy_stream(device_, stream_handle_); - stream_handle_ = nullptr; - } - } - - SP_Stream Handle() { return stream_handle_; } - - private: - SP_Device* device_; - SP_StreamExecutor* stream_executor_; - SP_Stream stream_handle_; -}; - // Converts SE_EventStatus to Event::Status. Event::Status SEEventStatusToEventStatus(SE_EventStatus s) { switch (s) { @@ -237,82 +202,6 @@ Event::Status SEEventStatusToEventStatus(SE_EventStatus s) { } } -class CEvent : public internal::EventInterface { - public: - CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) - : device_(device), - stream_executor_(stream_executor), - event_handle_(nullptr) {} - ~CEvent() override { Destroy(); } - - port::Status Create() { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->create_event(device_, &event_handle_, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - - port::Status Record(SP_Stream stream_handle) { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->record_event(device_, stream_handle, event_handle_, - c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - - void Destroy() { - if (event_handle_ != nullptr) { - stream_executor_->destroy_event(device_, event_handle_); - event_handle_ = nullptr; - } - } - - SP_Event Handle() { return event_handle_; } - - private: - SP_Device* device_; - SP_StreamExecutor* stream_executor_; - SP_Event event_handle_; -}; - -class CTimer : public internal::TimerInterface { - public: - CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, - SP_TimerFns* timer_fns) - : device_(device), - stream_executor_(stream_executor), - timer_handle_(nullptr), - timer_fns_(timer_fns) {} - ~CTimer() override { Destroy(); } - - port::Status Create() { - OwnedTFStatus c_status(TF_NewStatus()); - stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - - void Destroy() { - if (timer_handle_ != nullptr) { - stream_executor_->destroy_timer(device_, timer_handle_); - timer_handle_ = nullptr; - } - } - - SP_Timer Handle() { return timer_handle_; } - - uint64 Microseconds() const override { - return timer_fns_->nanoseconds(timer_handle_) / 1000; - } - - uint64 Nanoseconds() const override { - return timer_fns_->nanoseconds(timer_handle_); - } - - private: - SP_Device* device_; - SP_StreamExecutor* stream_executor_; - SP_Timer timer_handle_; - SP_TimerFns* timer_fns_; -}; - // Converts DeviceMemoryBase to a C struct. SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; @@ -321,14 +210,12 @@ SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { device_memory_base.opaque = const_cast(mem->opaque()); device_memory_base.size = mem->size(); device_memory_base.payload = mem->payload(); - // TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here. return device_memory_base; } DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) { DeviceMemoryBase base(mem.opaque, mem.size); base.SetPayload(mem.payload); - // TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here. return base; } @@ -426,7 +313,6 @@ class CStreamExecutor : public internal::StreamExecutorInterface { LOG(ERROR) << status.error_message(); return absl::nullopt; } - // TODO(annarev): validate SP_AllocatorStats. ::stream_executor::AllocatorStats stats; stats.num_allocs = c_stats.num_allocs; stats.bytes_in_use = c_stats.bytes_in_use; @@ -849,15 +735,23 @@ port::StatusOr> CPlatform::GetUncachedExecutor( TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); TF_RETURN_IF_ERROR(ValidateSPDevice(device)); + // Get Device Count + int visible_device_count = 0; + platform_fns_.get_device_count(&platform_, &visible_device_count, + c_status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); + auto executor = absl::make_unique( std::move(device), &device_fns_, &stream_executor_, &platform_, - &platform_fns_, &timer_fns_, name_, platform_.visible_device_count); + &platform_fns_, &timer_fns_, name_, visible_device_count); auto result = absl::make_unique(this, std::move(executor), config.ordinal); return result; } -port::Status InitStreamExecutorPlugin(void* dso_handle) { +port::Status InitStreamExecutorPlugin(void* dso_handle, + std::string* device_type, + std::string* platform_name) { tensorflow::Env* env = tensorflow::Env::Default(); // Step 1: Load symbol for `TF_InitPlugin` @@ -867,10 +761,12 @@ port::Status InitStreamExecutorPlugin(void* dso_handle) { // Step 2: Call `TF_InitPlugin` auto init_fn = reinterpret_cast(dso_symbol); - return InitStreamExecutorPlugin(init_fn); + return InitStreamExecutorPlugin(init_fn, device_type, platform_name); } -port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) { +port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, + std::string* device_type, + std::string* platform_name) { SE_PlatformRegistrationParams params{ SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE}; SP_Platform platform{SP_PLATFORM_STRUCT_SIZE}; @@ -915,12 +811,9 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) { TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); - platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get()); - TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); - TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); - // Register new platform - std::string platform_name = std::string(platform.name); + *device_type = std::string(platform.type); + *platform_name = std::string(platform.name); std::unique_ptr cplatform( new stream_executor::CPlatform( std::move(platform), params.destroy_platform, std::move(platform_fns), @@ -928,8 +821,8 @@ port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn) { std::move(timer_fns))); SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( std::move(cplatform))); - - // TODO(annarev): Add pluggable device registration here. + // TODO(annarev): Return `use_bfc_allocator` value in some way so that it is + // available in `PluggableDeviceProcessState` once the latter is checked in. return port::Status::OK(); } } // namespace stream_executor diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h index bec77ef520b296..b3b56d1ce28a99 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -76,7 +76,7 @@ limitations under the License. // // Values such as `name` and `type` must outlive SE_InitPlugin call. // params->platform->name = DEVICE_NAME; // params->platform->type = DEVICE_TYPE; -// params->platform->visible_device_count = 2; +// params->platform_fns->get_device_count = get_device_count; // params->platform_fns->create_device = create_device; // params->platform_fns->destroy_device = destroy_device; // ... @@ -140,8 +140,9 @@ typedef enum SE_EventStatus { // https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57 typedef struct SP_DeviceMemoryBase { size_t struct_size; - void* ext; // free-form data set by plugin + void* ext; // Reserved for future use // Platform-dependent value representing allocated memory. + // Note that the pointer does not have to be to the virtual address itself. void* opaque; uint64_t size; // Size in bytes of this allocation. uint64_t payload; // Value for plugin's use @@ -427,22 +428,25 @@ typedef struct SP_Platform { // capital letters and underscores. const char* type; - // Number of visible devices - size_t visible_device_count; - // Whether this platform supports unified memory. // Unified memory is a single memory address space accessible from any device. TF_Bool supports_unified_memory; + + // Whether to wrap allocator for this device with an allocator that uses BFC + // (best-fit with coalescing) strategy. + TF_Bool use_bfc_allocator; } SP_Platform; -#define SP_PLATFORM_STRUCT_SIZE \ - TF_OFFSET_OF_END(SP_Platform, supports_unified_memory) +#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, use_bfc_allocator) typedef struct SP_PlatformFns { size_t struct_size; void* ext; // reserved for future use + // Callbacks for getting device count + void (*get_device_count)(const SP_Platform* platform, int* device_count, + TF_Status* status); // Callbacks for creating/destroying SP_Device. void (*create_device)(const SP_Platform* platform, SE_CreateDeviceParams* params, TF_Status* status); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 52ae4ba77e0b19..dab6939be509d1 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ #include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/platform.h" @@ -30,13 +31,25 @@ namespace stream_executor { typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, TF_Status* const); -// Registers StreamExecutor platform. -port::Status InitStreamExecutorPlugin(void* dso_handle); +// Registers StreamExecutor platform. `device_type` and `platform_name` are +// output parameters. +port::Status InitStreamExecutorPlugin(void* dso_handle, + std::string* device_type, + std::string* platform_name); // Allow registering a StreamExecutor plugin using a function (used for // testing). -port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn); +port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, + std::string* device_type, + std::string* platform_name); +struct TFStatusDeleter { + void operator()(TF_Status* s) const { TF_DeleteStatus(s); } +}; + +// This file implements core stream executor base classes in terms of +// the C API defined in stream_executor.h. A class "CSomething" represents a +// "Something" that can be manipulated via calls in the C interface. class CPlatform : public Platform { public: explicit CPlatform(SP_Platform platform, @@ -50,8 +63,17 @@ class CPlatform : public Platform { Id id() const override { return const_cast(&plugin_id_value_); } const std::string& Name() const override { return name_; } int VisibleDeviceCount() const override { - return platform_.visible_device_count; + int visible_device_count = 0; + std::unique_ptr c_status(TF_NewStatus()); + platform_fns_.get_device_count(&platform_, &visible_device_count, + c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return 0; + } + return visible_device_count; } + bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } port::StatusOr> DescriptionForDevice( int ordinal) const override; port::StatusOr ExecutorForDevice(int ordinal) override; @@ -83,5 +105,111 @@ class CPlatform : public Platform { stream_executor::ExecutorCache executor_cache_; }; +class CStream : public internal::StreamInterface { + public: + CStream(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + stream_handle_(nullptr) {} + ~CStream() override { Destroy(); } + + port::Status Create() { + std::unique_ptr c_status(TF_NewStatus()); + stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); + port::Status s = tensorflow::StatusFromTF_Status(c_status.get()); + return s; + } + + void Destroy() { + if (stream_handle_ != nullptr) { + stream_executor_->destroy_stream(device_, stream_handle_); + stream_handle_ = nullptr; + } + } + + SP_Stream Handle() { return stream_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Stream stream_handle_; +}; + +class CEvent : public internal::EventInterface { + public: + CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + event_handle_(nullptr) {} + ~CEvent() override { Destroy(); } + + port::Status Create() { + std::unique_ptr c_status(TF_NewStatus()); + stream_executor_->create_event(device_, &event_handle_, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + port::Status Record(SP_Stream stream_handle) { + std::unique_ptr c_status(TF_NewStatus()); + stream_executor_->record_event(device_, stream_handle, event_handle_, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (event_handle_ != nullptr) { + stream_executor_->destroy_event(device_, event_handle_); + event_handle_ = nullptr; + } + } + + SP_Event Handle() { return event_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Event event_handle_; +}; + +class CTimer : public internal::TimerInterface { + public: + CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, + SP_TimerFns* timer_fns) + : device_(device), + stream_executor_(stream_executor), + timer_handle_(nullptr), + timer_fns_(timer_fns) {} + ~CTimer() override { Destroy(); } + + port::Status Create() { + std::unique_ptr c_status(TF_NewStatus()); + stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (timer_handle_ != nullptr) { + stream_executor_->destroy_timer(device_, timer_handle_); + timer_handle_ = nullptr; + } + } + + SP_Timer Handle() { return timer_handle_; } + + uint64 Microseconds() const override { + return timer_fns_->nanoseconds(timer_handle_) / 1000; + } + + uint64 Nanoseconds() const override { + return timer_fns_->nanoseconds(timer_handle_); + } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Timer timer_handle_; + SP_TimerFns* timer_fns_; +}; + } // namespace stream_executor #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc index 56c4ea090528ff..dec1b4e65b6595 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" @@ -24,205 +25,26 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/timer.h" -struct SP_Stream_st { - explicit SP_Stream_st(int id) : stream_id(id) {} - int stream_id; -}; - -struct SP_Event_st { - explicit SP_Event_st(int id) : event_id(id) {} - int event_id; -}; - -struct SP_Timer_st { - explicit SP_Timer_st(int id) : timer_id(id) {} - int timer_id; -}; - namespace stream_executor { namespace { -constexpr int kDeviceCount = 2; -constexpr char kDeviceName[] = "MY_DEVICE"; -constexpr char kDeviceType[] = "GPU"; - -/*** Create SP_StreamExecutor (with empty functions) ***/ -void allocate(const SP_Device* const device, uint64_t size, - int64_t memory_space, SP_DeviceMemoryBase* const mem) {} -void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) { -} -void* host_memory_allocate(const SP_Device* const device, uint64_t size) { - return nullptr; -} -void host_memory_deallocate(const SP_Device* const device, void* mem) {} -TF_Bool get_allocator_stats(const SP_Device* const device, - SP_AllocatorStats* const stats) { - return true; -} -TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free, - int64_t* const total) { - return true; -} -void create_stream(const SP_Device* const device, SP_Stream* stream, - TF_Status* const status) { - stream = nullptr; -} -void destroy_stream(const SP_Device* const device, SP_Stream stream) {} -void create_stream_dependency(const SP_Device* const device, - SP_Stream dependent, SP_Stream other, - TF_Status* const status) {} -void get_stream_status(const SP_Device* const device, SP_Stream stream, - TF_Status* const status) {} -void create_event(const SP_Device* const device, SP_Event* event, - TF_Status* const status) { - event = nullptr; -} -void destroy_event(const SP_Device* const device, SP_Event event) {} -SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) { - return SE_EVENT_UNKNOWN; -} -void record_event(const SP_Device* const device, SP_Stream stream, - SP_Event event, TF_Status* const status) {} -void wait_for_event(const SP_Device* const device, SP_Stream stream, - SP_Event event, TF_Status* const status) {} -void create_timer(const SP_Device* const device, SP_Timer* timer, - TF_Status* const status) {} -void destroy_timer(const SP_Device* const device, SP_Timer timer) {} -void start_timer(const SP_Device* const device, SP_Stream stream, - SP_Timer timer, TF_Status* const status) {} -void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer, - TF_Status* const status) {} -void memcpy_dtoh(const SP_Device* const device, SP_Stream stream, - void* host_dst, const SP_DeviceMemoryBase* const device_src, - uint64_t size, TF_Status* const status) {} -void memcpy_htod(const SP_Device* const device, SP_Stream stream, - SP_DeviceMemoryBase* const device_dst, const void* host_src, - uint64_t size, TF_Status* const status) {} -void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst, - const SP_DeviceMemoryBase* const device_src, - uint64_t size, TF_Status* const status) {} -void sync_memcpy_htod(const SP_Device* const device, - SP_DeviceMemoryBase* const device_dst, - const void* host_src, uint64_t size, - TF_Status* const status) {} -void block_host_for_event(const SP_Device* const device, SP_Event event, - TF_Status* const status) {} -void synchronize_all_activity(const SP_Device* const device, - TF_Status* const status) {} -TF_Bool host_callback(const SP_Device* const device, SP_Stream stream, - SE_StatusCallbackFn const callback_fn, - void* const callback_arg) { - return true; -} - -void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) { - *se = {SP_STREAMEXECUTOR_STRUCT_SIZE}; - se->allocate = allocate; - se->deallocate = deallocate; - se->host_memory_allocate = host_memory_allocate; - se->host_memory_deallocate = host_memory_deallocate; - se->get_allocator_stats = get_allocator_stats; - se->device_memory_usage = device_memory_usage; - se->create_stream = create_stream; - se->destroy_stream = destroy_stream; - se->create_stream_dependency = create_stream_dependency; - se->get_stream_status = get_stream_status; - se->create_event = create_event; - se->destroy_event = destroy_event; - se->get_event_status = get_event_status; - se->record_event = record_event; - se->wait_for_event = wait_for_event; - se->create_timer = create_timer; - se->destroy_timer = destroy_timer; - se->start_timer = start_timer; - se->stop_timer = stop_timer; - se->memcpy_dtoh = memcpy_dtoh; - se->memcpy_htod = memcpy_htod; - se->sync_memcpy_dtoh = sync_memcpy_dtoh; - se->sync_memcpy_htod = sync_memcpy_htod; - se->block_host_for_event = block_host_for_event; - se->synchronize_all_activity = synchronize_all_activity; - se->host_callback = host_callback; -} - -void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) { - *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE}; -} - -/*** Create SP_TimerFns ***/ -uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; } - -void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) { - timer_fns->nanoseconds = nanoseconds; -} - -/*** Create SP_Platform ***/ -void create_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns, - TF_Status* status) { - TF_SetStatus(status, TF_OK, ""); - PopulateDefaultTimerFns(timer_fns); -} -void destroy_timer_fns(const SP_Platform* platform, SP_TimerFns* timer_fns) {} - -void create_stream_executor(const SP_Platform* platform, - SE_CreateStreamExecutorParams* params, - TF_Status* status) { - TF_SetStatus(status, TF_OK, ""); - PopulateDefaultStreamExecutor(params->stream_executor); -} -void destroy_stream_executor(const SP_Platform* platform, - SP_StreamExecutor* se) {} - -void create_device(const SP_Platform* platform, SE_CreateDeviceParams* params, - TF_Status* status) { - TF_SetStatus(status, TF_OK, ""); - params->device->struct_size = {SP_DEVICE_STRUCT_SIZE}; -} -void destroy_device(const SP_Platform* platform, SP_Device* device) {} - -void create_device_fns(const SP_Platform* platform, - SE_CreateDeviceFnsParams* params, TF_Status* status) { - TF_SetStatus(status, TF_OK, ""); - params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE}; -} -void destroy_device_fns(const SP_Platform* platform, SP_DeviceFns* device_fns) { -} - -void PopulateDefaultPlatform(SP_Platform* platform, - SP_PlatformFns* platform_fns) { - *platform = {SP_PLATFORM_STRUCT_SIZE}; - platform->name = kDeviceName; - platform->type = kDeviceType; - platform->visible_device_count = kDeviceCount; - platform_fns->create_device = create_device; - platform_fns->destroy_device = destroy_device; - platform_fns->create_device_fns = create_device_fns; - platform_fns->destroy_device_fns = destroy_device_fns; - platform_fns->create_stream_executor = create_stream_executor; - platform_fns->destroy_stream_executor = destroy_stream_executor; - platform_fns->create_timer_fns = create_timer_fns; - platform_fns->destroy_timer_fns = destroy_timer_fns; -} - -void destroy_platform(SP_Platform* const platform) {} -void destroy_platform_fns(SP_PlatformFns* const platform_fns) {} /*** Registration tests ***/ TEST(StreamExecutor, SuccessfulRegistration) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform, params->platform_fns); - params->destroy_platform = destroy_platform; - params->destroy_platform_fns = destroy_platform_fns; + test_util::PopulateDefaultPlatformRegistrationParams(params); }; - port::Status status = InitStreamExecutorPlugin(plugin_init); + std::string device_type, platform_name; + port::Status status = + InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); TF_ASSERT_OK(status); port::StatusOr maybe_platform = MultiPlatformManager::PlatformWithName("MY_DEVICE"); TF_ASSERT_OK(maybe_platform.status()); Platform* platform = maybe_platform.ConsumeValueOrDie(); - ASSERT_EQ(platform->Name(), kDeviceName); - ASSERT_EQ(platform->VisibleDeviceCount(), kDeviceCount); + ASSERT_EQ(platform->Name(), test_util::kDeviceName); + ASSERT_EQ(platform->VisibleDeviceCount(), test_util::kDeviceCount); port::StatusOr maybe_executor = platform->ExecutorForDevice(0); @@ -233,13 +55,13 @@ TEST(StreamExecutor, NameNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform, params->platform_fns); + test_util::PopulateDefaultPlatformRegistrationParams(params); params->platform->name = nullptr; - params->destroy_platform = destroy_platform; - params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = InitStreamExecutorPlugin(plugin_init); + std::string device_type, platform_name; + port::Status status = + InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); } @@ -248,13 +70,13 @@ TEST(StreamExecutor, InvalidNameWithSemicolon) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform, params->platform_fns); + test_util::PopulateDefaultPlatformRegistrationParams(params); params->platform->name = "INVALID:NAME"; - params->destroy_platform = destroy_platform; - params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = InitStreamExecutorPlugin(plugin_init); + std::string device_type, platform_name; + port::Status status = + InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT( status.error_message(), @@ -265,13 +87,13 @@ TEST(StreamExecutor, InvalidNameWithSlash) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform, params->platform_fns); + test_util::PopulateDefaultPlatformRegistrationParams(params); params->platform->name = "INVALID/"; - params->destroy_platform = destroy_platform; - params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = InitStreamExecutorPlugin(plugin_init); + std::string device_type, platform_name; + port::Status status = + InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); EXPECT_THAT(status.error_message(), testing::ContainsRegex("Device name/type 'INVALID/' must match")); @@ -281,13 +103,13 @@ TEST(StreamExecutor, CreateDeviceNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform, params->platform_fns); + test_util::PopulateDefaultPlatformRegistrationParams(params); params->platform_fns->create_device = nullptr; - params->destroy_platform = destroy_platform; - params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = InitStreamExecutorPlugin(plugin_init); + std::string device_type, platform_name; + port::Status status = + InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.error_message(), "'create_device' field in SP_PlatformFns must be set."); @@ -297,13 +119,13 @@ TEST(StreamExecutor, UnifiedMemoryAllocateNotSet) { auto plugin_init = [](SE_PlatformRegistrationParams* const params, TF_Status* const status) -> void { TF_SetStatus(status, TF_OK, ""); - PopulateDefaultPlatform(params->platform, params->platform_fns); + test_util::PopulateDefaultPlatformRegistrationParams(params); params->platform->supports_unified_memory = true; - params->destroy_platform = destroy_platform; - params->destroy_platform_fns = destroy_platform_fns; }; - port::Status status = InitStreamExecutorPlugin(plugin_init); + std::string device_type, platform_name; + port::Status status = + InitStreamExecutorPlugin(plugin_init, &device_type, &platform_name); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ( status.error_message(), @@ -315,18 +137,18 @@ class StreamExecutorTest : public ::testing::Test { protected: StreamExecutorTest() {} void SetUp() override { - PopulateDefaultPlatform(&platform_, &platform_fns_); - PopulateDefaultDeviceFns(&device_fns_); - PopulateDefaultStreamExecutor(&se_); - PopulateDefaultTimerFns(&timer_fns_); + test_util::PopulateDefaultPlatform(&platform_, &platform_fns_); + test_util::PopulateDefaultDeviceFns(&device_fns_); + test_util::PopulateDefaultStreamExecutor(&se_); + test_util::PopulateDefaultTimerFns(&timer_fns_); } void TearDown() override {} StreamExecutor* GetExecutor(int ordinal) { if (!cplatform_) { cplatform_ = absl::make_unique( - platform_, destroy_platform, platform_fns_, destroy_platform_fns, - device_fns_, se_, timer_fns_); + platform_, test_util::DestroyPlatform, platform_fns_, + test_util::DestroyPlatformFns, device_fns_, se_, timer_fns_); } port::StatusOr maybe_executor = cplatform_->ExecutorForDevice(ordinal); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc new file mode 100644 index 00000000000000..a3e210bc1c2846 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.cc @@ -0,0 +1,193 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h" + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +namespace stream_executor { +namespace test_util { + +/*** Functions for creating SP_StreamExecutor ***/ +void Allocate(const SP_Device* const device, uint64_t size, + int64_t memory_space, SP_DeviceMemoryBase* const mem) {} +void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) { +} +void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) { + return nullptr; +} +void HostMemoryDeallocate(const SP_Device* const device, void* mem) {} +TF_Bool GetAllocatorStats(const SP_Device* const device, + SP_AllocatorStats* const stats) { + return true; +} +TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free, + int64_t* const total) { + return true; +} +void CreateStream(const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) { + stream = nullptr; +} +void DestroyStream(const SP_Device* const device, SP_Stream stream) {} +void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent, + SP_Stream other, TF_Status* const status) {} +void GetStreamStatus(const SP_Device* const device, SP_Stream stream, + TF_Status* const status) {} +void CreateEvent(const SP_Device* const device, SP_Event* event, + TF_Status* const status) { + event = nullptr; +} +void DestroyEvent(const SP_Device* const device, SP_Event event) {} +SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) { + return SE_EVENT_UNKNOWN; +} +void RecordEvent(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} +void WaitForEvent(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} +void CreateTimer(const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) {} +void DestroyTimer(const SP_Device* const device, SP_Timer timer) {} +void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer, + TF_Status* const status) {} +void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer, + TF_Status* const status) {} +void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst, + const SP_DeviceMemoryBase* const device_src, uint64_t size, + TF_Status* const status) {} +void MemcpyHToD(const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, const void* host_src, + uint64_t size, TF_Status* const status) {} +void SyncMemcpyDToH(const SP_Device* const device, void* host_dst, + const SP_DeviceMemoryBase* const device_src, uint64_t size, + TF_Status* const status) {} +void SyncMemcpyHToD(const SP_Device* const device, + SP_DeviceMemoryBase* const device_dst, const void* host_src, + uint64_t size, TF_Status* const status) {} +void BlockHostForEvent(const SP_Device* const device, SP_Event event, + TF_Status* const status) {} +void SynchronizeAllActivity(const SP_Device* const device, + TF_Status* const status) {} +TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) { + return true; +} + +void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) { + *se = {SP_STREAMEXECUTOR_STRUCT_SIZE}; + se->allocate = Allocate; + se->deallocate = Deallocate; + se->host_memory_allocate = HostMemoryAllocate; + se->host_memory_deallocate = HostMemoryDeallocate; + se->get_allocator_stats = GetAllocatorStats; + se->device_memory_usage = DeviceMemoryUsage; + se->create_stream = CreateStream; + se->destroy_stream = DestroyStream; + se->create_stream_dependency = CreateStreamDependency; + se->get_stream_status = GetStreamStatus; + se->create_event = CreateEvent; + se->destroy_event = DestroyEvent; + se->get_event_status = GetEventStatus; + se->record_event = RecordEvent; + se->wait_for_event = WaitForEvent; + se->create_timer = CreateTimer; + se->destroy_timer = DestroyTimer; + se->start_timer = StartTimer; + se->stop_timer = StopTimer; + se->memcpy_dtoh = MemcpyDToH; + se->memcpy_htod = MemcpyHToD; + se->sync_memcpy_dtoh = SyncMemcpyDToH; + se->sync_memcpy_htod = SyncMemcpyHToD; + se->block_host_for_event = BlockHostForEvent; + se->synchronize_all_activity = SynchronizeAllActivity; + se->host_callback = HostCallback; +} + +void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) { + *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE}; +} + +/*** Functions for creating SP_TimerFns ***/ +uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; } + +void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) { + timer_fns->nanoseconds = Nanoseconds; +} + +/*** Functions for creating SP_Platform ***/ +void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultTimerFns(timer_fns); +} +void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {} + +void CreateStreamExecutor(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultStreamExecutor(params->stream_executor); +} +void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) { +} +void GetDeviceCount(const SP_Platform* platform, int* device_count, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + *device_count = kDeviceCount; +} +void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + params->device->struct_size = {SP_DEVICE_STRUCT_SIZE}; +} +void DestroyDevice(const SP_Platform* platform, SP_Device* device) {} + +void CreateDeviceFns(const SP_Platform* platform, + SE_CreateDeviceFnsParams* params, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE}; +} +void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {} + +void PopulateDefaultPlatform(SP_Platform* platform, + SP_PlatformFns* platform_fns) { + *platform = {SP_PLATFORM_STRUCT_SIZE}; + platform->name = kDeviceName; + platform->type = kDeviceType; + platform_fns->get_device_count = GetDeviceCount; + platform_fns->create_device = CreateDevice; + platform_fns->destroy_device = DestroyDevice; + platform_fns->create_device_fns = CreateDeviceFns; + platform_fns->destroy_device_fns = DestroyDeviceFns; + platform_fns->create_stream_executor = CreateStreamExecutor; + platform_fns->destroy_stream_executor = DestroyStreamExecutor; + platform_fns->create_timer_fns = CreateTimerFns; + platform_fns->destroy_timer_fns = DestroyTimerFns; +} + +/*** Functions for creating SE_PlatformRegistrationParams ***/ +void DestroyPlatform(SP_Platform* platform) {} +void DestroyPlatformFns(SP_PlatformFns* platform_fns) {} + +void PopulateDefaultPlatformRegistrationParams( + SE_PlatformRegistrationParams* const params) { + PopulateDefaultPlatform(params->platform, params->platform_fns); + params->destroy_platform = DestroyPlatform; + params->destroy_platform_fns = DestroyPlatformFns; +} + +} // namespace test_util +} // namespace stream_executor diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h new file mode 100644 index 00000000000000..0bebf6f47b2d5d --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h @@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_ + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +struct SP_Stream_st { + explicit SP_Stream_st(int id) : stream_id(id) {} + int stream_id; +}; + +struct SP_Event_st { + explicit SP_Event_st(int id) : event_id(id) {} + int event_id; +}; + +struct SP_Timer_st { + explicit SP_Timer_st(int id) : timer_id(id) {} + int timer_id; +}; + +namespace stream_executor { +namespace test_util { + +constexpr int kDeviceCount = 2; +constexpr char kDeviceName[] = "MY_DEVICE"; +constexpr char kDeviceType[] = "GPU"; + +void PopulateDefaultStreamExecutor(SP_StreamExecutor* se); +void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns); +void PopulateDefaultTimerFns(SP_TimerFns* timer_fns); +void PopulateDefaultPlatform(SP_Platform* platform, + SP_PlatformFns* platform_fns); +void PopulateDefaultPlatformRegistrationParams( + SE_PlatformRegistrationParams* const params); + +void DestroyPlatform(SP_Platform* platform); +void DestroyPlatformFns(SP_PlatformFns* platform_fns); + +} // namespace test_util +} // namespace stream_executor + +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_ diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD new file mode 100644 index 00000000000000..c13639fdd94867 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/test/BUILD @@ -0,0 +1,20 @@ +# Description: +# test for stream_executor +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_shared_object", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +tf_cc_shared_object( + name = "test_pluggable_device.so", + srcs = ["test_pluggable_device.cc"], + visibility = ["//tensorflow/c:__subpackages__"], + deps = [ + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", + "//tensorflow/c/experimental/stream_executor:stream_executor_test_util", + ], +) diff --git a/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc new file mode 100644 index 00000000000000..a63078184a8771 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/test/test_pluggable_device.cc @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h" + +extern "C" { + +void SE_InitPlugin(SE_PlatformRegistrationParams* const params, + TF_Status* const status) { + stream_executor::test_util::PopulateDefaultPlatformRegistrationParams(params); +} + +void TF_InitKernel() {} +} diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index ed501b5b10137a..329e336a008327 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -24,8 +24,15 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" +// Required for IS_MOBILE_PLATFORM definition +#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/stream.h" +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +using tensorflow::errors::InvalidArgument; // This file forms the basis of a stable ABI for third-party kernel // implementations. It is crucial that changes to this file are made cautiously // and with a focus on maintaining both source and binary compatibility. @@ -74,6 +81,9 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, // TF_CALL_ALL_TYPES macro can find tensorflow::string as string. switch (dtype) { TF_CALL_ALL_TYPES(CASE); + TF_CALL_QUANTIZED_TYPES(CASE); + TF_CALL_quint16(CASE); + TF_CALL_qint16(CASE); default: status->status = errors::Unimplemented("Unexpected type ", dtype); return; @@ -81,9 +91,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, TF_SetStatus(status, TF_OK, ""); } #undef CASE + } // namespace } // namespace tensorflow +namespace { +const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx, + const char* attr_name, + TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + const tensorflow::AttrValue* attr = + ::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name); + if (attr == nullptr) { + status->status = InvalidArgument("Operation '", cc_ctx->def().name(), + "' has no attr named '", attr_name, "'."); + } + return attr; +} +} // namespace + void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name, const TF_DataType type, @@ -168,6 +194,35 @@ void TF_RegisterKernelBuilder(const char* name, TF_KernelBuilder* builder, TF_SetStatus(status, TF_OK, ""); } +// This function is only for pluggable device. +// It will return nullptr in all other cases. +// This function is experimental and subject to change. +SP_Stream TF_GetStream(TF_OpKernelContext* ctx, TF_Status* status) { +#if defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) + status->status = tensorflow::errors::Unimplemented( + "Accessing device stream is not supported on mobile. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); + return nullptr; +#else + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); + if (cc_ctx->op_device_context() == nullptr) { // CPU Device + status->status = tensorflow::errors::FailedPrecondition( + "Accessing device stream is not supported for a CPU device."); + return nullptr; + } else if (!cc_ctx->op_device_context()->IsPluggableDevice()) { + status->status = tensorflow::errors::FailedPrecondition( + "Accessing device stream is only supported for pluggable devices."); + return nullptr; + } else { // Is a PluggableDevice + TF_SetStatus(status, TF_OK, ""); + auto c_stream = static_cast( + cc_ctx->op_device_context()->stream()->implementation()); + return c_stream->Handle(); + } +#endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) +} + int TF_NumInputs(TF_OpKernelContext* ctx) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); return cc_ctx->num_inputs(); @@ -222,7 +277,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { cc_ctx->CtxFailure(s); } -#define DEFINE_TF_GETATTR(func, c_type, cc_type) \ +void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx, + const char* attr_name, + int32_t* list_size, + int32_t* total_size, + TF_Status* status) { + const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); + if (!status->status.ok()) { + *list_size = -1; + *total_size = -1; + return; + } + switch (attr->value_case()) { +#define SINGLE_CASE(kK, attr_type, size_expr) \ + case tensorflow::AttrValue::kK: \ + *list_size = -1; \ + *total_size = size_expr; \ + break; + + SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length()); + SINGLE_CASE(kI, TF_ATTR_INT, -1); + SINGLE_CASE(kF, TF_ATTR_FLOAT, -1); + SINGLE_CASE(kB, TF_ATTR_BOOL, -1); + SINGLE_CASE(kType, TF_ATTR_TYPE, -1); + SINGLE_CASE(kShape, TF_ATTR_SHAPE, + attr->shape().unknown_rank() ? -1 : attr->shape().dim_size()); + SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1); +#undef SINGLE_CASE + + case tensorflow::AttrValue::kList: + *list_size = 0; + *total_size = -1; +#define LIST_CASE(field, attr_type, ...) \ + if (attr->list().field##_size() > 0) { \ + *list_size = attr->list().field##_size(); \ + __VA_ARGS__; \ + break; \ + } + + LIST_CASE( + s, TF_ATTR_STRING, *total_size = 0; + for (int i = 0; i < attr->list().s_size(); + ++i) { *total_size += attr->list().s(i).size(); }); + LIST_CASE(i, TF_ATTR_INT); + LIST_CASE(f, TF_ATTR_FLOAT); + LIST_CASE(b, TF_ATTR_BOOL); + LIST_CASE(type, TF_ATTR_TYPE); + LIST_CASE( + shape, TF_ATTR_SHAPE, *total_size = 0; + for (int i = 0; i < attr->list().shape_size(); ++i) { + const auto& s = attr->list().shape(i); + *total_size += s.unknown_rank() ? 0 : s.dim_size(); + }); + LIST_CASE(tensor, TF_ATTR_TENSOR); + LIST_CASE(tensor, TF_ATTR_FUNC); +#undef LIST_CASE + break; + + case tensorflow::AttrValue::kPlaceholder: + *list_size = -1; + *total_size = -1; + break; + + case tensorflow::AttrValue::kFunc: + *list_size = -1; + *total_size = -1; + break; + + case tensorflow::AttrValue::VALUE_NOT_SET: + status->status = + InvalidArgument("Attribute '", attr_name, "' has no value set"); + break; + } +} + +#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \ void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \ const char* attr_name, \ c_type* val, TF_Status* status) { \ @@ -234,10 +363,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { if (s.ok()) { \ *val = static_cast(v); \ } \ + } \ + void TF_OpKernelConstruction_GetAttr##func##List( \ + TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \ + int max_vals, TF_Status* status) { \ + TF_SetStatus(status, TF_OK, ""); \ + const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \ + if (!status->status.ok()) return; \ + if (attr->value_case() != tensorflow::AttrValue::kList) { \ + status->status = \ + InvalidArgument("Value for '", attr_name, "' is not a list."); \ + return; \ + } \ + status->status = \ + tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \ + if (!status->status.ok()) return; \ + const auto len = std::min(max_vals, attr->list().list_field##_size()); \ + for (int i = 0; i < len; ++i) { \ + vals[i] = static_cast(attr->list().list_field(i)); \ + } \ + } + +DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type) +DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i) +DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i) +DEFINE_TF_GETATTR(Float, float, float, "float", f) +DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b) + +void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx, + const char* attr_name, char* value, + size_t max_length, + TF_Status* status) { + std::string v; + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (!status->status.ok()) return; + + if (max_length <= 0) { + return; + } + std::memcpy(value, v.data(), std::min(v.length(), max_length)); +} + +void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx, + const char* attr_name, + char** values, size_t* lengths, + int max_values, void* storage, + size_t storage_size, + TF_Status* status) { + std::vector v; + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v); + ::tensorflow::Set_TF_Status_from_Status(status, s); + + if (!status->status.ok()) return; + + const auto len = std::min(max_values, static_cast(v.size())); + char* p = static_cast(storage); + for (int i = 0; i < len; ++i) { + const std::string& s = v[i]; + values[i] = p; + lengths[i] = s.size(); + if ((p + s.size()) > (static_cast(storage) + storage_size)) { + status->status = InvalidArgument( + "Not enough storage to hold the requested list of strings"); + return; + } + memcpy(values[i], s.data(), s.size()); + p += s.size(); } +} -DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) -DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t) +bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx, + const char* attr_name, TF_Status* status) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); + return cc_ctx->HasAttr(attr_name); +} TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) { auto* cc_ctx = reinterpret_cast(ctx); diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 489aa5399a5266..508d59b1223442 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" @@ -65,6 +66,11 @@ typedef struct TF_KernelBuilder TF_KernelBuilder; typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; typedef struct TF_OpKernelContext TF_OpKernelContext; +// TF_InitKernel to do op/kernel registration. +// Plugin should implement TF_InitKernel to register kernels. This function +// should register all kernels in a plugin. +void TF_InitKernel(); + // Allocates a new kernel builder and returns a pointer to it. // // If non-null, TensorFlow will call create_func when it needs to instantiate @@ -128,6 +134,16 @@ TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); // -------------------------------------------------------------------------- // OpKernelContext routines +// TF_GetStream returns the SP_Stream available in ctx. +// This function returns a stream only for devices registered using the +// StreamExecutor C API +// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return +// nullptr and set error status in all other cases. +// Experimental: this function doesn't have compatibility guarantees and subject +// to change at any time. +TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx, + TF_Status* status); + // TF_NumInputs returns the number of inputs available in ctx. TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); @@ -168,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( // Returns the step ID of the given context. TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); +// Get the list_size and total_size of the attribute `attr_name` of `oper`. +// list_size - the length of the list. +// total_size - total size of the list. +// (1) If attr_type == TF_ATTR_STRING +// then total_size is the cumulative byte size +// of all the strings in the list. +// (3) If attr_type == TF_ATTR_SHAPE +// then total_size is the number of dimensions +// of the shape valued attribute, or -1 +// if its rank is unknown. +// (4) If attr_type == TF_ATTR_SHAPE +// then total_size is the cumulative number +// of dimensions of all shapes in the list. +// (5) Otherwise, total_size is undefined. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size, + int32_t* total_size, TF_Status* status); + // Interprets the named kernel construction attribute as a TF_DataType and // places it into *val. *status is set to TF_OK. // @@ -186,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_Status* status); +// Interprets the named kernel construction attribute as int64_t and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// int64, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as float and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// float, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat( + TF_OpKernelConstruction* ctx, const char* attr_name, float* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as bool and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// bool, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as string and +// places it into *val. `val` must +// point to an array of length at least `max_length` (ideally set to +// total_size from TF_OpKernelConstruction_GetAttrSize(ctx, +// attr_name, list_size, total_size)). *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// string, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString( + TF_OpKernelConstruction* ctx, const char* attr_name, char* val, + size_t max_length, TF_Status* status); + +// Interprets the named kernel construction attribute as a TF_DataType array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as int32_t array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as int64_t array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as float array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList( + TF_OpKernelConstruction* ctx, const char* attr_name, float* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as bool array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as string array and fills +// in `vals` and `lengths`, each of which must point to an array of length at +// least `max_values`. *status is set to TF_OK. The elements of values will +// point to addresses in `storage` which must be at least `storage_size` bytes +// in length. Ideally, max_values would be set to list_size and `storage` would +// be at least total_size, obtained from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList( + TF_OpKernelConstruction* ctx, const char* attr_name, char** vals, + size_t* lengths, int max_values, void* storage, size_t storage_size, + TF_Status* status); + +// Return true if the kernel construction has the attr_name +TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); + // Returns the unique operation name for this OpKernel. TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( TF_OpKernelConstruction* ctx); diff --git a/tensorflow/c/kernels/bitcast_op.cc b/tensorflow/c/kernels/bitcast_op.cc index c194dcd686bd47..c6468e0ab80f6b 100644 --- a/tensorflow/c/kernels/bitcast_op.cc +++ b/tensorflow/c/kernels/bitcast_op.cc @@ -148,7 +148,7 @@ void RegisterBitcastOpKernel() { << "Error while registering bitcast kernel"; } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM { auto* builder = TF_NewKernelBuilder("Bitcast", tensorflow::DEVICE_GPU, &BitcastOp_Create, &BitcastOp_Compute, diff --git a/tensorflow/c/kernels/summary_op_benchmark_test.cc b/tensorflow/c/kernels/summary_op_benchmark_test.cc index 887a86066d3e2e..b65862063f81a6 100644 --- a/tensorflow/c/kernels/summary_op_benchmark_test.cc +++ b/tensorflow/c/kernels/summary_op_benchmark_test.cc @@ -49,14 +49,12 @@ Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, float value) { constexpr char longTagParam[] = "LONGTAG____________________________"; constexpr float largeValueParam = 2352352.2623433; -#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ - void BM_ScalarSummary##name##device(int iters) { \ - testing::StopTiming(); \ - TensorShape tensorshape(DIMARGS dims); \ - auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ - } \ +#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ + void BM_ScalarSummary##name##device(::testing::benchmark::State& state) { \ + TensorShape tensorshape(DIMARGS dims); \ + auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \ + test::Benchmark("cpu", g, /*old_benchmark_api=*/false).Run(state); \ + } \ BENCHMARK(BM_ScalarSummary##name##device); BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2); diff --git a/tensorflow/c/kernels/summary_op_test.cc b/tensorflow/c/kernels/summary_op_test.cc index 68c8deb5eab1a7..fede040f2f39d3 100644 --- a/tensorflow/c/kernels/summary_op_test.cc +++ b/tensorflow/c/kernels/summary_op_test.cc @@ -44,7 +44,7 @@ class DummyDevice : public DeviceBase { } }; -// Helper for comparing ouput and expected output +// Helper for comparing output and expected output void ExpectSummaryMatches(const Summary& actual, const string& expected_str) { Summary expected; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected)); diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index c9df2cc34d13ec..4fc5e46c1352d8 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/strings/str_format.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/c/c_api.h" #include "tensorflow/c/tf_datatype.h" @@ -161,6 +162,336 @@ TEST(TestKernel, TestRegisterKernelBuilder) { ASSERT_TRUE(delete_called); } +// REGISTER_OP for TF_OpKernelConstruction_GetAttr* test cases. +// Registers two ops, each with a single attribute called 'Attr'. +// The attribute in one op will have a type 'type', the other +// will have list(type). +#define ATTR_TEST_REGISTER_OP(name, type) \ + REGISTER_OP("TestKernelAttr" #name) \ + .Attr("Attr: " #type) \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape); \ + REGISTER_OP("TestKernelAttr" #name "List") \ + .Attr("Attr: list(" #type ")") \ + .SetShapeFn(tensorflow::shape_inference::UnknownShape) +ATTR_TEST_REGISTER_OP(String, string); +ATTR_TEST_REGISTER_OP(Int, int); +ATTR_TEST_REGISTER_OP(Float, float); +ATTR_TEST_REGISTER_OP(Bool, bool); +ATTR_TEST_REGISTER_OP(Type, type); +#undef ATTR_TEST_REGISTER_OP + +// Helper macros for the TF_OpKernelConstruction_GetAttr* tests. +#define EXPECT_TF_SIZE(attr_name, expected_list_size, expected_total_size) \ + do { \ + int32_t list_size, total_size; \ + TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, \ + &total_size, status); \ + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); \ + EXPECT_EQ(expected_list_size, list_size); \ + EXPECT_EQ(expected_total_size, total_size); \ + } while (0) + +typedef void* (*MyCreateFuncWithAttr)(TF_OpKernelConstruction*); +class TestKernelAttr : public ::testing::Test { + public: + TestKernelAttr() {} + ~TestKernelAttr() override {} + + std::unique_ptr GetFakeKernelWithAttr(const char* op_name, + AttrValue v, Status* status) { + NodeDef def; + def.set_op(op_name); + def.set_name("FakeNode"); + def.set_device("FakeDevice"); + (*def.mutable_attr())["Attr"] = v; + return CreateOpKernel(DeviceType("FakeDevice"), nullptr, nullptr, def, 1, + status); + } + + void CreateAndCallKernelWithAttr(MyCreateFuncWithAttr MyCreateFuncAttr, + const char* op_name, AttrValue& v) { + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, "FakeDevice", MyCreateFuncAttr, &MyComputeFunc, &MyDeleteFunc); + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder("FakeNode", builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + Status status; + std::unique_ptr kernel = + GetFakeKernelWithAttr(op_name, v, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + kernel->Compute(nullptr); + + ASSERT_TRUE(delete_called); + } +}; + +TEST_F(TestKernelAttr, String) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + std::unique_ptr val(new char[5]); + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ 5); + TF_OpKernelConstruction_GetAttrString(ctx, "Attr", val.get(), + /*max_length*/ 5, status); + + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ("bunny", string(static_cast(val.get()), 5)); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_s("bunny"); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrString", v); +} + +TEST_F(TestKernelAttr, StringList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + std::vector list = {"bugs", "bunny", "duck"}; + int list_total_size = 0; + for (const auto& s : list) { + list_total_size += s.size(); + } + + TF_Status* status = TF_NewStatus(); + std::unique_ptr values(new char*[list.size()]); + std::unique_ptr lens(new size_t[list.size()]); + std::unique_ptr storage(new char[list_total_size]); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list.size(), + /*expected_total_size*/ list_total_size); + TF_OpKernelConstruction_GetAttrStringList( + ctx, "Attr", values.get(), lens.get(), list.size(), storage.get(), + list_total_size, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + for (size_t i = 0; i < list.size(); ++i) { + EXPECT_EQ(list[i].size(), lens[i]) << i; + EXPECT_EQ(list[i], string(static_cast(values[i]), lens[i])) + << i; + } + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + std::string attr_in[] = {"bugs", "bunny", "duck"}; + SetAttrValue(gtl::ArraySlice(attr_in, 3), &v); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrStringList", v); +} + +TEST_F(TestKernelAttr, Int) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + int64_t val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrInt64(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(1234, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_i(1234); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrInt", v); +} + +TEST_F(TestKernelAttr, IntList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const int64_t list[] = {1, 2, 3, 4}; + const size_t list_size = TF_ARRAYSIZE(list); + int64_t values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrInt64List(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + int64 attr_in[] = {1, 2, 3, 4}; + SetAttrValue(gtl::ArraySlice(attr_in, 4), &v); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrIntList", v); +} + +TEST_F(TestKernelAttr, Float) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + float val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrFloat(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_FLOAT_EQ(2.718, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_f(2.718); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloat", v); +} + +TEST_F(TestKernelAttr, FloatList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const float list[] = {1.414, 2.718, 3.1415}; + const size_t list_size = TF_ARRAYSIZE(list); + float values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrFloatList(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + float attr_in[] = {1.414, 2.718, 3.1415}; + SetAttrValue(gtl::ArraySlice(attr_in, 3), &v); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrFloatList", v); +} + +TEST_F(TestKernelAttr, Bool) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + unsigned char val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrBool(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(1, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_b(true); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBool", v); +} + +TEST_F(TestKernelAttr, BoolList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const unsigned char list[] = {1, 0, 1, 0}; + const size_t list_size = TF_ARRAYSIZE(list); + unsigned char values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrBoolList(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + bool attr_in[] = {true, false, true, false}; + SetAttrValue(gtl::ArraySlice(attr_in, 4), &v); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrBoolList", v); +} + +TEST_F(TestKernelAttr, Type) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + TF_DataType val; + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ -1, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrType(ctx, "Attr", &val, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_EQ(TF_FLOAT, val); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + v.set_type(DT_FLOAT); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrType", v); +} + +TEST_F(TestKernelAttr, TypeList) { + auto my_create_func = [](TF_OpKernelConstruction* ctx) { + struct MyCustomKernel* s = new struct MyCustomKernel; + s->created = true; + s->compute_called = false; + + const TF_DataType list[] = {TF_FLOAT, TF_DOUBLE, TF_HALF, TF_COMPLEX128}; + const size_t list_size = TF_ARRAYSIZE(list); + TF_DataType values[list_size]; + + TF_Status* status = TF_NewStatus(); + EXPECT_TF_SIZE(/*attr_name*/ "Attr", /*expected_list_size*/ list_size, + /*expected_total_size*/ -1); + TF_OpKernelConstruction_GetAttrTypeList(ctx, "Attr", values, list_size, + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + EXPECT_TRUE( + std::equal(std::begin(list), std::end(list), std::begin(values))); + TF_DeleteStatus(status); + return static_cast(s); + }; + + AttrValue v; + DataType attr_in[] = {DT_FLOAT, DT_DOUBLE, DT_HALF, DT_COMPLEX128}; + SetAttrValue(gtl::ArraySlice(attr_in, 4), &v); + CreateAndCallKernelWithAttr(my_create_func, "TestKernelAttrTypeList", v); +} +#undef EXPECT_TF_SIZE + class DummyDevice : public DeviceBase { public: explicit DummyDevice(Env* env) : DeviceBase(env) {} @@ -259,50 +590,74 @@ TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) { TF_DeleteKernelBuilder(nullptr); } -TEST(TestKernel, TestTypeConstraint) { - const char* node_name = "SomeNodeName"; - const char* op_name = "TypeOp"; - const char* device_name = "FakeDeviceName1"; - - REGISTER_OP(op_name) - .Input("input1: double") - .Input("input2: uint8") - .Output("output1: uint8") - .Attr("T: type"); - - TF_KernelBuilder* builder = TF_NewKernelBuilder( - op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); - TF_Status* status = TF_NewStatus(); - TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_RegisterKernelBuilder(node_name, builder, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - - TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - KernelList list; - list.ParseFromArray(buf->data, buf->length); - const auto expected_str = R"str(kernel { - op: "TypeOp" +std::string ExpectedString(const char* type) { + const auto format_str = R"str(kernel { + op: "TypeOp%s" device_type: "FakeDeviceName1" constraint { name: "T" allowed_values { list { - type: DT_INT32 + type: %s } } } } )str"; - ASSERT_EQ(expected_str, list.DebugString()); - - TF_DeleteBuffer(buf); - TF_DeleteStatus(status); - TF_DeleteKernelBuilder(builder); - ASSERT_TRUE(delete_called); + return absl::StrFormat(format_str, type, type); } +#define TEST_KERNEL_TYPE_CONSTRAINT(tf_type, dtype) \ + TEST(TestKernel, TestTypeConstraint##tf_type) { \ + const char* node_name = "SomeNodeName"; \ + const char* op_name = "TypeOp" #dtype; \ + const char* device_name = "FakeDeviceName1"; \ + \ + REGISTER_OP(op_name) \ + .Input("input1: double") \ + .Input("input2: uint8") \ + .Output("output1: uint8") \ + .Attr("T: type"); \ + \ + TF_KernelBuilder* builder = TF_NewKernelBuilder( \ + op_name, device_name, &MyCreateFunc, &MyComputeFunc, &MyDeleteFunc); \ + TF_Status* status = TF_NewStatus(); \ + TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::tf_type, \ + status); \ + EXPECT_EQ(TF_OK, TF_GetCode(status)); \ + TF_RegisterKernelBuilder(node_name, builder, status); \ + EXPECT_EQ(TF_OK, TF_GetCode(status)); \ + \ + TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); \ + EXPECT_EQ(TF_OK, TF_GetCode(status)); \ + KernelList list; \ + list.ParseFromArray(buf->data, buf->length); \ + ASSERT_EQ(ExpectedString(#dtype), list.DebugString()); \ + \ + TF_DeleteBuffer(buf); \ + TF_DeleteStatus(status); \ + TF_DeleteKernelBuilder(builder); \ + ASSERT_TRUE(delete_called); \ + } + +TEST_KERNEL_TYPE_CONSTRAINT(TF_HALF, DT_HALF); +TEST_KERNEL_TYPE_CONSTRAINT(TF_BFLOAT16, DT_BFLOAT16); +TEST_KERNEL_TYPE_CONSTRAINT(TF_FLOAT, DT_FLOAT); +TEST_KERNEL_TYPE_CONSTRAINT(TF_DOUBLE, DT_DOUBLE); +TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT64, DT_UINT64); +TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT32, DT_UINT32); +TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT16, DT_UINT16); +TEST_KERNEL_TYPE_CONSTRAINT(TF_UINT8, DT_UINT8); +TEST_KERNEL_TYPE_CONSTRAINT(TF_INT8, DT_INT8); +TEST_KERNEL_TYPE_CONSTRAINT(TF_INT32, DT_INT32); +TEST_KERNEL_TYPE_CONSTRAINT(TF_COMPLEX64, DT_COMPLEX64); +TEST_KERNEL_TYPE_CONSTRAINT(TF_COMPLEX128, DT_COMPLEX128); +TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT8, DT_QINT8); +TEST_KERNEL_TYPE_CONSTRAINT(TF_QUINT8, DT_QUINT8); +TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT32, DT_QINT32); +TEST_KERNEL_TYPE_CONSTRAINT(TF_QINT16, DT_QINT16); +TEST_KERNEL_TYPE_CONSTRAINT(TF_QUINT16, DT_QUINT16); + TEST(TestKernel, TestHostMemory) { const char* node_name = "SomeNodeName"; const char* op_name = "HostMemoryOp"; @@ -352,7 +707,7 @@ class DeviceKernelOpTest : public OpsTestBase { EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_DeleteStatus(status); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM std::unique_ptr device( DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0")); OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); @@ -361,7 +716,7 @@ class DeviceKernelOpTest : public OpsTestBase { TF_ASSERT_OK(InitOp()); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM const char* device_name_ = tensorflow::DEVICE_GPU; #else const char* device_name_ = tensorflow::DEVICE_CPU; @@ -378,6 +733,23 @@ template void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, TF_OpKernelContext* ctx); +REGISTER_OP("StreamOp").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestStream) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + TF_Status* s = TF_NewStatus(); + SP_Stream stream = TF_GetStream(ctx, s); + // Stream is always null if device is not a pluggable device. More test + // cases will be added when pluggable device mechanism is supported. + EXPECT_EQ(stream, nullptr); + EXPECT_NE(TF_OK, TF_GetCode(s)); + TF_DeleteStatus(s); + }; + + SetupOp("StreamOp", "StreamOp", my_compute_func); + TF_ASSERT_OK(RunOpKernel()); +} + REGISTER_OP("AllocateOutputOp1").Output("output1: float"); TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) { @@ -468,7 +840,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) { int64_t dim = 1; TF_AllocatorAttributes alloc_attrs; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM alloc_attrs.on_host = 0; #else alloc_attrs.on_host = 1; @@ -505,7 +877,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) { int64_t dim = 0; TF_AllocatorAttributes alloc_attrs; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM alloc_attrs.on_host = 0; #else alloc_attrs.on_host = 1; @@ -538,7 +910,7 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) { int64_t dim[2] = {2, 3}; TF_AllocatorAttributes alloc_attrs; alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM alloc_attrs.on_host = 0; #else alloc_attrs.on_host = 1; @@ -646,7 +1018,7 @@ template void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes, TF_OpKernelContext* ctx) { T* data = reinterpret_cast(TF_TensorData(tensor)); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM OpKernelContext* cc_ctx = reinterpret_cast(ctx); cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values, tensor_size_bytes); diff --git a/tensorflow/c/tensor_interface.h b/tensorflow/c/tensor_interface.h index d165c84980cb14..0b352f561f7b59 100644 --- a/tensorflow/c/tensor_interface.h +++ b/tensorflow/c/tensor_interface.h @@ -50,6 +50,8 @@ class AbstractTensorInterface { // Returns if their is sole ownership of this Tensor and thus it can be moved. virtual bool CanMove() const = 0; + virtual std::string SummarizeValue() const = 0; + protected: virtual ~AbstractTensorInterface() {} }; diff --git a/tensorflow/c/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc index e0097e88019ab3..7abd28b25a43e1 100644 --- a/tensorflow/c/tf_status_helper.cc +++ b/tensorflow/c/tf_status_helper.cc @@ -79,6 +79,7 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status) { assert(0); break; } + tf_status->status.ReplaceAllPayloads(status.GetAllPayloads()); } Status StatusFromTF_Status(const TF_Status* tf_status) { diff --git a/tensorflow/c/tf_status_helper_test.cc b/tensorflow/c/tf_status_helper_test.cc index 60780d74b2143d..0bd9d1e4e3c747 100644 --- a/tensorflow/c/tf_status_helper_test.cc +++ b/tensorflow/c/tf_status_helper_test.cc @@ -24,6 +24,8 @@ namespace { TEST(StatusHelper, TestStatusHelper) { TF_Status* s = TF_NewStatus(); Status cc_status(errors::InvalidArgument("some error")); + cc_status.SetPayload("key1", "value1"); + cc_status.SetPayload("key2", "value2"); Set_TF_Status_from_Status(s, cc_status); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); ASSERT_EQ(std::string("some error"), TF_Message(s)); @@ -32,6 +34,9 @@ TEST(StatusHelper, TestStatusHelper) { ASSERT_FALSE(another_cc_status.ok()); ASSERT_EQ(std::string("some error"), another_cc_status.error_message()); ASSERT_EQ(error::INVALID_ARGUMENT, another_cc_status.code()); + // Ensure the payloads are not lost during conversions + ASSERT_EQ(cc_status.GetPayload("key1"), another_cc_status.GetPayload("key1")); + ASSERT_EQ(cc_status.GetPayload("key2"), another_cc_status.GetPayload("key2")); TF_DeleteStatus(s); } diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 39d2683226fcb8..35f308c2a4c6e1 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -196,6 +196,10 @@ bool TensorInterface::CanMove() const { return false; } +std::string TensorInterface::SummarizeValue() const { + return tensor_.SummarizeValue(/*max_entries=*/3, /*print_v2=*/true); +} + DataType TensorInterface::Type() const { return tensor_.dtype(); } int TensorInterface::NumDims() const { return tensor_.dims(); } diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h index 7a896dc5d11c2b..fafcafa7ab8391 100644 --- a/tensorflow/c/tf_tensor_internal.h +++ b/tensorflow/c/tf_tensor_internal.h @@ -104,6 +104,7 @@ class TensorInterface : public AbstractTensorInterface { void* Data() const override; bool IsAligned() const override; bool CanMove() const override; + std::string SummarizeValue() const override; Status ToTensor(tensorflow::Tensor* dst) const; Status BitcastFrom(const TensorInterface& from, DataType type, diff --git a/tensorflow/c/tf_tstring.cc b/tensorflow/c/tf_tstring.cc new file mode 100644 index 00000000000000..f5f32bf3d0cd15 --- /dev/null +++ b/tensorflow/c/tf_tstring.cc @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/tf_tstring.h" + +#include "tensorflow/core/platform/ctstring_internal.h" + +void TF_StringInit(TF_TString *tstr) { TF_TString_Init(tstr); } + +void TF_StringCopy(TF_TString *dst, const char *src, size_t size) { + TF_TString_Copy(dst, src, size); +} + +void TF_StringAssignView(TF_TString *dst, const char *src, size_t size) { + TF_TString_AssignView(dst, src, size); +} + +const char *TF_StringGetDataPointer(const TF_TString *tstr) { + return TF_TString_GetDataPointer(tstr); +} + +TF_TString_Type TF_StringGetType(const TF_TString *str) { + return TF_TString_GetType(str); +} + +size_t TF_StringGetSize(const TF_TString *tstr) { + return TF_TString_GetSize(tstr); +} + +size_t TF_StringGetCapacity(const TF_TString *str) { + return TF_TString_GetCapacity(str); +} + +void TF_StringDealloc(TF_TString *tstr) { TF_TString_Dealloc(tstr); } diff --git a/tensorflow/c/tf_tstring.h b/tensorflow/c/tf_tstring.h index 8b576ff8197bc5..5dc29f23d59193 100644 --- a/tensorflow/c/tf_tstring.h +++ b/tensorflow/c/tf_tstring.h @@ -15,6 +15,48 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_TSTRING_H_ #define TENSORFLOW_C_TF_TSTRING_H_ +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/platform/ctstring.h" +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // _WIN32 +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +TF_CAPI_EXPORT extern void TF_StringInit(TF_TString *t); + +TF_CAPI_EXPORT extern void TF_StringCopy(TF_TString *dst, const char *src, + size_t size); + +TF_CAPI_EXPORT extern void TF_StringAssignView(TF_TString *dst, const char *src, + size_t size); + +TF_CAPI_EXPORT extern const char *TF_StringGetDataPointer( + const TF_TString *tstr); + +TF_CAPI_EXPORT extern TF_TString_Type TF_StringGetType(const TF_TString *str); + +TF_CAPI_EXPORT extern size_t TF_StringGetSize(const TF_TString *tstr); + +TF_CAPI_EXPORT extern size_t TF_StringGetCapacity(const TF_TString *str); + +TF_CAPI_EXPORT extern void TF_StringDealloc(TF_TString *tstr); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + #endif // THIRD_PARTY_TENSORFLOW_C_TF_TSTRING_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 8f7e447d32268a..2aaf8e62ab4689 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -6,7 +6,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "cc_library_with_android_deps", - "tf_cc_binary", "tf_cc_test", "tf_copts", "transitive_hdrs", @@ -650,14 +649,6 @@ tf_gen_op_wrappers_cc( pkg = "//tensorflow/core", ) -tf_gen_op_wrappers_cc( - name = "remote_fused_graph_ops", - op_lib_names = [ - "remote_fused_graph_ops", - ], - pkg = "//tensorflow/core", -) - tf_gen_op_wrappers_cc( name = "tpu_ops", include_internal_ops = 1, @@ -748,36 +739,6 @@ tf_gen_op_wrappers_cc( ], ) -tf_cc_binary( - name = "tutorials_example_trainer", - srcs = ["tutorials/example_trainer.cc"], - copts = tf_copts(), - linkopts = select({ - "//tensorflow:windows": [], - "//tensorflow:macos": [ - "-lm", - "-lpthread", - ], - "//tensorflow:ios": [ - "-lm", - "-lpthread", - ], - "//conditions:default": [ - "-lm", - "-lpthread", - "-lrt", - ], - }), - deps = [ - ":cc_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", - ], -) - cc_library( name = "queue_runner", srcs = ["training/queue_runner.cc"], @@ -854,9 +815,7 @@ transitive_hdrs( ":gradients", ":ops", ":queue_runner", - ":remote_fused_graph_ops", ":scope", - "//tensorflow/cc/profiler", "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader", "//tensorflow/cc/saved_model:reader", diff --git a/tensorflow/cc/experimental/base/public/tensor.h b/tensorflow/cc/experimental/base/public/tensor.h index fc447262ce16df..7aab1ccef18930 100644 --- a/tensorflow/cc/experimental/base/public/tensor.h +++ b/tensorflow/cc/experimental/base/public/tensor.h @@ -76,7 +76,7 @@ class Tensor { // unknown rank. int dims() const; - // Returns the number of elements in in demension `d`. + // Returns the number of elements in dimension `d`. // REQUIRES: `0 <= d < dims()` int64_t dim_size(int d) const; @@ -154,7 +154,7 @@ inline Tensor Tensor::FromBuffer(TF_DataType dtype, // 1. Only a function pointer is sent across the C API (&DeleterFunction) // 2. DeleterFunction is defined in the same build artifact that constructed // the std::function (so there isn't confusion about std::function ABI). - // Note that 2. is satisifed by the fact that this is a header-only API, where + // Note that 2. is satisfied by the fact that this is a header-only API, where // the function implementations are inline. DeleterStruct* deleter_struct = new DeleterStruct{deleter}; diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 13e666ddaad4ba..467202250c8313 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -60,7 +60,7 @@ string GetPath(const string& dot_h_fname) { if (result.size() > sizeof("external/") && result.compare(0, sizeof("external/") - 1, "external/") == 0) { result = result.substr(sizeof("external/") - 1); - pos = result.find("/"); + pos = result.find('/'); if (pos != string::npos) { result = result.substr(pos + 1); } @@ -586,7 +586,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, if (!api_def.description().empty()) { strings::StrAppend(&comment, "\n", api_def.description(), "\n"); } - strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n"); + strings::StrAppend(&comment, "\nArgs:\n* scope: A Scope object\n"); // Process inputs for (int i = 0; i < api_def.arg_order_size(); ++i) { diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 1414e861002487..649c979ecc67d0 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -73,7 +73,9 @@ class Output { Node* node() const { return op().node(); } int32 index() const { return index_; } DataType type() const { return op_.output_type(index_); } - string name() const { return strings::StrCat(node()->name(), ":", index()); } + std::string name() const { + return strings::StrCat(node()->name(), ":", index()); + } bool operator==(const Output& other) const { return op_ == other.op_ && index_ == other.index_; } @@ -107,7 +109,7 @@ class Input { /// be converted to a string (eg. a string literal). template ::value || - std::is_convertible::value>::type> + std::is_convertible::value>::type> Initializer(const T& v) { // NOLINT(runtime/explicit) typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), TensorShape()); @@ -120,7 +122,7 @@ class Input { /// Construct from a scalar value and an explicit shape template ::value || - std::is_convertible::value>::type> + std::is_convertible::value>::type> Initializer(const T& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); @@ -133,7 +135,7 @@ class Input { /// Construct from a initializer list of scalars (a one-dimensional tensor). template ::value || - std::is_convertible::value>::type> + std::is_convertible::value>::type> Initializer( const std::initializer_list& v) { // NOLINT(runtime/explicit) typedef typename RealType::type RealT; @@ -146,7 +148,7 @@ class Input { /// Construct from a initializer list of scalars and an explicit shape. template ::value || - std::is_convertible::value>::type> + std::is_convertible::value>::type> Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); @@ -168,7 +170,7 @@ class Input { Initializer(const std::initializer_list& v); // START_SKIP_DOXYGEN - template ::value> + template ::value> struct RealType { typedef tstring type; }; @@ -205,7 +207,7 @@ class Input { template ::value || - std::is_convertible::value>::type> + std::is_convertible::value>::type> Input(const T& v) // NOLINT(runtime/explicit) : Input(Initializer(v)) {} @@ -230,11 +232,11 @@ class Input { /// Constructor specifying a node name, index and datatype. This should only /// be used for specifying a backward edge, needed by control flow. - Input(const string& name, int32 i, DataType dt) + Input(const std::string& name, int32 i, DataType dt) : node_name_(name), index_(i), data_type_(dt) {} Node* node() const { return output_.node(); } - string node_name() const { return node_name_; } + std::string node_name() const { return node_name_; } int32 index() const { return node_name_.empty() ? output_.index() : index_; } DataType data_type() const { return data_type_; } Status status() const { return status_; } @@ -244,7 +246,7 @@ class Input { Status status_; Output output_ = Output(Operation(nullptr), 0); Tensor tensor_; - const string node_name_ = ""; + const std::string node_name_ = ""; int32 index_ = 0; DataType data_type_ = DT_INVALID; }; diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index d329b999a5cd29..86b659e7601691 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -67,7 +67,7 @@ bool IsZero(const Scope& scope, const Output& grad) { // mat: A 2-D tensor of dimension [D0, D1] // // Returns: -// A tensor of dimension [D0, D1], the result fo vec * mat. +// A tensor of dimension [D0, D1], the result for vec * mat. Output BroadcastMul(const Scope& scope, const Output& vec, const Output& mat) { auto reshaped = ExpandDims(scope, vec, -1); return Multiply(scope, reshaped, mat); diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD deleted file mode 100644 index 43240506f8ca60..00000000000000 --- a/tensorflow/cc/profiler/BUILD +++ /dev/null @@ -1,40 +0,0 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -tf_cuda_cc_test( - name = "profiler_test", - srcs = ["profiler_test.cc"], - tags = [ - "no_gpu", # b/77649654 - "no_rocm", # stream level tracing not supported on ROCm - ], - deps = [ - ":profiler", - "//tensorflow/cc:cc_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "profiler", - srcs = ["profiler.cc"], - hdrs = ["profiler.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/profiler:protos_all_cc", - "//tensorflow/core/profiler:tfprof_options", - "//tensorflow/core/profiler/internal:tfprof_stats", - ], -) diff --git a/tensorflow/cc/profiler/profiler.cc b/tensorflow/cc/profiler/profiler.cc deleted file mode 100644 index 3e55bac73e6d32..00000000000000 --- a/tensorflow/cc/profiler/profiler.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/profiler/profiler.h" - -namespace tensorflow { -namespace tfprof { - -Profiler::Profiler(const GraphDef& graph) { - std::unique_ptr graph_ptr(new GraphDef()); - *graph_ptr = graph; - stats_.reset(new TFStats(std::move(graph_ptr), nullptr, nullptr, nullptr)); -} - -void Profiler::AddStep(int64 step, const RunMetadata& run_meta) { - std::unique_ptr run_meta_ptr(new RunMetadata()); - *run_meta_ptr = run_meta; - stats_->AddRunMeta(step, std::move(run_meta_ptr)); -} - -GraphNodeProto Profiler::ProfileGraph(const Options& options) { - stats_->BuildView(kCmds[1]); - return stats_->ShowGraphNode(kCmds[1], options); -} - -GraphNodeProto Profiler::ProfileNameScope(const Options& options) { - stats_->BuildView(kCmds[0]); - return stats_->ShowGraphNode(kCmds[0], options); -} - -MultiGraphNodeProto Profiler::ProfileOperations(const Options& options) { - stats_->BuildView(kCmds[3]); - return stats_->ShowMultiGraphNode(kCmds[3], options); -} - -Status Profiler::SerializeToString(string* content) { - if (!content) { - return Status(error::Code::INVALID_ARGUMENT, - "Cannot use null string pointer for SerializeToString."); - } - stats_->SerializeToString(content); - return Status::OK(); -} - -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h deleted file mode 100644 index dc60fd5fb37a91..00000000000000 --- a/tensorflow/cc/profiler/profiler.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_ -#define TENSORFLOW_CC_PROFILER_PROFILER_H_ - -#include -#include - -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/profiler/internal/tfprof_stats.h" -#include "tensorflow/core/profiler/tfprof_options.h" -#include "tensorflow/core/profiler/tfprof_output.pb.h" - -namespace tensorflow { -namespace tfprof { - -/// @addtogroup core -/// @{ - -/// A `Profiler` object lets the caller profile the execution of a graph. -/// -/// Example: -/// // First build a graph and run tracing. -/// Scope root = Scope::NewRootScope(); -/// auto a = Placeholder(root, DT_INT32); -/// auto c = Add(root, a, {41}); -/// -/// ClientSession session(root); -/// std::vector outputs; -/// RunOptions run_options; -/// run_options.set_trace_level(RunOptions::FULL_TRACE); -/// RunMetadata run_meta; -/// Status s = session.Run(run_options, { {a, {1}} }, {c}, &outputs, -/// &run_meta); -/// if (!s.ok()) { ... } -/// -/// // Then create profiler to do profiling. -/// GraphDef graph; -/// root.ToGraphDef(&graph); -/// Profiler profiler(graph); -/// profiler.AddStep(0, run_meta); -/// Options opts = ... // TODO(xpan): Support option building API. -/// MultiGraphNodeProto r = profiler.ProfileOperations(opts); -/// -class Profiler { - public: - /// `graph` is the model's GraphDef. - explicit Profiler(const GraphDef& graph); - - /// Adds tracing information `run_meta` to profiler. A `run_meta` is - /// generated by a TensorFlow session run call. `step` is the key - /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to selectively profile the corresponding `run_meta`. - /// Multiple different `run_meta` can be keyed by the same `step` in order - /// to group them together. - void AddStep(int64 step, const RunMetadata& run_meta); - - /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are connected by the op inputs/outputs. - GraphNodeProto ProfileGraph(const Options& options); - - /// Profiles the model by organizing nodes in name scope structure. - /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a file system tree. - /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. - GraphNodeProto ProfileNameScope(const Options& options); - - /// Profiles the model by organizing nodes by operation types. - /// Each node is an operation type (e.g. Conv2D or MatMul), containing all - /// ops belonging to that type in the model. - MultiGraphNodeProto ProfileOperations(const Options& options); - - /// Serialize the profile content (ProfileProto) into a binary string, - /// User can write the string to file for offline analysis by - /// tfprof command-line tools or graphical user interface. - Status SerializeToString(string* content); - - private: - std::unique_ptr stats_; -}; -/// @} - -} // namespace tfprof -} // namespace tensorflow - -#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_ diff --git a/tensorflow/cc/profiler/profiler_test.cc b/tensorflow/cc/profiler/profiler_test.cc deleted file mode 100644 index 280cd74827fc8a..00000000000000 --- a/tensorflow/cc/profiler/profiler_test.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/platform/test.h" - -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/cc/profiler/profiler.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/graph/default_device.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { -namespace tfprof { - -class ProfilerTest : public ::testing::Test { - protected: - ProfilerTest() {} -}; - -GraphDef CreateGraphDef() { - Scope root = Scope::NewRootScope(); - - auto a = ops::Const(root, {{3, 2}, {-1, 0}}); - - auto x = ops::Const(root.WithOpName("x"), {{1.f}, {1.f}}); - - auto y = ops::MatMul(root.WithOpName("y"), a, x); - - auto y2 = ops::Square(root, y); - - auto y2_sum = ops::Sum(root, y2, 0); - - auto y_norm = ops::Sqrt(root, y2_sum); - - auto y_div = ops::Div(root.WithOpName("y_normalized"), y, y_norm); - - GraphDef def; - TF_CHECK_OK(root.ToGraphDef(&def)); - - return def; -} - -Options Default() { - Options opts(1000, /* max_depth */ - 0, /* min_bytes */ - 0, /* min_peak_bytes */ - 0, /* min_residual_bytes */ - 0, /* min_output_bytes */ - 0, /* min_micros */ - 0, /* min_accelerator_micros */ - 0, /* min_cpu_micros */ - 0, /* min_params */ - 0, /* min_float_ops */ - 0, /* min_occurrence */ - 0, /* step */ - "name", /* order_by */ - {".*"}, /* account_type_regexes */ - {".*"}, /* start_name_regexes */ - {}, /* trim_name_regexes */ - {".*"}, {}, /* hide_name_regexes */ - false, /* account_displayed_op_only */ - {"micros"}, /* select */ - {"none"}, /* output_type */ - {}); - return opts; -} - -template -const T* ExtractNode(const T& pb, const string& name) { - if (pb.name() == name) { - return &pb; - } - for (const T& c : pb.children()) { - const T* ret = ExtractNode(c, name); - if (ret) return ret; - } - return nullptr; -} - -TEST_F(ProfilerTest, Basics) { - SessionOptions options; - options.config.set_allow_soft_placement(true); - std::unique_ptr session(NewSession(options)); - GraphDef def = CreateGraphDef(); - if (options.target.empty()) { - graph::SetDefaultDevice("/gpu:0", &def); - } - - TF_CHECK_OK(session->Create(def)); - - Tensor x(DT_FLOAT, TensorShape({2, 1})); - auto x_flat = x.flat(); - x_flat.setRandom(); - Eigen::Tensor inv_norm = - x_flat.square().sum().sqrt().inverse(); - x_flat = x_flat * inv_norm(); - - std::vector outputs; - RunOptions run_options; - run_options.set_trace_level(RunOptions::FULL_TRACE); - RunMetadata run_metadata; - outputs.clear(); - - Profiler profiler(def); - for (int i = 0; i < 2; ++i) { - TF_CHECK_OK(session->Run(run_options, {{"x", x}}, {"y:0", "y_normalized:0"}, - {}, &outputs, &run_metadata)); - profiler.AddStep(i, run_metadata); - CHECK_EQ(size_t{2}, outputs.size()); - } - - std::vector resp; - TF_CHECK_OK(session->ListDevices(&resp)); - bool has_gpu = false; - for (const auto& dev : resp) { - if (dev.device_type() == "GPU") { - has_gpu = true; - } - } - - GraphNodeProto ret = profiler.ProfileNameScope(Default()); - const GraphNodeProto* matmul = ExtractNode(ret, "y"); - EXPECT_TRUE(matmul); - EXPECT_GT(matmul->exec_micros(), 0); - if (has_gpu) { - EXPECT_GT(matmul->accelerator_exec_micros(), 0); - } else { - EXPECT_EQ(matmul->accelerator_exec_micros(), 0); - } - const GraphNodeProto* square = ExtractNode(ret, "Square"); - EXPECT_TRUE(square); - EXPECT_GT(square->exec_micros(), 0); - if (has_gpu) { - EXPECT_GT(square->accelerator_exec_micros(), 0); - } else { - EXPECT_EQ(square->accelerator_exec_micros(), 0); - } - - Options opts2 = Default(); - opts2.output_type = "timeline"; - string timeline_file = io::JoinPath(testing::TmpDir(), "timeline"); - opts2.output_options["outfile"] = timeline_file; - GraphNodeProto ret2 = profiler.ProfileGraph(opts2); - string s; - TF_CHECK_OK(ReadFileToString(Env::Default(), timeline_file + "_0", &s)); - EXPECT_TRUE(s.find("Square") != s.npos); - - MultiGraphNodeProto ret3 = profiler.ProfileOperations(Default()); - const MultiGraphNodeProto* matmul2 = ExtractNode(ret3, "MatMul"); - EXPECT_TRUE(matmul2); - EXPECT_GT(matmul2->exec_micros(), 0); - if (has_gpu) { - EXPECT_GT(matmul2->accelerator_exec_micros(), 0); - } else { - EXPECT_EQ(matmul2->accelerator_exec_micros(), 0); - } - - TF_CHECK_OK(session->Close()); -} - -} // namespace tfprof -} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 056c99eed8e809..92e834aea0b0b9 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -38,6 +38,14 @@ cc_library( hdrs = ["tag_constants.h"], ) +# copybara:uncomment_begin(google-only) +# cc_library( +# name = "mobile_only_deps", +# visibility = ["//visibility:private"], +# deps = if_mobile(["//tensorflow/core:portable_tensorflow_lib"]), +# ) +# copybara:uncomment_end + cc_library( name = "reader", srcs = ["reader.cc"], @@ -45,6 +53,7 @@ cc_library( deps = [ ":constants", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util/tensor_bundle", ] + if_not_mobile([ # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate diff --git a/tensorflow/cc/saved_model/bundle_v2.cc b/tensorflow/cc/saved_model/bundle_v2.cc index b6daece84abb9d..e164352c8482ff 100644 --- a/tensorflow/cc/saved_model/bundle_v2.cc +++ b/tensorflow/cc/saved_model/bundle_v2.cc @@ -114,18 +114,27 @@ Status SavedModelV2Bundle::Load(const std::string& export_dir, TF_RETURN_IF_ERROR( ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_)); - // Load the variables checkpoint reader. - const std::string variables_prefix = io::JoinPath( - export_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename); - bundle->variable_reader_.reset( - new BundleReader(Env::Default(), variables_prefix)); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - bundle->variable_reader_->status(), - "Unable to load SavedModel variables checkpoint from ", variables_prefix); + const std::string variables_dir = + io::JoinPath(export_dir, kSavedModelVariablesDirectory); + if (!Env::Default()->FileExists(variables_dir).ok()) { + LOG(INFO) + << "No checkpoint found, assuming this is a program-only SavedModel"; + } else { + // Load the variables checkpoint reader. + const std::string variables_prefix = + io::JoinPath(variables_dir, kSavedModelVariablesFilename); + bundle->variable_reader_.reset( + new BundleReader(Env::Default(), variables_prefix)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + bundle->variable_reader_->status(), + "Unable to load SavedModel variables checkpoint from ", + variables_prefix); + + // Deserialize the object graph proto from the tensor bundle. + TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph( + bundle->variable_reader_.get(), &bundle->trackable_object_graph_)); + } - // Deserialize the object graph proto from the tensor bundle. - TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph( - bundle->variable_reader_.get(), &bundle->trackable_object_graph_)); return Status::OK(); } diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h index c2bfb4dcf83075..9d30a4a20add44 100644 --- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -84,9 +84,6 @@ class SavedModelAPI { SignatureDefFunction* GetSignatureDefFunction( const std::string& function_path, Status* status); - // Lists all Conrete Functions available from the SavedModel. - std::vector ListFunctions(); - // SavedModelAPI is movable, but not copyable. SavedModelAPI(SavedModelAPI&&) = default; SavedModelAPI& operator=(SavedModelAPI&&) = default; @@ -151,11 +148,6 @@ inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction( return SignatureDefFunction::wrap(function); } -inline std::vector SavedModelAPI::ListFunctions() { - ConcreteFunctionList list(TF_ListSavedModelFunctions(saved_model_.get())); - return list.ToVector(); -} - } // namespace cc } // namespace experimental } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index c1d4736f6b98b4..b5831a1bd5e961 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -19,11 +19,17 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/util/tensor_bundle/byte_swap.h" namespace tensorflow { namespace { @@ -49,6 +55,35 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { export_dir); } +// Swap tensor_content field of Const Op Tensors in the named functions +static Status SwapTensorContent(MetaGraphDef* meta_graph_def) { + GraphDef graph_def = *meta_graph_def->mutable_graph_def(); + for (auto& function : *meta_graph_def->mutable_graph_def() + ->mutable_library() + ->mutable_function()) { + for (auto& node : (*function.mutable_node_def())) { + if (node.op() != "Const") continue; + auto node_iterator = node.mutable_attr()->find("value"); + if (node_iterator == node.mutable_attr()->end()) continue; + AttrValue node_value = node_iterator->second; + if (!node_value.has_tensor()) continue; + + auto tsize = node_value.mutable_tensor()->tensor_content().size(); + auto p_type = node_value.mutable_tensor()->dtype(); + // Swap only when there is something in tensor_content field + if (tsize != 0 && DataTypeCanUseMemcpy(p_type)) { + Tensor parsed(p_type); + DCHECK(parsed.FromProto(*node_value.mutable_tensor())); + TF_RETURN_IF_ERROR(ByteSwapTensor(&parsed)); + (*node.mutable_attr())["value"].mutable_tensor()->set_tensor_content( + string(reinterpret_cast(parsed.tensor_data().data()), + parsed.tensor_data().size())); + } + } + } + return Status::OK(); +} + Status FindMetaGraphDef(const std::unordered_set& tags, SavedModel* saved_model_proto, MetaGraphDef* meta_graph_def) { @@ -63,6 +98,10 @@ Status FindMetaGraphDef(const std::unordered_set& tags, // Match with the set of tags provided. if (graph_tags == tags) { *meta_graph_def = std::move(graph_def); + // Correct the endiness of Tensor content on big-endian system + if (!port::kLittleEndian) { + TF_RETURN_IF_ERROR(SwapTensorContent(meta_graph_def)); + } return Status::OK(); } } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 274a1630a05bba..d7f79c510bde95 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -138,7 +138,7 @@ class FreezeTest : public ::testing::Test { } TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. + // "c" isn't dependent on the variable, so nothing should be frozen. TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( graph_def, {"c:0"}, "assign", &saved_model_bundle)); @@ -183,7 +183,7 @@ class FreezeTest : public ::testing::Test { } Output c = ops::Mul(scope.WithOpName("c"), a, read_var); TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. + // "c" isn't dependent on the variable, so nothing should be frozen. TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( graph_def, {"c:0"}, "assign", &saved_model_bundle)); @@ -244,7 +244,7 @@ class FreezeTest : public ::testing::Test { Output c = ops::Mul(scope.WithOpName("c"), a, read_var); TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); - // "c" isnt dependent on the variable, so nothing should be frozen. + // "c" isn't dependent on the variable, so nothing should be frozen. TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( graph_def, {"c:0"}, "assign", &saved_model_bundle)); diff --git a/tensorflow/cc/tutorials/example_trainer.cc b/tensorflow/cc/tutorials/example_trainer.cc deleted file mode 100644 index 789662f84d00ba..00000000000000 --- a/tensorflow/cc/tutorials/example_trainer.cc +++ /dev/null @@ -1,234 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/graph/default_device.h" -#include "tensorflow/core/graph/graph_def_builder.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/public/session.h" - -using tensorflow::string; -using tensorflow::int32; - -namespace tensorflow { -namespace example { - -struct Options { - int num_concurrent_sessions = 1; // The number of concurrent sessions - int num_concurrent_steps = 10; // The number of concurrent steps - int num_iterations = 100; // Each step repeats this many times - bool use_gpu = false; // Whether to use gpu in the training -}; - -// A = [3 2; -1 0]; x = rand(2, 1); -// We want to compute the largest eigenvalue for A. -// repeat x = y / y.norm(); y = A * x; end -GraphDef CreateGraphDef() { - // TODO(jeff,opensource): This should really be a more interesting - // computation. Maybe turn this into an mnist model instead? - Scope root = Scope::NewRootScope(); - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - - // A = [3 2; -1 0]. Using Const means the result will be a - // float tensor even though the initializer has integers. - auto a = Const(root, {{3, 2}, {-1, 0}}); - - // x = [1.0; 1.0] - auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}}); - - // y = A * x - auto y = MatMul(root.WithOpName("y"), a, x); - - // y2 = y.^2 - auto y2 = Square(root, y); - - // y2_sum = sum(y2). Note that you can pass constants directly as - // inputs. Sum() will automatically create a Const node to hold the - // 0 value. - auto y2_sum = Sum(root, y2, 0); - - // y_norm = sqrt(y2_sum) - auto y_norm = Sqrt(root, y2_sum); - - // y_normalized = y ./ y_norm - Div(root.WithOpName("y_normalized"), y, y_norm); - - GraphDef def; - TF_CHECK_OK(root.ToGraphDef(&def)); - - return def; -} - -string DebugString(const Tensor& x, const Tensor& y) { - CHECK_EQ(x.NumElements(), 2); - CHECK_EQ(y.NumElements(), 2); - auto x_flat = x.flat(); - auto y_flat = y.flat(); - // Compute an estimate of the eigenvalue via - // (x' A x) / (x' x) = (x' y) / (x' x) - // and exploit the fact that x' x = 1 by assumption - Eigen::Tensor lambda = (x_flat * y_flat).sum(); - return strings::Printf("lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]", - lambda(), x_flat(0), x_flat(1), y_flat(0), y_flat(1)); -} - -void ConcurrentSteps(const Options* opts, int session_index) { - // Creates a session. - SessionOptions options; - std::unique_ptr session(NewSession(options)); - GraphDef def = CreateGraphDef(); - if (options.target.empty()) { - graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def); - } - - TF_CHECK_OK(session->Create(def)); - - // Spawn M threads for M concurrent steps. - const int M = opts->num_concurrent_steps; - std::unique_ptr step_threads( - new thread::ThreadPool(Env::Default(), "trainer", M)); - - for (int step = 0; step < M; ++step) { - step_threads->Schedule([&session, opts, session_index, step]() { - // Randomly initialize the input. - Tensor x(DT_FLOAT, TensorShape({2, 1})); - auto x_flat = x.flat(); - x_flat.setRandom(); - Eigen::Tensor inv_norm = - x_flat.square().sum().sqrt().inverse(); - x_flat = x_flat * inv_norm(); - - // Iterations. - std::vector outputs; - for (int iter = 0; iter < opts->num_iterations; ++iter) { - outputs.clear(); - TF_CHECK_OK( - session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs)); - CHECK_EQ(size_t{2}, outputs.size()); - - const Tensor& y = outputs[0]; - const Tensor& y_norm = outputs[1]; - // Print out lambda, x, and y. - std::printf("%06d/%06d %s\n", session_index, step, - DebugString(x, y).c_str()); - // Copies y_normalized to x. - x = y_norm; - } - }); - } - - // Delete the threadpool, thus waiting for all threads to complete. - step_threads.reset(nullptr); - TF_CHECK_OK(session->Close()); -} - -void ConcurrentSessions(const Options& opts) { - // Spawn N threads for N concurrent sessions. - const int N = opts.num_concurrent_sessions; - - // At the moment our Session implementation only allows - // one concurrently computing Session on GPU. - CHECK_EQ(1, N) << "Currently can only have one concurrent session."; - - thread::ThreadPool session_threads(Env::Default(), "trainer", N); - for (int i = 0; i < N; ++i) { - session_threads.Schedule(std::bind(&ConcurrentSteps, &opts, i)); - } -} - -} // end namespace example -} // end namespace tensorflow - -namespace { - -bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - int32* dst) { - if (absl::ConsumePrefix(&arg, flag) && absl::ConsumePrefix(&arg, "=")) { - char extra; - return (sscanf(arg.data(), "%d%c", dst, &extra) == 1); - } - - return false; -} - -bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, - bool* dst) { - if (absl::ConsumePrefix(&arg, flag)) { - if (arg.empty()) { - *dst = true; - return true; - } - - if (arg == "=true") { - *dst = true; - return true; - } else if (arg == "=false") { - *dst = false; - return true; - } - } - - return false; -} - -} // namespace - -int main(int argc, char* argv[]) { - tensorflow::example::Options opts; - std::vector unknown_flags; - for (int i = 1; i < argc; ++i) { - if (string(argv[i]) == "--") { - while (i < argc) { - unknown_flags.push_back(argv[i]); - ++i; - } - break; - } - - if (ParseInt32Flag(argv[i], "--num_concurrent_sessions", - &opts.num_concurrent_sessions) || - ParseInt32Flag(argv[i], "--num_concurrent_steps", - &opts.num_concurrent_steps) || - ParseInt32Flag(argv[i], "--num_iterations", &opts.num_iterations) || - ParseBoolFlag(argv[i], "--use_gpu", &opts.use_gpu)) { - continue; - } - - fprintf(stderr, "Unknown flag: %s\n", argv[i]); - return -1; - } - - // Passthrough any unknown flags. - int dst = 1; // Skip argv[0] - for (char* f : unknown_flags) { - argv[dst++] = f; - } - argv[dst++] = nullptr; - argc = static_cast(unknown_flags.size() + 1); - tensorflow::port::InitMain(argv[0], &argc, &argv); - tensorflow::example::ConcurrentSessions(opts); -} diff --git a/tensorflow/compat_template.__init__.py b/tensorflow/compat_template.__init__.py index f695e58e6a1758..b880b0417727fe 100644 --- a/tensorflow/compat_template.__init__.py +++ b/tensorflow/compat_template.__init__.py @@ -52,13 +52,21 @@ _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "estimator", estimator) -try: - from tensorflow.python.keras.api._v2 import keras - _current_module.__path__ = ( - [_module_util.get_parent_dir(keras)] + _current_module.__path__) +if _os.environ.get("_PREFER_OSS_KERAS", False): + _keras_module = "keras.api._v2.keras" + keras = _LazyLoader("keras", globals(), _keras_module) + _module_dir = _module_util.get_parent_dir_for_name(_keras_module) + if _module_dir: + _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "keras", keras) -except ImportError: - pass +else: + try: + from tensorflow.python.keras.api._v2 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) + except ImportError: + pass # Explicitly import lazy-loaded modules to support autocompletion. # pylint: disable=g-import-not-at-top @@ -79,11 +87,30 @@ # Add module aliases if hasattr(_current_module, 'keras'): - losses = keras.losses - metrics = keras.metrics - optimizers = keras.optimizers - initializers = keras.initializers - setattr(_current_module, "losses", losses) - setattr(_current_module, "metrics", metrics) - setattr(_current_module, "optimizers", optimizers) - setattr(_current_module, "initializers", initializers) + # It is possible that keras is a lazily loaded module, which might break when + # actually trying to import it. Have a Try-Catch to make sure it doesn't break + # when it doing some very initial loading, like tf.compat.v2, etc. + if _os.environ.get("_PREFER_OSS_KERAS", False): + try: + _keras_package = "keras.api._v2.keras." + losses = _LazyLoader("losses", globals(), _keras_package + "losses") + metrics = _LazyLoader("metrics", globals(), _keras_package + "metrics") + optimizers = _LazyLoader( + "optimizers", globals(), _keras_package + "optimizers") + initializers = _LazyLoader( + "initializers", globals(), _keras_package + "initializers") + setattr(_current_module, "losses", losses) + setattr(_current_module, "metrics", metrics) + setattr(_current_module, "optimizers", optimizers) + setattr(_current_module, "initializers", initializers) + except ImportError: + pass + else: + losses = keras.losses + metrics = keras.metrics + optimizers = keras.optimizers + initializers = keras.initializers + setattr(_current_module, "losses", losses) + setattr(_current_module, "metrics", metrics) + setattr(_current_module, "optimizers", optimizers) + setattr(_current_module, "initializers", initializers) diff --git a/tensorflow/compat_template_v1.__init__.py b/tensorflow/compat_template_v1.__init__.py index c216ef62df4567..10929243d8cc26 100644 --- a/tensorflow/compat_template_v1.__init__.py +++ b/tensorflow/compat_template_v1.__init__.py @@ -42,13 +42,21 @@ _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "estimator", estimator) -try: - from tensorflow.python.keras.api._v1 import keras - _current_module.__path__ = ( - [_module_util.get_parent_dir(keras)] + _current_module.__path__) +if _os.environ.get("_PREFER_OSS_KERAS", False): + _keras_module = "keras.api._v1.keras" + keras = _LazyLoader("keras", globals(), _keras_module) + _module_dir = _module_util.get_parent_dir_for_name(_keras_module) + if _module_dir: + _current_module.__path__ = [_module_dir] + _current_module.__path__ setattr(_current_module, "keras", keras) -except ImportError: - pass +else: + try: + from tensorflow.python.keras.api._v1 import keras + _current_module.__path__ = ( + [_module_util.get_parent_dir(keras)] + _current_module.__path__) + setattr(_current_module, "keras", keras) + except ImportError: + pass # Explicitly import lazy-loaded modules to support autocompletion. # pylint: disable=g-import-not-at-top diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 06745de647bc0a..b45e21f33af393 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -61,6 +61,7 @@ test_suite( ":test_graph_tfvariable_test", ":tfcompile_test", ], + visibility = ["//visibility:public"], ) py_binary( @@ -82,6 +83,7 @@ py_binary( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variables", + "@absl_py//absl:app", "@six_archive//:six", ], ) @@ -115,8 +117,8 @@ genrule( # have control of the full GPU. cmd = "CUDA_VISIBLE_DEVICES='' " + "$(location :make_test_graphs) --out_dir $(@D)", - exec_tools = [":make_test_graphs"], tags = ["manual"], + tools = [":make_test_graphs"], ) tf_library( diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 532d64c5a3e702..6ae5631eec027c 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -23,6 +23,7 @@ import os import sys +from absl import app import six from six.moves import range @@ -39,7 +40,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import app from tensorflow.python.training import saver as saver_lib FLAGS = None diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 742cb308b3cb02..c94d95fa3e393b 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -127,7 +127,7 @@ def tf_library( "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), - exec_tools = [tfcompile_tool], + tools = [tfcompile_tool], # Run tfcompile on the build host, rather than forge, since it's # typically way faster on the local machine. local = 1, @@ -162,7 +162,7 @@ def tf_library( "//tensorflow/python/tools:freeze_graph)" + freeze_args ), - exec_tools = ["//tensorflow/python/tools:freeze_graph"], + tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) tfcompile_graph = freeze_file @@ -242,7 +242,7 @@ def tf_library( " --out_function_object=$(@D)/" + function_object_file + " " + flags + " " + profiling_flag + " " + mlir_flag + " " + traceme_flag ), - exec_tools = [tfcompile_tool], + tools = [tfcompile_tool], visibility = visibility, testonly = testonly, # Run tfcompile on the build host since it's typically faster on the @@ -281,7 +281,7 @@ def tf_library( " --out_session_module=$(@D)/" + session_module_pb + " " + flags ), - exec_tools = [tfcompile_tool], + tools = [tfcompile_tool], visibility = visibility, testonly = testonly, local = 1, @@ -432,7 +432,8 @@ def target_llvm_triple(): "//tensorflow:ios": "arm64-none-ios", "//tensorflow:ios_x86_64": "x86_64-apple-ios", "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:macos": "x86_64-none-darwin", + "//tensorflow:macos_x86_64": "x86_64-none-darwin", + "//tensorflow:macos_arm64": "aarch64-none-darwin", "//tensorflow:windows": "x86_64-none-windows", "//tensorflow:linux_s390x": "systemz-none-linux-gnu", "//conditions:default": "x86_64-pc-linux", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index deb3396d89cb01..5bdf309280eb49 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -4,7 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "tf_cc_test") # buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "if_libtpu", "tf_copts") +load("//tensorflow:tensorflow.bzl", "if_libtpu", "if_with_tpu_support", "tf_copts") load("//tensorflow/stream_executor:build_defs.bzl", "if_cuda_or_rocm") # buildifier: disable=same-origin-load @@ -19,8 +19,11 @@ load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", load("//tensorflow/core/platform:build_config_root.bzl", "tf_cuda_tests_tags") package( - default_visibility = [":internal"], - licenses = ["notice"], # Apache 2.0 + default_visibility = [ + ":internal", + "//third_party/cloud_tpu/inference_converter:__pkg__", + ], + licenses = ["notice"], ) package_group( @@ -67,6 +70,9 @@ cc_library( ] + if_cuda_or_rocm([ ":xla_gpu_device", ":xla_gpu_jit", + ]) + if_with_tpu_support([ + ":xla_tpu_device", + ":xla_tpu_jit", ]), alwayslink = 1, ) @@ -101,6 +107,16 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_tpu_jit", + visibility = ["//visibility:public"], + deps = if_libtpu([ + "//tensorflow/core/tpu/graph_rewrite:tpu_rewrite_pass_registration", + "//tensorflow/stream_executor/tpu:tpu_transfer_manager", + ]), + alwayslink = 1, +) + cc_library( name = "xla_cpu_device", srcs = ["xla_cpu_device.cc"], @@ -153,6 +169,42 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_tpu_device", + srcs = ["xla_tpu_device.cc"], + hdrs = ["xla_tpu_device.h"], + visibility = [":friends"], + deps = [ + ":jit_compilation_passes", + ":xla_device", + ":xla_kernel_creator", # buildcleaner: keep + "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core/common_runtime:copy_tensor", + "//tensorflow/core/common_runtime:device", + "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/common_runtime:dma_helper", + "//tensorflow/core/platform:status", + "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/tpu:tpu_node_device_util", + "//tensorflow/core/tpu:virtual_device", + "//tensorflow/stream_executor/tpu:c_api_conversions", + "//tensorflow/stream_executor/tpu:status_helper", + "//tensorflow/stream_executor/tpu:tpu_executor_base", + "//tensorflow/stream_executor/tpu:tpu_node_context", + "//tensorflow/stream_executor/tpu:tpu_platform_interface", + "//tensorflow/stream_executor/tpu:tpu_stream_interface", + ], + alwayslink = 1, +) + cc_library( name = "xla_tensor", srcs = ["xla_tensor.cc"], @@ -184,6 +236,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", @@ -215,10 +268,12 @@ XLA_DEVICE_DEPS = [ "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:finalize_dataset_op", "//tensorflow/core/kernels/data:generator_dataset_op", "//tensorflow/core/kernels/data:iterator_ops", "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", + "//tensorflow/core/kernels/data:options_dataset_op", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor/platform", @@ -244,6 +299,7 @@ cc_library( # Public visibility is needed for external TF/XLA backends. visibility = ["//visibility:public"], deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"], + alwayslink = 1, ) cc_library( @@ -289,6 +345,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -302,12 +359,17 @@ cc_library( "//tensorflow/compiler/xla:parse_flags_from_env", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) cc_header_only_library( name = "flags_headers_only", + features = [ + "-parse_headers", # buildifier: disable=no-parse-headers + ], deps = [":flags_headers"], ) @@ -364,12 +426,9 @@ cc_library( ":flags", ":xla_activity_listener", ":xla_activity_proto_cc", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", + "//tensorflow/compiler/mlir:array_container_utils", + "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", @@ -384,13 +443,13 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", - ] + if_libtpu( - if_false = [ - "//tensorflow/compiler/mlir:array_container_utils", - "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", - ], - if_true = [], - ), + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], ) tf_cc_test( @@ -429,7 +488,6 @@ cc_library( hdrs = ["get_compiler_ir.h"], visibility = [ ":internal", - "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", "//tensorflow/core/common_runtime/eager:__pkg__", ], deps = [ @@ -460,7 +518,6 @@ cc_library( textual_hdrs = ["get_compiler_ir.h"], visibility = [ ":internal", - "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", "//tensorflow/core/common_runtime/eager:__pkg__", ], deps = [ @@ -474,6 +531,9 @@ cc_library( cc_header_only_library( name = "get_compiler_ir_hdrs_only", + features = [ + "-parse_headers", # buildifier: disable=no-parse-headers + ], deps = [":get_compiler_ir_hdrs"], ) @@ -497,7 +557,6 @@ cc_library( ], visibility = [ ":internal", - "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", "//tensorflow/core/common_runtime/eager:__pkg__", ], deps = [ @@ -507,6 +566,7 @@ cc_library( ":flags", ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration", + "//tensorflow/compiler/tf2xla:mlir_bridge_pass", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:core_cpu_internal", @@ -546,8 +606,8 @@ cc_library( hdrs = ["resource_operation_safety_analysis.h"], deps = [ ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", @@ -718,7 +778,6 @@ cc_library( "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/cc:scope_internal", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:side_effect_util", @@ -731,6 +790,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -758,9 +818,9 @@ cc_library( deps = [ ":flags", ":xla_activity_proto_cc", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -868,9 +928,12 @@ tf_cc_test( "partially_decluster_pass_test.cc", "rearrange_function_argument_pass_test.cc", ], - # TODO(b/141643254) Re-enable msan after fixing use-of-uninitialized-value - # error. - tags = ["nomsan"] + tf_cuda_tests_tags(), + tags = [ + # TODO(b/141643254) Re-enable msan after fixing + # use-of-uninitialized-value error. + "nomsan", + "no_cuda_asan", # TODO(b/171317460): re-enable. + ] + tf_cuda_tests_tags(), deps = [ ":common", ":compilability_check_util", @@ -991,13 +1054,13 @@ cc_library( ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", - "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/graphcycles", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:graph", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index a340b9d3f4579b..be81fa86fcf877 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -309,9 +309,13 @@ xla::StatusOr ReplaceFunctionCallWithPartitionedCall( } } - ops::PartitionedCall call( - root.WithOpName("partitioned_call"), args, n->output_types(), func, - ops::PartitionedCall::Attrs{}.ConfigProto(config_string)); + // In theory we can use PartitionedCall if the XLA cluster does not have any + // stateful operations. However, for now we choose to be conservative since + // we don't have any evidence that choosing a stateless partitioned call helps + // for performance. + ops::StatefulPartitionedCall call( + root.WithOpName("stateful_partitioned_call"), args, n->output_types(), + func, ops::StatefulPartitionedCall::Attrs{}.ConfigProto(config_string)); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) { diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 160ea83585d1aa..869d869fdb42d8 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -194,7 +194,7 @@ TEST_F(BuildXlaOpsTest, OnNonXlaDevice) { auto xla_run = NodeWith(Op("_XlaRun"), Inputs(Out(1, predicated_compilation_key))); auto tf_call = - NodeWith(Op("PartitionedCall"), + NodeWith(Op("StatefulPartitionedCall"), CtrlDeps(NodeWith(Op("Identity"), Inputs(Out(0, predicated_compilation_key))))); auto merge = NodeWith(Op("_XlaMerge"), Inputs(Out(tf_call), Out(xla_run))); @@ -252,9 +252,10 @@ TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) { TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); Node* sink_node = graph->sink_node(); - EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")), - NodeWith(Op("PartitionedCall")), - NodeWith(Op("NoOp"))))); + EXPECT_THAT(sink_node, + NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")), + NodeWith(Op("StatefulPartitionedCall")), + NodeWith(Op("NoOp"))))); } #ifdef GOOGLE_CUDA @@ -298,15 +299,15 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) { std::unique_ptr graph; TF_ASSERT_OK(BuildXlaOps(root, fdef_lib, &graph)); - Node* partitioned_call_op = nullptr; + Node* stateful_partitioned_call_op = nullptr; for (Node* n : graph->op_nodes()) { - if (n->type_string() == "PartitionedCall") { - ASSERT_EQ(partitioned_call_op, nullptr); - partitioned_call_op = n; + if (n->type_string() == "StatefulPartitionedCall") { + ASSERT_EQ(stateful_partitioned_call_op, nullptr); + stateful_partitioned_call_op = n; } } - ASSERT_NE(partitioned_call_op, nullptr); + ASSERT_NE(stateful_partitioned_call_op, nullptr); auto xla_compile = NodeWith(Op("_XlaCompile")); auto switch_on_compilation_pred = NodeWith(Op("Switch"), Inputs(Out(0, xla_compile), Out(1, xla_compile))); @@ -315,7 +316,7 @@ TEST_F(BuildXlaOpsTest, NoDeviceToHostCopiesForClustersWithInt32Inputs) { // Check that we pipe int32 inputs through an IdentityN to avoid extra D2H // copies. EXPECT_THAT( - partitioned_call_op, + stateful_partitioned_call_op, NodeWith(Inputs(Out(NodeWith(Op("IdentityN"), CtrlDeps(ctrl_dep)))))); } #endif diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 62e121420c3b0e..7ff1d76aa2f502 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" @@ -42,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" @@ -151,10 +151,12 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( // not considered uncompilable. if (node_stack_trace != nullptr) { for (const auto& frame : *node_stack_trace) { - stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name}); + stack_trace.emplace_back( + StackFrameView{frame.name, frame.function_name, frame.stack_trace}); } } - stack_trace.emplace_back(StackFrameView{node.name(), ""}); + stack_trace.emplace_back( + StackFrameView{node.name(), "", node.GetStackTrace()}); RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes; IsCompilableNode(node, lib_runtime, &stack_trace, @@ -162,28 +164,6 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( return uncompilable_nodes; } -RecursiveCompilabilityChecker::UncompilableNodesMap -RecursiveCompilabilityChecker::FindUncompilableNodes( - const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, - const std::vector* - node_stack_trace) const { - // If `node_stack_trace` is provided, that means `call_def` is inside - // a function body, and therefore, arg nodes and retval nodes are - // not considered uncompilable. - std::vector stack_trace; - if (node_stack_trace != nullptr) { - for (const auto& frame : *node_stack_trace) { - stack_trace.emplace_back(StackFrameView{frame.name, frame.function_name}); - } - } - stack_trace.emplace_back(StackFrameView{call_def.name(), ""}); - - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes; - IsCompilableCall(call_def, lib_runtime, &stack_trace, - /*encapsulating_function=*/nullptr, &uncompilable_nodes); - return uncompilable_nodes; -} - bool RecursiveCompilabilityChecker::HasXLAKernel( const Node& node, string* uncompilable_reason) const { // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient @@ -194,12 +174,11 @@ bool RecursiveCompilabilityChecker::HasXLAKernel( "SymbolicGradient should be handled by IsCompilableCall()."; return false; } + if (node.type_string() == "Const") { - // Skip Const op with type DT_STRING, since XLA doesn't support it, but the - // registered Const KernelDef says that it does, to support no-op Assert for - // tfcompile. const AttrValue* attr = node.attrs().Find("dtype"); - if (attr != nullptr && attr->type() == DT_STRING) { + if (!op_filter_.allow_string_consts && attr != nullptr && + attr->type() == DT_STRING) { *uncompilable_reason = "Const op with type DT_STRING is not supported by XLA."; return false; @@ -359,7 +338,8 @@ bool RecursiveCompilabilityChecker::IsCompilableCall( const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); bool is_compilable = true; for (const Node* node : fbody->graph->op_nodes()) { - stack_trace->emplace_back(StackFrameView{node->name(), function.name()}); + stack_trace->emplace_back( + StackFrameView{node->name(), function.name(), node->GetStackTrace()}); is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace, &function, uncompilable_nodes); stack_trace->pop_back(); @@ -491,6 +471,15 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (!op_filter_.allow_collective_reduce_v2 && + node.type_string() == "CollectiveReduceV2") { + absl::string_view uncompilable_reason = "Collective op"; + MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, + encapsulating_function, uncompilable_nodes); + LogNotCompilable(node, uncompilable_reason); + return false; + } + if (!op_filter_.allow_ops_producing_or_consuming_variant && OpProducesOrConsumesVariant(node)) { absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer"; @@ -583,7 +572,8 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( [](const StackFrameView& stack_element) { return StackFrame{ std::string(stack_element.name), - std::string(stack_element.function_name)}; + std::string(stack_element.function_name), + stack_element.stack_trace}; }); node_info.name = std::string(stack_trace.back().name); @@ -690,8 +680,10 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes( static auto const ops_triggering_xla_compilation = new absl::flat_hash_set{"XlaBroadcastHelper", "XlaConv", + "XlaConvV2", "XlaDequantize", "XlaDot", + "XlaDotV2", "XlaDynamicSlice", "XlaDynamicUpdateSlice", "XlaEinsum", diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 65da072483b557..99a9e97b5b341b 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -24,11 +24,11 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" @@ -62,6 +62,7 @@ class RecursiveCompilabilityChecker { struct StackFrame { std::string name; std::string function_name; + std::shared_ptr stack_trace; }; // Contains information about uncompilable node inside a function body. @@ -128,6 +129,12 @@ class RecursiveCompilabilityChecker { // Require the function to be always compilable, regardless whether some // control flow branches might be dead for a given input. bool require_always_compilable = false; + + // Whether string constants are compilable. + bool allow_string_consts = true; + + // Whether to allow the compilation of CollectiveReduceV2Op. + bool allow_collective_reduce_v2 = true; }; RecursiveCompilabilityChecker(OperationFilter op_filter, @@ -153,20 +160,6 @@ class RecursiveCompilabilityChecker { const Node& node, FunctionLibraryRuntime* lib_runtime, const std::vector* node_stack_trace = nullptr) const; - // Returns a map where the key is the function identifier(short debug - // string) of the function encapsulating the uncompilable nodes, and the - // value is a pair of NameAttrList of the function and a vector of - // uncompilable node info. When uncompilable node is not inside any - // function call nodes, then key is a ShortDebugString() of an empty - // NameAttrList. - // - // Also, when `node` is inside a function body, users can set - // `node_stack_trace` to provide an additional context for `node`'s - // placement within the outer most graph. - UncompilableNodesMap FindUncompilableNodes( - const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, - const std::vector* node_stack_trace = nullptr) const; - // Returns true if `node` can be compiled by XLA. bool IsCompilableNode(const Node& node, FunctionLibraryRuntime* lib_runtime) const { @@ -175,15 +168,6 @@ class RecursiveCompilabilityChecker { return IsCompilableNode(node, lib_runtime, &stack_trace); } - // Returns true if `call_def` can be compiled by XLA. It is assumed that - // `call_def` is a call operation. - bool IsCompilableCall(const NodeDef& call_def, - FunctionLibraryRuntime* lib_runtime) { - std::vector stack_trace; - stack_trace.emplace_back(StackFrameView{call_def.name(), ""}); - return IsCompilableCall(call_def, lib_runtime, &stack_trace); - } - // Returns true if XLA supports this Op, but we don't want to cluster it (ie: // due to performance or correctness concerns). bool OpIsInaccurate(const Node& node) const; @@ -193,6 +177,7 @@ class RecursiveCompilabilityChecker { struct StackFrameView { absl::string_view name; absl::string_view function_name; + std::shared_ptr stack_trace; }; bool IsCompilableNode( @@ -270,7 +255,7 @@ class RecursiveCompilabilityChecker { UncompilableNodesMap* uncompilable_nodes_map); // Make sure we don't recurse infinitely on recursive functions. - const size_t kMaxRecursionDepth = 10; + const size_t kMaxRecursionDepth = 50; const OperationFilter op_filter_; const DeviceType jit_device_type_; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index d482642b44cfc6..fd55cab637c2e9 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -27,11 +27,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index efd2ef24c3bf05..f4bb9ca4271e95 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -933,6 +933,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1102,6 +1104,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT, DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph2}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1120,6 +1124,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph1}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1258,6 +1264,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1289,6 +1297,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1424,6 +1434,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1452,6 +1464,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1566,6 +1580,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1660,6 +1676,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1769,6 +1787,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -1881,6 +1901,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2017,6 +2039,8 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph1}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2033,6 +2057,8 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph2}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2165,6 +2191,8 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2183,6 +2211,8 @@ TEST(EncapsulateSubgraphsTest, {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2312,6 +2342,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2328,6 +2360,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2345,6 +2379,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"Toutputs", absl::Span({})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O3"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", NameAttrList()}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2473,6 +2509,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, @@ -2591,6 +2629,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"Toutputs", absl::Span({DT_FLOAT})}, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, + {"send_key", ""}, + {"recv_key", ""}, {"shape_inference_graph", shape_inference_graph}, {"tpu_core", 0}, {"cost_estimate_ns", 1000000}, diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 4a5c79c02d98fa..9e209f3342e6e0 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -43,7 +43,7 @@ bool IsCpuGpuCompile(const Graph* graph) { for (Node* n : graph->nodes()) { string name; // Only consider nodes being compiled. - if (!GetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name).ok()) continue; + if (!TryGetNodeAttr(n->attrs(), kXlaClusterIdAttr, &name)) continue; // Early return for any node with a device that is not a CPU or GPU. DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(n->requested_device(), &parsed)) { @@ -58,8 +58,8 @@ bool IsCpuGpuCompile(const Graph* graph) { // Checks if a graph node is marked to be a guaranteed constant. bool is_guaranteed_constant(const Node& n) { bool guaranteed_constant = false; - if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) - .ok()) { + if (!TryGetNodeAttr(n.attrs(), "_is_guaranteed_constant", + &guaranteed_constant)) { return false; } return guaranteed_constant; diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index fef43eb8730bee..75708a772e3ce0 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -565,6 +565,20 @@ void ReplaceLiftedArgNodePlaceholderWithArg( function_body.graph->RemoveNode(lifted_arg_node); } +// Adds function def to function definition library and update the function +// callsite operation `callsite_node` to invoke new function instead. +Status AddFunctionWithNewName(const std::string& new_name, + const std::string& func_attr_name, + const FunctionDef& function_def, + NameAttrList* func_attr, Node* callsite_node, + FunctionLibraryDefinition* fld) { + TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def)); + func_attr->set_name(new_name); + callsite_node->ClearAttr(func_attr_name); + callsite_node->AddAttr(func_attr_name, *func_attr); + return Status::OK(); +} + // Reconnect outside compilation lifted arguments in a functional While node to // its outside compilation tensor sources. Status PostprocessLiftedArgsForWhile( @@ -633,12 +647,15 @@ Status PostprocessLiftedArgsForWhile( *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node); } + const auto new_body_function_name = + fld->UniqueFunctionName(absl::StrCat(body_func.name(), "_lifted_arg_")); FunctionDef rewritten_body_function_def; TF_RETURN_IF_ERROR(GraphToFunctionDef( - *body_function_body->graph, body_func.name(), HostGraphControlRetMapping, - &rewritten_body_function_def)); - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(body_func.name(), rewritten_body_function_def)); + *body_function_body->graph, new_body_function_name, + HostGraphControlRetMapping, &rewritten_body_function_def)); + TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_body_function_name, "body", + rewritten_body_function_def, + &body_func, n, fld)); // In cond_graph, just add new _Arg nodes. NameAttrList cond_func; @@ -657,13 +674,15 @@ Status PostprocessLiftedArgsForWhile( TF_RETURN_IF_ERROR(arg_node_or.status()); } + const auto new_cond_function_name = + fld->UniqueFunctionName(absl::StrCat(cond_func.name(), "_lifted_arg_")); FunctionDef rewritten_cond_function_def; TF_RETURN_IF_ERROR(GraphToFunctionDef( - *cond_function_body->graph, cond_func.name(), HostGraphControlRetMapping, - &rewritten_cond_function_def)); - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(cond_func.name(), rewritten_cond_function_def)); - + *cond_function_body->graph, new_cond_function_name, + HostGraphControlRetMapping, &rewritten_cond_function_def)); + TF_RETURN_IF_ERROR(AddFunctionWithNewName(new_cond_function_name, "cond", + rewritten_cond_function_def, + &cond_func, n, fld)); return Status::OK(); } @@ -779,19 +798,25 @@ Status PostprocessLiftedArgsForIf( else_branch_lifted_arg_nodes, else_branch_arg_node); } + const auto new_then_function_name = fld->UniqueFunctionName( + absl::StrCat(then_branch_func.name(), "_lifted_arg_")); FunctionDef rewritten_then_branch_function_def; TF_RETURN_IF_ERROR(GraphToFunctionDef( - *then_branch_function_body->graph, then_branch_func.name(), + *then_branch_function_body->graph, new_then_function_name, HostGraphControlRetMapping, &rewritten_then_branch_function_def)); - TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(), - rewritten_then_branch_function_def)); + TF_RETURN_IF_ERROR(AddFunctionWithNewName( + new_then_function_name, "then_branch", rewritten_then_branch_function_def, + &then_branch_func, n, fld)); + const auto new_else_function_name = fld->UniqueFunctionName( + absl::StrCat(else_branch_func.name(), "_lifted_arg_")); FunctionDef rewritten_else_branch_function_def; TF_RETURN_IF_ERROR(GraphToFunctionDef( - *else_branch_function_body->graph, else_branch_func.name(), + *else_branch_function_body->graph, new_else_function_name, HostGraphControlRetMapping, &rewritten_else_branch_function_def)); - TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(), - rewritten_else_branch_function_def)); + TF_RETURN_IF_ERROR(AddFunctionWithNewName( + new_else_function_name, "else_branch", rewritten_else_branch_function_def, + &else_branch_func, n, fld)); return Status::OK(); } @@ -852,11 +877,19 @@ Status PostprocessLiftedArgsForCall( TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(), HostGraphControlRetMapping, &rewritten_fdef)); - TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef)); + const auto new_function_name = + fld->UniqueFunctionName(absl::StrCat(n->type_string(), "_lifted_arg_")); + rewritten_fdef.mutable_signature()->set_name(new_function_name); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(rewritten_fdef)); // We need to recreate the node. Otherwise TF will not know n->num_inputs() // has increased. NodeDef node_def = n->def(); + + // Function name is represented via the Op's type. Reset the op type to new + // function def name; + *node_def.mutable_op() = new_function_name; + for (int i = original_arg_count, end = data_types.size(); i < end; i++) { Node* outside_compilation_node = lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count] @@ -1439,14 +1472,15 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode( // Rewrites loop cond to add a node which sends loop cond to host. TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( - FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func, - const string& while_node_name, const string& host_transfer_key) { + const string& cond_xla_func_name, const string& host_transfer_key, + NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld, + Node* while_node) { // Instantiate the loop cond function. std::unique_ptr fbody; - const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func.name()); + const FunctionDef* loop_cond_fdef = fld->Find(loop_cond_func->name()); TF_RET_CHECK(loop_cond_fdef); TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( - *loop_cond_fdef, AttrSlice(&loop_cond_func.attr()), fld, &fbody)); + *loop_cond_fdef, AttrSlice(&loop_cond_func->attr()), fld, &fbody)); Graph* g = fbody->graph; // Find the _Retval node and the loop cond node. @@ -1455,7 +1489,7 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( if (n->type_string() == "_Retval") { if (ret_node) { return errors::Internal("Multiple return node for loop cond function ", - loop_cond_func.name(), ": ", + loop_cond_func->name(), ": ", ret_node->DebugString(), " and ", n->DebugString()); } else { @@ -1465,14 +1499,14 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( } if (!ret_node) { return errors::Internal("No _Retval node for loop cond function ", - loop_cond_func.name()); + loop_cond_func->name()); } Node* loop_cond; TF_RETURN_IF_ERROR(ret_node->input_node(0, &loop_cond)); // Build the XlaSendToHost node. NodeDefBuilder send_loop_cond_builder( - absl::StrCat("send_oc_while_cond_", while_node_name), "XlaSendToHost"); + absl::StrCat("send_oc_while_cond_", while_node->name()), "XlaSendToHost"); send_loop_cond_builder.Attr("Tinput", DT_BOOL); send_loop_cond_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); @@ -1488,11 +1522,26 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( TF_RETURN_IF_ERROR(s); g->AddEdge(loop_cond, 0, send_loop_cond_node, 0); - // Replace original function. + // Replace original function if loop_cond_func already has been re-written + // for outside compilation. FunctionDef replace_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*g, loop_cond_func.name(), &replace_fdef)); - TF_RETURN_IF_ERROR(fld->ReplaceFunction(loop_cond_func.name(), replace_fdef)); + if (loop_cond_func->name() == cond_xla_func_name) { + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, loop_cond_func->name(), &replace_fdef)); + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(loop_cond_func->name(), replace_fdef)); + } else { + // If original while cond function has not been modified, add a new function + // with send loop predicated added and update the while node callsite + // operation. + const auto new_name = fld->UniqueFunctionName( + absl::StrCat(loop_cond_func->name(), "_send_pred_added_")); + TF_RETURN_IF_ERROR(GraphToFunctionDef(*g, new_name, &replace_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef)); + loop_cond_func->set_name(new_name); + while_node->ClearAttr("cond"); + while_node->AddAttr("cond", *loop_cond_func); + } return Status::OK(); } @@ -2011,8 +2060,8 @@ Status ExtractOutsideCompilationForWhileNode( // XLA computation: rewrite cond function to add a SendToHost node to send // loop predicate. - TF_RETURN_IF_ERROR( - AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + TF_RETURN_IF_ERROR(AddSendLoopPredToLoopCond( + cond_xla_func_name, host_transfer_key, &cond, fld, n)); n->AddAttr(kXlaTokenInputNodesAttrName, std::vector{kXlaTokenArgNodeName}); diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 52d8fb94ff6e2f..b287b5fddc8e6c 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -163,6 +163,7 @@ void AllocateAndParseFlags() { ops_flags = new XlaOpsCommonFlags; ops_flags->tf_xla_always_defer_compilation = false; + ops_flags->tf_xla_async_compilation = false; jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags->jitter_amount = 1e-5; @@ -177,6 +178,7 @@ void AllocateAndParseFlags() { // bridge, on a per-graph basis). bool enable_mlir_bridge = false; bool enable_mlir_bridge_is_explicit = false; + bool mlir_bridge_safe_mode = false; auto setter_for_jitter_tensor_names = [](string sequence) { jitter_flags->tensor_names = absl::StrSplit(sequence, ','); @@ -192,11 +194,11 @@ void AllocateAndParseFlags() { "XLA clusters."), Flag("tf_xla_check_cluster_input_numerics", &build_ops_flags->tf_xla_check_cluster_input_numerics, - "If true then insert CheckNumerics nodes to to check all cluster " + "If true then insert CheckNumerics nodes to check all cluster " "inputs."), Flag("tf_xla_check_cluster_output_numerics", &build_ops_flags->tf_xla_check_cluster_output_numerics, - "If true then insert CheckNumerics nodes to to check all cluster " + "If true then insert CheckNumerics nodes to check all cluster " "outputs."), Flag("tf_xla_disable_constant_folding", &build_ops_flags->tf_xla_disable_constant_folding, @@ -215,6 +217,10 @@ void AllocateAndParseFlags() { Flag("tf_xla_always_defer_compilation", &ops_flags->tf_xla_always_defer_compilation, ""), + Flag("tf_xla_async_compilation", &ops_flags->tf_xla_async_compilation, + "When lazy compilation is enabled, asynchronous compilation starts " + "the cluster compilation in the background, and the fallback path " + "is executed until the compilation has finished."), Flag("tf_introduce_floating_point_jitter_to_tensors", setter_for_jitter_tensor_names, "", @@ -227,7 +233,13 @@ void AllocateAndParseFlags() { Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge, "Enables experimental MLIR-Based TensorFlow Compiler Bridge.", - &enable_mlir_bridge_is_explicit)}); + &enable_mlir_bridge_is_explicit), + Flag( + "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode, + "When tf_mlir_enable_mlir_bridge is true, this field can enable " + "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, " + "it only runs for graphs that use features MLIR bridge currently " + "supports.")}); AppendMarkForCompilationPassFlagsInternal(flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); @@ -235,10 +247,15 @@ void AllocateAndParseFlags() { mlir_flags = new MlirCommonFlags; if (!enable_mlir_bridge_is_explicit) { mlir_flags->tf_mlir_enable_mlir_bridge = - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + (mlir_bridge_safe_mode) + ? ConfigProto::Experimental:: + MLIR_BRIDGE_ROLLOUT_SAFE_MODE_FALLBACK_ENABLED + : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; } else if (enable_mlir_bridge) { mlir_flags->tf_mlir_enable_mlir_bridge = - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + (mlir_bridge_safe_mode) + ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED + : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; } else { mlir_flags->tf_mlir_enable_mlir_bridge = ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED; @@ -283,6 +300,37 @@ MlirCommonFlags* GetMlirCommonFlags() { return mlir_flags; } +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + absl::optional config_proto) { + // TF1 graphs that do not override Sessions's ConfigProto and TF2 graphs + // can enable/disable the graph via tf_mlir_enable_mlir_bridge. + auto tf_mlir_enable_mlir_bridge = + GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; + if (tf_mlir_enable_mlir_bridge != + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED) { + return tf_mlir_enable_mlir_bridge; + } + + // If a ConfigProto was not passed in, we can assume the caller is + // checking if TF2 graph should have the bridge enabled / disabled. In that + // case, we have already checked tf_mlir_enable_mlir_bridge so it is safe to + // return UNSPECIFIED here. + if (!config_proto.has_value()) { + return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED; + } + + // TF1 graphs that do override Session's ConfigProto and set + // ConfigProto's enable_mlir_bridge or mlir_bridge_rollout fields will not + // update tf_mlir_enable_mlir_bridge so check their values. + + // ConfigProto's enable_mlir_bridge defaults to false so only respect it + // when it is true. + if (config_proto.value().experimental().enable_mlir_bridge()) { + return ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + } + return config_proto.value().experimental().mlir_bridge_rollout(); +} + void AppendMarkForCompilationPassFlags(std::vector* flag_list) { absl::call_once(flags_init, &AllocateAndParseFlags); AppendMarkForCompilationPassFlagsInternal(flag_list); diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index a0860da7b04149..1981eed1b0afae 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/command_line_flags.h" @@ -39,7 +40,7 @@ struct XlaAutoJitFlag { int32 optimization_level_general; }; -// Sets the xla_auto_jit_flag based on the given flag sting. Supported syntax +// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax // is: // : sets general and single_gpu setting to the provided number. // single-gpu(): sets the single_gpu setting to the provided number. @@ -98,6 +99,9 @@ struct XlaOpsCommonFlags { // If true, _XlaCompile always refuses to compile the cluster, which means the // XLA clusters always run in the TF executor. Defaults to false. bool tf_xla_always_defer_compilation; + // If true, _XlaCompile compiles the cluster asynchronously with respect to + // the main execution. The fallback path is taken while compilation happens. + bool tf_xla_async_compilation; }; // Flags for the build_xla_ops pass. @@ -156,6 +160,11 @@ GetIntroduceFloatingPointJitterPassFlags(); MlirCommonFlags* GetMlirCommonFlags(); +// Returns the effective MLIR bridge rollout state based on the flags and the +// optional configuration. +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + absl::optional config_proto); + // Appends the flag definitions associated with // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. // diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 08b3bea1084c5c..7cbf427edd86f5 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -37,7 +37,8 @@ static xla::StatusOr GetLocalExecutable( const XlaCompiler::Options& options, const XlaCompiler::CompileOptions& compile_options, const NameAttrList& function, XlaCompilationCache* cache, - absl::Span args, const XlaCompiler& compiler) { + const std::vector& args, + const XlaCompiler& compiler) { const XlaCompiler::CompilationResult* compilation_result = nullptr; xla::LocalExecutable* executable = nullptr; TF_RETURN_IF_ERROR(cache->Compile(options, function, args, compile_options, @@ -100,12 +101,10 @@ xla::StatusOr GetCompilerIr( })); core::ScopedUnref cache_ref(cache); - absl::optional tf_allocator_adapter; - XlaCompiler::Options options = GenerateCompilerOptions(*cache, *flr, dev, /*stream=*/nullptr, platform_info, - /*has_ref_vars=*/false, &tf_allocator_adapter); + /*has_ref_vars=*/false); XlaCompiler::CompileOptions compile_options; compile_options.always_return_tuple = false; @@ -115,11 +114,12 @@ xla::StatusOr GetCompilerIr( xla::StatusOr> args = XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_arg_indices, inputs, variable_infos); + constant_arg_indices, inputs, variable_infos, dev); TF_RETURN_IF_ERROR(args.status()); switch (stage) { - case IrExportStage::HLO: { + case IrExportStage::HLO: + case IrExportStage::HLO_SERIALIZED: { XlaCompiler::CompilationResult result; TF_RETURN_IF_ERROR( compiler.CompileFunction(compile_options, function, *args, &result)); @@ -131,13 +131,23 @@ xla::StatusOr GetCompilerIr( std::unique_ptr new_module, xla::HloModule::CreateFromProto(result.computation->proto(), config)); - return new_module->ToString(); + if (stage == IrExportStage::HLO_SERIALIZED) { + return new_module->ToProto().SerializeAsString(); + } else { + return new_module->ToString(); + } } - case IrExportStage::OPTIMIZED_HLO: { + case IrExportStage::OPTIMIZED_HLO: + case IrExportStage::OPTIMIZED_HLO_SERIALIZED: { xla::StatusOr executable = GetLocalExecutable( options, compile_options, function, cache, *args, compiler); TF_RETURN_IF_ERROR(executable.status()); - return (*executable)->executable()->module().ToString(); + xla::Executable* new_executable = (*executable)->executable(); + if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) { + return new_executable->module().ToProto().SerializeAsString(); + } else { + return new_executable->module().ToString(); + } } case IrExportStage::OPTIMIZED_HLO_DOT: { xla::StatusOr executable = GetLocalExecutable( diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h index 0a0a1a44271475..db46cbcac837a6 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.h +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -27,10 +27,16 @@ class Tensor; class TensorHandle; class EagerContext; -enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; - -// Returns HLO text for a given function `func_name` using library runtime -// `runtime` on a device `dev` with given `inputs`. +enum class IrExportStage { + HLO, + HLO_SERIALIZED, + OPTIMIZED_HLO, + OPTIMIZED_HLO_SERIALIZED, + OPTIMIZED_HLO_DOT +}; + +// Returns the IR format of the selected stage for a given function `func_name` +// using library runtime `runtime` on a device `dev` with given `inputs`. xla::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, absl::string_view func_name, Device* dev, EagerContext* context, diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD deleted file mode 100644 index 23d994c27c52f9..00000000000000 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ /dev/null @@ -1,57 +0,0 @@ -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -package( - default_visibility = [ - "//tensorflow/compiler/tf2xla:internal", - ], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "graphcycles", - srcs = ["graphcycles.cc"], - hdrs = ["graphcycles.h"], - deps = [ - ":ordered_set", - "//tensorflow/core:lib", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "ordered_set", - hdrs = ["ordered_set.h"], - deps = [ - "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "graphcycles_test", - srcs = ["graphcycles_test.cc"], - deps = [ - ":graphcycles", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "ordered_set_test", - srcs = ["ordered_set_test.cc"], - deps = [ - ":ordered_set", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 1f400137f5b59e..e459dc14cb174b 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -11,6 +11,7 @@ package( XLA_OPS_DEPS = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/jit:flags", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 0f0f43cbad6667..ba359d75aeb7d9 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/synchronization/notification.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/flags.h" @@ -36,6 +37,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -166,8 +168,9 @@ static Status CompileToLocalExecutable( const XlaPlatformInfo& platform_info, absl::Span inputs, absl::Span variable_infos, - absl::Span constants, bool lazy, bool may_alias_resource_update, - xla::LocalClient** client, + absl::Span constants, + XlaCompilationCache::CompileMode compile_mode, + bool may_alias_resource_update, xla::LocalClient** client, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { // We store information about the JIT-compiled XLA computation @@ -190,11 +193,10 @@ static Status CompileToLocalExecutable( *client = static_cast(cache->client()); - absl::optional tf_allocator_adapter; XlaCompiler::Options options = GenerateCompilerOptions( *cache, *ctx->function_library(), ctx->device(), ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, - platform_info, has_ref_vars, &tf_allocator_adapter); + platform_info, has_ref_vars); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; @@ -202,19 +204,80 @@ static Status CompileToLocalExecutable( // rather than a one-element tuple. compile_options.always_return_tuple = false; compile_options.alias_resource_update = !has_ref_vars && - !platform_info.is_on_xla_device() && may_alias_resource_update; xla::StatusOr> args = - XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs, - variable_infos); + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constants, inputs, variable_infos, + static_cast(ctx->device())); TF_RETURN_IF_ERROR(args.status()); - return cache->Compile(options, function, *args, compile_options, - lazy ? XlaCompilationCache::CompileMode::kLazy - : XlaCompilationCache::CompileMode::kStrict, + return cache->Compile(options, function, *args, compile_options, compile_mode, compilation_result, executable); } +// Resolve the device assignment for the TF single-host MirroredStrategy by +// calling into TF runtime which in turn would start a rendezvous. +static xla::StatusOr ResolveDeviceAssignment( + OpKernelContext* ctx, + const absl::optional< + XlaCompiler::CompilationResult::CollectiveReduceV2OpInfo>& + collective_reduce_info) { + static const int kTimeoutSeconds = 30; + if (!collective_reduce_info) { + // An empty device assignment is sufficient for the case where no + // collectives are present. + return xla::DeviceAssignment{}; + } + + CollectiveParams params; + params.name = "xla-reduction-compilation"; + params.group.device_type = + DeviceType{static_cast(ctx->device())->device_type()}; + params.group.group_size = collective_reduce_info->group_size; + params.group.group_key = collective_reduce_info->group_key; + params.instance.type = REDUCTION_COLLECTIVE; + params.instance.impl_details.communication_hint = "nccl"; + params.instance.impl_details.timeout_seconds = kTimeoutSeconds; + params.instance.impl_details.collective_name = "NcclReduce"; + // TODO(cheshire): Avoid passing a dummy shape, TF runtime does not resolve + // devices otherwise. + params.instance.shape = TensorShape({1}); + + Status st; + absl::Notification n; + ctx->collective_executor()->CompleteParamsAsync( + ctx->device()->attributes(), ¶ms, ctx->cancellation_manager(), + [&](const Status& s) { + st = s; + n.Notify(); + }); + if (!n.WaitForNotificationWithTimeout(absl::Seconds(kTimeoutSeconds))) { + return errors::InvalidArgument("Timeout reached"); + } + TF_RETURN_IF_ERROR(st); + const std::vector& devices = params.group.device_names; + + xla::DeviceAssignment out(devices.size(), 1); + for (int device_idx = 0; device_idx < devices.size(); device_idx++) { + const std::string& device_name = devices[device_idx]; + Device* resolved_device = nullptr; + TF_RETURN_IF_ERROR(ctx->function_library()->device_mgr()->LookupDevice( + device_name, &resolved_device)); + + // TODO(cheshire): CPU support. + const DeviceBase::GpuDeviceInfo* gpu_device_info = + resolved_device->tensorflow_gpu_device_info(); + if (!gpu_device_info || !gpu_device_info->stream) { + return errors::Internal( + "CollectiveReduceV2Op compilation is only supported on GPUs"); + } + + out(device_idx, 0) = gpu_device_info->stream->parent()->device_ordinal(); + } + + return out; +} + void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); @@ -232,7 +295,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status s = CompileToLocalExecutable( ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, inputs, - variable_infos, constants_, /*lazy=*/false, + variable_infos, constants_, XlaCompilationCache::CompileMode::kStrict, /*may_alias_resource_update=*/true, &client, &compilation_result, &executable); OP_REQUIRES_OK(ctx, s); @@ -245,14 +308,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - - VLOG(1) << "Executing XLA Computation..."; - - absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = GetAllocator( - &tf_allocator_adapter, ctx->device(), - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, - platform_info_); + std::shared_ptr allocator_ptr = + GetAllocator(ctx->device(), stream, platform_info_); + se::DeviceMemoryAllocator* allocator = allocator_ptr.get(); int device_ordinal = stream ? stream->parent()->device_ordinal() : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( @@ -269,11 +327,23 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // Execute the computation. VLOG(2) << "Executing computation."; + xla::StatusOr device_assignment = + ResolveDeviceAssignment(ctx, compilation_result->collective_reduce_info); + OP_REQUIRES_OK(ctx, device_assignment.status()); + xla::ExecutableRunOptions run_options; + run_options.set_device_assignment(&*device_assignment); run_options.set_stream(stream); run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); + + // Hardcode run id to always be zero: TF distributed strategy differentiates + // between subsequent runs using dependency edges. + // This is safe, as only TF dist-strat can produce distributed ops, and we can + // rely on TF dist-strat invariants. + xla::RunId run_id(0); + run_options.set_run_id(run_id); Env* env = Env::Default(); auto start_time = env->NowMicros(); @@ -382,6 +452,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { mutex_lock guard(cannot_compile_cluster_mu_); cannot_compile_cluster = cannot_compile_cluster_; } + XlaCompilationCache::CompileMode compile_mode = [&] { + if (must_compile_) { + return XlaCompilationCache::CompileMode::kStrict; + } + return GetXlaOpsCommonFlags().tf_xla_async_compilation + ? XlaCompilationCache::CompileMode::kAsync + : XlaCompilationCache::CompileMode::kLazy; + }(); if (GetXlaOpsCommonFlags().tf_xla_always_defer_compilation || cannot_compile_cluster) { @@ -397,12 +475,12 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { // unlocking them in XlaRun may lead to deadlocks. Status status = CompileToLocalExecutable( ctx, function_, has_ref_vars_, platform_info_, inputs, variable_infos, - constants_, - /*lazy=*/!must_compile_, - /*may_alias_resource_update=*/false, &client, &kernel, &executable); + constants_, compile_mode, /*may_alias_resource_update=*/false, &client, + &kernel, &executable); OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, variable_infos, &variables)); - if (must_compile_ || status.code() != error::UNIMPLEMENTED) { + if (compile_mode != XlaCompilationCache::CompileMode::kLazy || + status.code() != error::UNIMPLEMENTED) { OP_REQUIRES_OK(ctx, status); } @@ -424,6 +502,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { host_alloc_attrs.set_on_host(true); Allocator* cpu_allocator = ctx->device()->GetAllocator(host_alloc_attrs); + // Async compilation returns nullptr executable without an error. if (!executable) { DCHECK(!must_compile_); Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); @@ -464,13 +543,11 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { XlaExecutableClosure closure = XlaExecutableClosureStore::Global()->Consume(key); - absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = GetAllocator( - &tf_allocator_adapter, ctx->device(), - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, - platform_info_); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + std::shared_ptr allocator_ptr = + GetAllocator(ctx->device(), stream, platform_info_); + se::DeviceMemoryAllocator* allocator = allocator_ptr.get(); int device_ordinal = stream ? stream->parent()->device_ordinal() : closure.client()->default_device_ordinal(); XlaComputationLaunchContext launch_context( diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index ada7766fcbb399..a172a81766525f 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -30,12 +30,12 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" @@ -1199,6 +1199,8 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { RecursiveCompilabilityChecker::OperationFilter filter = CreateOperationFilter(*registration); filter.require_always_compilable = true; + filter.allow_string_consts = false; + filter.allow_collective_reduce_v2 = false; RecursiveCompilabilityChecker checker( filter, DeviceType{registration->compilation_device_name}); @@ -1207,6 +1209,15 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } + if (node->type_string() == "Const") { + // Skip Const op with type DT_STRING, since XLA autoclustering doesn't + // support it. + const AttrValue* attr = node->attrs().Find("dtype"); + if (attr != nullptr && attr->type() == DT_STRING) { + continue; + } + } + if (!allowlist.empty() && !allowlist.contains(node->def().op())) { VLOG(1) << "Rejecting TF operation " << node->def().op() << " as it is not listed in --tf_xla_ops_to_cluster."; @@ -1775,7 +1786,7 @@ absl::flat_hash_map>* GetAllowlistTable() { "Identity", "IdentityN", "Relu", "Relu6", "ReluGrad", "Relu6Grad", "LeakyReluGrad", "Elu", "EluGrad", "Selu", "SeluGrad", "Select", "SelectV2", "Transpose", "ConjugateTranspose", - "_UnaryOpsComposition", + "_UnaryOpsComposition", "CollectiveReduceV2", // The following 4 operations are converted to identity "PlaceholderWithDefault", "PreventGradient", "StopGradient", "Snapshot"}}, @@ -1801,11 +1812,11 @@ absl::flat_hash_map>* GetAllowlistTable() { "Range", "Rank", "Reshape", "Shape", "ShapeN", "Size", "Squeeze", "Transpose", "ZerosLike", "OnesLike", "BiasAdd" /*PW + Broadcast*/, "BroadcastArgs", "BroadcastGradientArgs", "OneHot", "Concat", "ConcatV2", - "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", - "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", - "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex", - "TensorStridedSliceUpdate", + "ConcatOffset", "Const", "MirrorPad", "MirrorPadGrad", "Pack", "Pad", + "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", + "SplitV", "StridedSlice", "StridedSliceGrad", + "ResourceStridedSliceAssign", "Tile", "Transpose", "InvertPermutation", + "Unpack", "DeviceIndex", "TensorStridedSliceUpdate", }}}; // clang-format on return result; @@ -1990,6 +2001,8 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "StatelessCase", "StatelessIf", "StatelessMultinomial", + "StatelessRandomGetAlg", + "StatelessRandomGetKeyCounter", "StatelessRandomGetKeyCounterAlg", "StatelessRandomNormal", "StatelessRandomNormalV2", @@ -2033,6 +2046,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "TensorScatterUpdate", "TridiagonalSolve", "TruncatedNormal", + "Unique", "UpperBound", "UnsortedSegmentMax", "UnsortedSegmentMin", @@ -2040,11 +2054,14 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "UnsortedSegmentSum", "VarIsInitializedOp", "VariableShape", + "Where", "While", "XlaBroadcastHelper", "XlaConv", + "XlaConvV2", "XlaDequantize", "XlaDot", + "XlaDotV2", "XlaDynamicSlice", "XlaDynamicUpdateSlice", "XlaEinsum", @@ -2061,12 +2078,14 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "XlaSelfAdjointEig", "XlaSend", "XlaSetBound", + "XlaSetDynamicDimensionSize", "XlaSharding", "XlaSort", "XlaSpmdFullToShardShape", "XlaSpmdShardToFullShape", "XlaSvd", "XlaVariadicReduce", + "XlaVariadicSort", "XlaWhile", "Zeta", "_Arg", diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 6ca8fd0e34a14f..4bbc8fba3c0755 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -22,5 +22,6 @@ tf_gen_op_wrapper_py( py_library( name = "xla_ops_grad", srcs = ["xla_ops_grad.py"], + srcs_version = "PY3", deps = ["//tensorflow/python:framework_ops"], ) diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h index c652e5fe216447..3931ae6c7cc079 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.h +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ #define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index e2a1d159336c7c..bf6dd5ab9f4951 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/compiler/xla/service/graphcycles/graphcycles.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 461a6692c84474..112287b80fb07d 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -17,12 +17,15 @@ limitations under the License. #include +#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h" #include "absl/base/call_once.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/utils/array_container_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -47,14 +50,12 @@ limitations under the License. #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/dump_graph.h" -#if !defined(LIBTPU_ON_GCE) -#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" -#include "tensorflow/compiler/mlir/utils/array_container_utils.h" -#endif - namespace tensorflow { constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold; +constexpr int64 XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads; +constexpr int64 + XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations; XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client, DeviceType device_type) @@ -70,6 +71,12 @@ XlaCompilationCache::~XlaCompilationCache() { "programs to complete"; } } + // Wait for all outstanding compilations to finish. + // Resetting the pointer explicitly in the top level destructor. + // Without this, the pointer would be reset when the AsyncCompilationState + // is destructed, which is dependent on the order of the members in the + // XlaCompilationCache class, which is error prone if the order changes. + async_compilation_state_.compiler_threads.reset(); // TODO(b/110813685): Think about the program ownership model. Programs are // currently owned by the compilation cache which means we must wait for // program completion in the destructor. There are multiple compilation caches @@ -139,6 +146,7 @@ XlaCompilationCache::BuildSignature( for (const XlaCompiler::Argument& arg : args) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: + case XlaCompiler::Argument::kConstantResource: signature.arg_values.push_back(arg.constant_value); break; case XlaCompiler::Argument::kParameter: @@ -167,13 +175,17 @@ Status XlaCompilationCache::BuildExecutable( argument_layouts[i] = &result.xla_input_shapes[i]; } xla::ExecutableBuildOptions build_options; + if (result.collective_reduce_info) { + build_options.set_num_replicas(result.collective_reduce_info->group_size); + } build_options.set_device_ordinal(options.device_ordinal != -1 ? options.device_ordinal : client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); - build_options.set_device_allocator(options.device_allocator); + build_options.set_device_allocator(options.device_allocator.get()); build_options.set_alias_passthrough_params(options.alias_passthrough_params); - + build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping( + options.detailed_logging); TF_ASSIGN_OR_RETURN( auto executables, client_->Compile(*result.computation, argument_layouts, build_options)); @@ -184,21 +196,22 @@ Status XlaCompilationCache::BuildExecutable( Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, - absl::Span args, + const std::vector& args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { - absl::optional compile_threshold; - if (compile_mode == CompileMode::kLazy) { - compile_threshold = kDefaultCompilationThreshold; - } - auto compile_fn = [&](XlaCompiler* compiler, + // !!Pay attention when additional variables must be captured by this + // lambda!! compile_fn can run asynchronously after this funcion has + // exited. Make sure that any variable needed inside compile_fn is + // either passed as an argument, or captured by value right here. + auto compile_fn = [compile_options, function]( + XlaCompiler* compiler, + const std::vector& args, XlaCompiler::CompilationResult* result) { return compiler->CompileFunction(compile_options, function, args, result); }; - return CompileImpl(options, function, args, compile_fn, - /*compile_threshold=*/compile_threshold, + return CompileImpl(options, function, args, compile_fn, compile_mode, out_compilation_result, out_executable); } @@ -261,7 +274,7 @@ static xla::StatusOr> CreateGraph( Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, - absl::Span args, OpKernelContext* ctx, + const std::vector& args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { @@ -274,6 +287,7 @@ Status XlaCompilationCache::CompileSingleOp( // and causes false uniqueness between nodes. name.mutable_attr()->erase("_class"); auto compile_op = [&](XlaCompiler* compiler, + const std::vector& args, XlaCompiler::CompilationResult* result) { std::vector result_dtypes(ctx->num_outputs()); for (int i = 0, end = result_dtypes.size(); i < end; ++i) { @@ -283,23 +297,15 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - // TODO(b/155596779): Support TensorList args. - bool has_tensor_list_arg = - absl::c_any_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kTensorList; - }); const ConfigProto* config = ctx->function_library()->config_proto(); // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR. - bool use_mlir = config && config->experimental().enable_mlir_bridge() && - !has_tensor_list_arg && + bool use_mlir = config && + GetMlirBridgeRolloutPolicy( + *graph, /*function_library=*/nullptr, + *config, /*uses_uninitialized_resource_args=*/ + AnyUninitializedResourceArg(args)) == + MlirBridgeRolloutPolicy::kEnabledByUser && node_def.op() != "VarIsInitializedOp"; -#ifdef LIBTPU_ON_GCE - if (use_mlir) { - LOG(WARNING) << "MLIR is not supported in this environment."; - } - return compiler->CompileGraph(compile_options, node_def.name(), - std::move(graph), args, result); -#else if (!use_mlir) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); @@ -315,10 +321,8 @@ Status XlaCompilationCache::CompileSingleOp( *graph, mlir::SpanToArrayRef(args), control_rets, options.device_type.type_string(), compile_options.use_tuple_arg, *options.flib_def, debug_info, options.shape_representation_fn, result); -#endif }; - return CompileImpl(options, name, args, compile_op, - /*compile_threshold=*/absl::nullopt, + return CompileImpl(options, name, args, compile_op, CompileMode::kStrict, out_compilation_result, out_executable); } @@ -336,12 +340,113 @@ void LogOnceXlaCompiledFirstCluster() { } } // namespace +Status XlaCompilationCache::CompileStrict( + Entry* entry, const XlaCompiler::Options& options, + const std::vector& args, const string& function_name, + const std::function& args, + XlaCompiler::CompilationResult*)>& compile_fn) { + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); + + XlaCompiler compiler(options); + entry->compile_state = CompileState::kCompiled; + + entry->compilation_status = + compile_fn(&compiler, args, &entry->compilation_result); + TF_RETURN_IF_ERROR(entry->compilation_status); + TF_RET_CHECK(entry->executable.get() == nullptr); + entry->compilation_status = + BuildExecutable(options, entry->compilation_result, &entry->executable); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + metrics::UpdateXlaCompilationTime(compile_time_us); + { + mutex_lock lock(cluster_compile_stats_mu_); + auto it = cluster_compile_stats_.find(function_name); + const uint64 compile_time_s = compile_time_us / 1.0e6; + it->second.compile_count++; + it->second.cumulative_compile_time_us += compile_time_us; + + LogOnceXlaCompiledFirstCluster(); + VLOG(1) << "compiled " << function_name << " " << it->second.compile_count + << " times, compile time: " << compile_time_us + << " us, cumulative: " << it->second.cumulative_compile_time_us + << " us (" + << tensorflow::strings::HumanReadableElapsedTime(compile_time_s) + << " / " + << tensorflow::strings::HumanReadableElapsedTime( + it->second.cumulative_compile_time_us / 1.0e6) + << ")"; + + XlaJitCompilationActivity jit_compilation_activity; + jit_compilation_activity.set_cluster_name(function_name); + jit_compilation_activity.set_compile_count(it->second.compile_count); + jit_compilation_activity.set_compile_time_us(compile_time_us); + jit_compilation_activity.set_cumulative_compile_time_us( + it->second.cumulative_compile_time_us); + TF_RETURN_IF_ERROR( + BroadcastXlaActivity(std::move(jit_compilation_activity))); + } + + return Status::OK(); +} + +Status XlaCompilationCache::CompileAsynchronous( + Entry* entry, const XlaCompiler::Options& options, + const std::vector& args, const string& function_name, + const std::function& args, + XlaCompiler::CompilationResult*)>& compile_fn) { + // Explicitly capture all required data by value for async compilation. + entry->compile_state = CompileState::kCompiling; + { + mutex_lock lock(async_compilation_state_.async_compilation_state_mu); + async_compilation_state_.num_ongoing_compilations++; + } + // Don't move the above code into the thread function as it synchronously + // updates the async compilation state! + + // When the ThreadPool for the compilation cache is destroyed, it waits for + // compilations to have finished. This means that both 'entry' and 'this' will + // be alive for the duration of the compilation. + // !!Pay attention when additional variables must be captured by this lambda!! + // All values are captured by value. Make sure that all pointer values (like + // entry) do not get freed until the lambda has finished,\. + async_compilation_state_.compiler_threads->Schedule([=] { + Entry local_entry; + VLOG(2) << "Starting asynchronous compilation of cluster " << function_name + << '.'; + // We don't need to lock local_entry.mu, but do it anyway to satisfy + // thread safety analysis. + mutex_lock entry_lock(local_entry.mu); + (void)CompileStrict(&local_entry, options, args, function_name, compile_fn); + + VLOG(2) << "Finished asynchronous compililation of cluster " + << function_name << '.'; + { + mutex_lock lock(async_compilation_state_.async_compilation_state_mu); + async_compilation_state_.num_ongoing_compilations--; + } + { // Populate original entry with compilation result. + mutex_lock entry_lock(entry->mu); + entry->compilation_result = local_entry.compilation_result; + entry->compile_state = local_entry.compile_state; + entry->compilation_status = local_entry.compilation_status; + entry->executable = std::move(local_entry.executable); + } + }); + return Status::OK(); +} + Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, - absl::Span args, + const std::vector& args, const std::function& args, XlaCompiler::CompilationResult*)>& compile_fn, - absl::optional compile_threshold, + CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable) { if (FailOnXlaCompilation()) { @@ -357,9 +462,20 @@ Status XlaCompilationCache::CompileImpl( VLOG(3) << i << ": " << args[i].HumanString(); } } + absl::optional compile_threshold; + if (compile_mode == CompileMode::kLazy) { + compile_threshold = kDefaultCompilationThreshold; + } else if (compile_mode == CompileMode::kAsync) { + compile_threshold = 0; // for now, always compile right away. + } TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args)); - VLOG(2) << "Signature: " << signature.HumanString(); + + string human_signature; + if (VLOG_IS_ON(2)) { + human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name(); + VLOG(2) << "Signature: " << human_signature; + } // The outer lock protects the existence of the cache entry. It does not // protect the contents of the cache entry. @@ -411,14 +527,18 @@ Status XlaCompilationCache::CompileImpl( // cache eviction. mutex_lock entry_lock(entry->mu); int64 current_request_count = ++entry->request_count; - VLOG(2) << "Compilation cache entry hit: " << entry->compiled - << " signature: " << signature.HumanString() << " with request count " + VLOG(2) << "Compilation cache entry hit: " + << static_cast(entry->compile_state) + << " signature: " << human_signature << " with request count " << current_request_count << " and compile threshold " << compile_threshold.value_or(0); - if (!entry->compiled) { + // TODO(sanjoy): Refactor this code into helper functions. + bool return_null = false; + CompileState state = entry->compile_state; + if (state == CompileState::kUncompiled) { XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable"); const bool should_compile = [&] { - if (!compile_threshold.has_value()) { + if (compile_mode == CompileMode::kStrict) { // Lazy compilation is disabled. return true; } @@ -427,7 +547,7 @@ Status XlaCompilationCache::CompileImpl( BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION, function.name()) .IgnoreError(); - VLOG(3) << "Not compiling cluster " << function.name() + VLOG(2) << "Not compiling cluster " << function.name() << " because it is megamorphic."; return false; } @@ -436,10 +556,21 @@ Status XlaCompilationCache::CompileImpl( return true; } + if (compile_mode == CompileMode::kAsync) { + // Asynchronous compilation is enabled. + mutex_lock lock(async_compilation_state_.async_compilation_state_mu); + if (async_compilation_state_.num_ongoing_compilations >= + async_compilation_state_.kMaxNumOngoingCompilations) { + VLOG(2) << "Not asynchronously compiling cluster " << function.name() + << " because of too many ongoing compilations."; + return false; + } + } + bool reached_compile_threshold = current_request_count >= *compile_threshold; if (!reached_compile_threshold) { - VLOG(3) + VLOG(2) << "Not compiling cluster " << function.name() << " because it has not reached compile threshold; threshold is " << *compile_threshold << " execution count " @@ -449,62 +580,34 @@ Status XlaCompilationCache::CompileImpl( }(); if (!should_compile) { - VLOG(2) << "Not compiling for signature: " << signature.HumanString(); - *out_compilation_result = nullptr; - *out_executable = nullptr; - return Status::OK(); - } - - tensorflow::Env* env = tensorflow::Env::Default(); - const uint64 compile_start_us = env->NowMicros(); - // Do the actual JIT compilation without holding the lock (it can take - // a long time.) - - XlaCompiler compiler(options); - entry->compiled = true; - - entry->compilation_status = - compile_fn(&compiler, &entry->compilation_result); - TF_RETURN_IF_ERROR(entry->compilation_status); - CHECK_EQ(entry->executable.get(), nullptr); - entry->compilation_status = - BuildExecutable(options, entry->compilation_result, &entry->executable); - - const uint64 compile_end_us = env->NowMicros(); - const uint64 compile_time_us = compile_end_us - compile_start_us; - metrics::UpdateXlaCompilationTime(compile_time_us); - { - mutex_lock lock(cluster_compile_stats_mu_); - auto it = cluster_compile_stats_.find(function.name()); - it->second.compile_count++; - it->second.cumulative_compile_time_us += compile_time_us; - LogOnceXlaCompiledFirstCluster(); - VLOG(1) << "compiled " << function.name() << " " - << it->second.compile_count - << " times, compile time: " << compile_time_us - << " us, cumulative: " << it->second.cumulative_compile_time_us - << " us (" - << tensorflow::strings::HumanReadableElapsedTime(compile_time_us / - 1.0e6) - << " / " - << tensorflow::strings::HumanReadableElapsedTime( - it->second.cumulative_compile_time_us / 1.0e6) - << ")"; - - XlaJitCompilationActivity jit_compilation_activity; - jit_compilation_activity.set_cluster_name(function.name()); - jit_compilation_activity.set_compile_count(it->second.compile_count); - jit_compilation_activity.set_compile_time_us(compile_time_us); - jit_compilation_activity.set_cumulative_compile_time_us( - it->second.cumulative_compile_time_us); - + VLOG(2) << "Not compiling for signature: " << human_signature; + return_null = true; + } else if (compile_mode == CompileMode::kAsync) { + VLOG(2) << "Queueing asynchronous compilation for signature: " + << human_signature; + TF_RETURN_IF_ERROR(CompileAsynchronous(entry, options, args, + function.name(), compile_fn)); + return_null = true; + } else { + VLOG(2) << "Instantly compiling for signature: " << human_signature; TF_RETURN_IF_ERROR( - BroadcastXlaActivity(std::move(jit_compilation_activity))); + CompileStrict(entry, options, args, function.name(), compile_fn)); } + } else if (state == CompileState::kCompiling) { + VLOG(2) << "Ongoing asynchronous compilation for signature: " + << human_signature; + return_null = true; + } else if (state == CompileState::kCompiled) { + VLOG(2) << "Already Compiled for signature: " << human_signature; + } + if (return_null) { + *out_compilation_result = nullptr; + *out_executable = nullptr; + } else { + TF_RETURN_IF_ERROR(entry->compilation_status); + *out_compilation_result = &entry->compilation_result; + *out_executable = entry->executable.get(); } - TF_RETURN_IF_ERROR(entry->compilation_status); - *out_compilation_result = &entry->compilation_result; - *out_executable = entry->executable.get(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index cd58cf31988f9e..c84bc6ddebf982 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -50,6 +50,13 @@ class XlaCompilationCache : public ResourceBase { enum class CompileMode { kLazy, kStrict, + kAsync, + }; + + enum class CompileState { + kUncompiled, + kCompiling, + kCompiled, }; // Compiles a function into a XlaCompiler::CompilationResult that can be used @@ -62,7 +69,9 @@ class XlaCompilationCache : public ResourceBase { // heuristics, the compilation cache may decide not to compile the cluster at // this time. In this case it returns null into both `out_compilation_result` // and `out_executable`. If `compile_mode` is `kStrict` then the compilation - // cache always attempts the compilation on a cache miss. + // cache always attempts the compilation on a cache miss. If compilation mode + // is 'kAsync' compilation of the cluster happens in the background while the + // fallback path executes. // // The result of compilation is written to `*out_compilation_result`, which // must be non-null. If `out_executable` is non-null, also builds an @@ -71,7 +80,7 @@ class XlaCompilationCache : public ResourceBase { // non-constant outputs. Status Compile(const XlaCompiler::Options& options, const NameAttrList& function, - absl::Span args, + const std::vector& args, const XlaCompiler::CompileOptions& compile_options, CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, @@ -83,7 +92,7 @@ class XlaCompilationCache : public ResourceBase { // XlaCompiler, if possible. Status CompileSingleOp( const XlaCompiler::Options& options, - absl::Span args, OpKernelContext* ctx, + const std::vector& args, OpKernelContext* ctx, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -126,10 +135,11 @@ class XlaCompilationCache : public ResourceBase { // Common implementation of Compile and CompileSingleOp. Status CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, - absl::Span args, + const std::vector& args, const std::function& args, XlaCompiler::CompilationResult*)>& compile_fn, - absl::optional compile_threshold, + CompileMode compile_mode, const XlaCompiler::CompilationResult** out_compilation_result, xla::LocalExecutable** out_executable); @@ -146,8 +156,8 @@ class XlaCompilationCache : public ResourceBase { struct Entry { mutex mu; - // Have we tried compiling this entry? - bool compiled = false; + // The current compilation state for this entry. + CompileState compile_state = CompileState::kUncompiled; // The number of times a compilation with this signature has been requested. int64 request_count = 0; @@ -163,6 +173,22 @@ class XlaCompilationCache : public ResourceBase { std::unique_ptr executable TF_GUARDED_BY(mu); }; + Status CompileStrict( + Entry* entry, const XlaCompiler::Options& options, + const std::vector& args, + const string& function_name, + const std::function& args, + XlaCompiler::CompilationResult*)>& compile_fn) + TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu); + Status CompileAsynchronous( + Entry* entry, const XlaCompiler::Options& options, + const std::vector& args, + const string& function_name, + const std::function& args, + XlaCompiler::CompilationResult*)>& compile_fn); + mutex compile_cache_mu_; absl::flat_hash_map, Signature::Hash> cache_ TF_GUARDED_BY(compile_cache_mu_); @@ -189,6 +215,30 @@ class XlaCompilationCache : public ResourceBase { absl::flat_hash_map cluster_compile_stats_ TF_GUARDED_BY(cluster_compile_stats_mu_); + struct AsyncCompilationState { + mutex async_compilation_state_mu; + + // Number of threads for asynchronous compilations. + static constexpr int64 kNumCompilerThreads = 10; + + // Maximum number of ongoing compilations. + static constexpr int64 kMaxNumOngoingCompilations = kNumCompilerThreads; + + // Number of ongoing compilations. + int64 num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) = + 0; + + // Pool of threads for asynchronous compilations. + std::unique_ptr compiler_threads; + + AsyncCompilationState() { + compiler_threads = absl::make_unique( + tensorflow::Env::Default(), "async_compiler_threads", + kNumCompilerThreads); + } + + } async_compilation_state_; + // The number of times a lazy compilation must be requested for a specific // signature before we attempt to compile it. static constexpr int64 kDefaultCompilationThreshold = 2; diff --git a/tensorflow/compiler/jit/xla_compilation_cache_test.cc b/tensorflow/compiler/jit/xla_compilation_cache_test.cc index 5578925b7901a6..e40d6221324bcf 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache_test.cc @@ -78,7 +78,9 @@ TEST(XlaCompilationCacheTest, TestDisabledXlaCompilation) { absl::StrContains(status.error_message(), "XLA compilation disabled")); } -static void BM_BuildSignature(int iters, int n_args) { +void BM_BuildSignature(::testing::benchmark::State& state) { + const int n_args = state.range(0); + NameAttrList fn; fn.set_name("afunction"); for (int i = 0; i < n_args; i++) { @@ -93,7 +95,7 @@ static void BM_BuildSignature(int iters, int n_args) { args[i].constant_value = Tensor(DT_INT32, {4, 0}); } - while (--iters > 0) { + for (auto i : state) { xla::StatusOr s = XlaCompilationCache::BuildSignature(fn, args); CHECK(s.ok()); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index d092508eccf811..f1df174612f4d5 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,11 +48,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const ResourceVarsSnapshot& variable_args) { xla::LocalClient* client = static_cast(cache->client()); - absl::optional tf_allocator_adapter; - se::DeviceMemoryAllocator* allocator = GetAllocator( - &tf_allocator_adapter, ctx->device(), - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, - platform_info_); + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + std::shared_ptr allocator_ptr = + GetAllocator(ctx->device(), stream, platform_info_); + se::DeviceMemoryAllocator* allocator = allocator_ptr.get(); XlaComputationLaunchContext launch_context( client, allocator, client->default_device_ordinal(), /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr, @@ -74,9 +74,6 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, input_output_alias); TF_RETURN_IF_ERROR(execution_inputs.status()); - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - VLOG(2) << "Executing computation: " << name(); xla::ExecutableRunOptions run_options; run_options.set_stream(stream); @@ -126,13 +123,12 @@ Status XlaCompileOnDemandOp::Compile( write_into_cache); })); - absl::optional tf_allocator_adapter; XlaCompiler::Options options = GenerateCompilerOptions( **cache, *ctx->function_library(), ctx->device(), ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr, - platform_info_, - /*has_ref_vars=*/true, &tf_allocator_adapter); - + platform_info_, /*has_ref_vars=*/true); + // No detailed logging from on demand op. + options.detailed_logging = false; XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; // Optimization: where possible, have the computation return a naked array @@ -152,7 +148,8 @@ Status XlaCompileOnDemandOp::Compile( ctx, variables_indices, variable_infos, variable_args)); args = XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_input_indices, inputs, variable_infos); + constant_input_indices, inputs, variable_infos, + static_cast(ctx->device())); TF_RETURN_IF_ERROR(args.status()); } diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index dd1ddb616f59ad..c4edd86f015c03 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -38,7 +38,7 @@ class XlaCpuDeviceFactory : public DeviceFactory { Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 089d22dca03537..f0e236de511fde 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -398,7 +398,7 @@ static void ShowXlaDeviceDeprecationWarning( absl::call_once(once, [] { LOG(INFO) << "XLA_GPU and XLA_CPU devices are deprecated and will be " "removed in subsequent releases. Instead, use either " - "@tf.function(experimental_compile=True) for must-compile " + "@tf.function(jit_compile=True) for must-compile " "semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 " "for auto-clustering best-effort compilation."; }); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 17e4226405a271..d811089d3c6bbf 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -21,9 +21,11 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/kernels/constant_op.h" +#include "tensorflow/core/kernels/data/finalize_dataset_op.h" #include "tensorflow/core/kernels/data/generator_dataset_op.h" #include "tensorflow/core/kernels/data/iterator_ops.h" #include "tensorflow/core/kernels/data/optional_ops.h" +#include "tensorflow/core/kernels/data/options_dataset_op.h" #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" #include "tensorflow/core/kernels/fifo_queue.h" #include "tensorflow/core/kernels/function_ops.h" @@ -117,6 +119,18 @@ class XlaAssignVariableOp : public OpKernel { .TypeConstraint("out_type") \ .TypeConstraint("T", TYPES), \ ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("VariableShape") \ + .Device(DEVICE) \ + .TypeConstraint("out_type") \ + .HostMemory("output") \ + .HostMemory("input"), \ + VariableShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("VariableShape") \ + .Device(DEVICE) \ + .TypeConstraint("out_type") \ + .HostMemory("output") \ + .HostMemory("input"), \ + VariableShapeOp); \ REGISTER_KERNEL_BUILDER(Name("Size") \ .Device(DEVICE) \ .HostMemory("output") \ @@ -172,6 +186,16 @@ class XlaAssignVariableOp : public OpKernel { .HostMemory("input_dataset") \ .HostMemory("handle"), \ data::PrefetchDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionsDataset") \ + .Device(DEVICE) \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + data::OptionsDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("FinalizeDataset") \ + .Device(DEVICE) \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + data::FinalizeDatasetOp); \ \ REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ data::IteratorHandleOp); \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 99ba565881940e..d43c98f8c79bee 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -43,15 +43,15 @@ static xla::StatusOr>> ParseVisibleDeviceList( } const std::vector visible_devices = absl::StrSplit(visible_device_list, ','); - for (const string& platform_gpu_id_str : visible_devices) { - int32 platform_gpu_id; - if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) { + for (const string& platform_device_id_str : visible_devices) { + int32 platform_device_id; + if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) { return errors::InvalidArgument( "Could not parse entry in 'visible_device_list': '", - platform_gpu_id_str, + platform_device_id_str, "'. visible_device_list = ", visible_device_list); } - gpu_ids.insert(platform_gpu_id); + gpu_ids.insert(platform_device_id); } return {{gpu_ids}}; } @@ -96,7 +96,7 @@ Status XlaGpuDeviceFactory::CreateDevices( std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices) { - LOG(INFO) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; + VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set"; return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index b90f8b7b99060a..52991c5312b962 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -31,44 +32,6 @@ limitations under the License. namespace tensorflow { -// Returns true iff 'ndef' is a call to a function that is compilable. A -// function is compilable iff every operator in the function body is -// compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not -// null, we will populate 'uncompilable_node_info' with uncompilable node info. -static bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, - RecursiveCompilabilityChecker::UncompilableNodesMap* - uncompilable_node_info) { - Device* device = flr->device(); - const XlaOpRegistry::DeviceRegistration* registration; - CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), - ®istration)); - - // We can always *compile* resource operations, stateful RNGs and dummy ops, - // even if we are sometimes unable to auto-cluster them. - RecursiveCompilabilityChecker::OperationFilter op_filter; - op_filter.allow_resource_ops_in_called_functions = true; - op_filter.allow_stack_ops = true; - op_filter.allow_tensor_array_ops = true; - op_filter.allow_stateful_rng_ops = true; - op_filter.allow_control_trigger = true; - op_filter.allow_eliding_assert_and_checknumerics_ops = true; - op_filter.allow_ops_producing_or_consuming_variant = true; - op_filter.allow_slow_ops = true; - op_filter.allow_inaccurate_ops = true; - - RecursiveCompilabilityChecker checker{ - op_filter, DeviceType{registration->compilation_device_name}}; - if (!uncompilable_node_info) { - // We do not need uncompilable node info. Just return the result. - return checker.IsCompilableCall(ndef, flr); - } - - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = - checker.FindUncompilableNodes(ndef, flr); - uncompilable_node_info->swap(uncompilable_node_result); - return uncompilable_node_info->empty(); -} - bool XlaKernelCreator::CanCreateKernel( const FunctionLibraryRuntime& flr, const std::shared_ptr& props) const { @@ -88,37 +51,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); - // Only check for compilability if the MLIR bridge is not enabled. - if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge != - ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED) { - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; - if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { - std::vector - uncompilable_node_info; - for (const auto& it : uncompilable_nodes_map) { - for (const auto& info : it.second.second) { - uncompilable_node_info.emplace_back(info); - } - } - string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); - absl::StrAppend(&message, "Uncompilable nodes:"); - for (const auto& node_info : uncompilable_node_info) { - string node_message = absl::StrCat("\n", node_info.name, ": ", - node_info.uncompilable_reason, "\n", - "\tStacktrace:\n"); - for (const auto& stack_frame : node_info.stack_trace) { - absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", - stack_frame.name, stack_frame.function_name); - } - absl::StrAppend(&message, node_message); - } - VLOG(1) << message; - return errors::InvalidArgument(message); - } - } - // Get function body, constant args, and resource args. NameAttrList function; TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index a0e60b1eafea15..ffec1d1ce31416 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -44,6 +44,17 @@ namespace { using xla::ScopedShapedBuffer; using xla::ShapedBuffer; +// Fetch the platform Id from device. +se::Platform::Id XlaPlatformInfoFromDevice(DeviceBase* device_base) { + auto device = static_cast(device_base); + se::Platform::Id platform_id = nullptr; + if (device->device_type() == DEVICE_CPU) { + platform_id = se::host::kHostPlatformId; + } + + return platform_id; +} + } // anonymous namespace VariableInfo::VariableInfo(int index, absl::string_view name, Var* var) @@ -89,9 +100,25 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, Var* variable = nullptr; ResourceHandle handle = inputs[var_idx]->flat()(0); if (handle.device() != dev->attributes().name()) { - return errors::InvalidArgument( - "Trying to access resource ", handle.name(), " located in device ", - handle.device(), " from device ", dev->attributes().name()); + std::string definition_location = [&]() -> std::string { + if (handle.definition_stack_trace()) { + std::vector stack_frames = + handle.definition_stack_trace()->ToStackFrames( + {}, IsInternalFrameForFilename, + /*reverse_traversal=*/true, + /*limit=*/1); + if (!stack_frames.empty()) { + const StackFrame& last_frame = stack_frames[0]; + return absl::StrCat(" (defined @ ", last_frame.file_name, ":", + last_frame.line_number, ")"); + } + } + return ""; + }(); + return errors::InvalidArgument("Trying to access resource ", + handle.name(), definition_location, + " located in device ", handle.device(), + " from device ", dev->attributes().name()); } TF_RETURN_IF_ERROR(rm->LookupOrCreate( handle.container(), handle.name(), &variable, [](Var** ptr) { @@ -187,14 +214,18 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( // Fills in `execution_input` with `buffer` for `index`. static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input, xla::ShapeIndex index, - se::DeviceMemoryBase& buffer, + se::DeviceMemoryBase buffer, bool donate_buffer, int device_ordinal, se::DeviceMemoryAllocator* allocator) { xla::MaybeOwningDeviceMemory* in_buffer = execution_input.MutableBuffer(index); if (donate_buffer) { + // Here we pass ownership of the buffer to execution_input without releasing + // ownership from the caller of PopulateExecutionInputBuffer. If execution + // succeeds, we'll take back that duplicate ownership in + // GetOrCreateTensorForOutput. If execution fails, the ExecutionInput will + // release that duplicate ownership automatically. *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator); - buffer = se::DeviceMemoryBase(); } else { *in_buffer = buffer; } @@ -281,18 +312,21 @@ static Tensor MakeTensor(DataType dtype, const TensorShape& shape, return t; } -// Get aliased tensor, or make a new one for the corresponding output operation. -static Tensor GetOrCreateTensorForOutput( - int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, +// Get aliased tensor from output, or make a new one for the corresponding +// output operation. Transfers ownership of the buffer from output to the +// returned tensor. +static xla::StatusOr GetOrCreateTensorForOutput( + xla::ScopedShapedBuffer& output, int output_num, OpKernelContext* ctx, + int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, absl::Span input_mapping, const std::map& resource_vars_snapshots, DataType output_dtype, const TensorShape& output_shape, - se::DeviceMemoryBase output_buffer, Allocator* output_allocator) { + Allocator* output_allocator, bool allocate_xla_tensors, se::Stream* stream, + bool use_multiple_streams, std::shared_ptr definition_event) { xla::ShapeIndex output_index = input_output_alias.shape().IsTuple() ? xla::ShapeIndex({output_num}) : xla::ShapeIndex({}); - CHECK(input_output_alias.shape().IsTuple() || output_num == 0); if (absl::optional alias = input_output_alias.GetAliasedParameter(output_index)) { @@ -303,24 +337,39 @@ static Tensor GetOrCreateTensorForOutput( ctx->input(tf_param).dtype() != DT_RESOURCE ? ctx->input(tf_param) : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param); - if (output_buffer.opaque() == input_tensor.data()) { + se::DeviceMemoryBase input_buffer = + XlaTensor::DeviceMemoryFromTensor(input_tensor); + se::DeviceMemoryBase output_buffer = output.buffer({output_num}); + if (input_buffer.opaque() == output_buffer.opaque()) { + // In the case of a donated buffer, both input_tensor and output think + // they have ownership of the buffer (see comment in + // PopulateExecutionInputBuffer). Release ownership from output to avoid + // double free. + output.set_buffer(se::OwningDeviceMemory(), {output_num}); return input_tensor; } } - return MakeTensor(output_dtype, output_shape, output_buffer, - output_allocator); -} -static void PopulateXlaTensor(Tensor* output_tensor, - xla::ScopedShapedBuffer* output, int output_num, - se::Stream* stream, bool use_multiple_streams, - std::shared_ptr definition_event) { - XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - CHECK(xla_tensor); - xla_tensor->set_shaped_buffer(output->TakeSubTree({output_num})); - if (use_multiple_streams) { - xla_tensor->ResetDefinitionEvent(definition_event, stream); + if (allocate_xla_tensors) { + Tensor output_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(output_dtype, output_shape, &output_tensor)); + if (output_tensor.TotalBytes() > 0) { + XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); + TF_RET_CHECK(xla_tensor); + xla_tensor->set_shaped_buffer(output.TakeSubTree({output_num})); + if (use_multiple_streams) { + xla_tensor->ResetDefinitionEvent(definition_event, stream); + } + } + return output_tensor; } + + se::DeviceMemoryBase output_buffer = output.buffer({output_num}); + Tensor output_tensor = + MakeTensor(output_dtype, output_shape, output_buffer, output_allocator); + output.set_buffer(se::OwningDeviceMemory(), {output_num}); + return output_tensor; } // Sets output `output_num` for `ctx` provided it is known at a compile time. @@ -426,7 +475,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( ShapedBuffer buffer( xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}), xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}), - output.platform(), output.device_ordinal()); + output.device_ordinal()); buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(), /*source_base_index=*/{}, /*target_base_index=*/{0}); @@ -445,19 +494,26 @@ Status XlaComputationLaunchContext::PopulateOutputs( std::vector output_tensor_shapes; output_tensor_shapes.reserve(ctx->num_outputs()); if (output.on_host_shape().is_dynamic()) { - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - xla::TransferManager::GetForPlatform(stream->parent()->platform())); + const se::Platform* platform = nullptr; + if (stream != nullptr) { + platform = stream->parent()->platform(); + } else { + // Stream is not set for the host platform. + TF_ASSIGN_OR_RETURN(platform, + se::MultiPlatformManager::PlatformWithId( + XlaPlatformInfoFromDevice(ctx->device()))); + } + TF_ASSIGN_OR_RETURN(auto transfer_manager, + xla::TransferManager::GetForPlatform(platform)); - xla::Shape output_host_shape = output.on_host_shape(); xla::Shape output_device_shape = output.on_device_shape(); TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( - stream, &output, &output_host_shape, &output_device_shape)); + stream, &output, &output_device_shape)); - output.set_shapes(output_host_shape, output_device_shape); + output.set_shapes(output_device_shape, output_device_shape); for (int i = 0; i < ctx->num_outputs(); ++i) { const xla::Shape& subshape = - xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + xla::ShapeUtil::GetSubshape(output_device_shape, {i}); TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); output_tensor_shapes.push_back(shape); @@ -491,22 +547,15 @@ Status XlaComputationLaunchContext::PopulateOutputs( << "Invalid input for outputs " << i << ": " << input_index; ctx->set_output(i, ctx->input(input_index)); } else { - if (allocate_xla_tensors_) { - Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); - if (output_tensor->TotalBytes() > 0) { - PopulateXlaTensor(output_tensor, &output, output_num, stream, - use_multiple_streams_, definition_event); - } - } else { - se::DeviceMemoryBase buffer = output.buffer({output_num}); - Tensor output_tensor = GetOrCreateTensorForOutput( - output_num, ctx, missing_ctx_input_prefix, input_output_alias, - compilation_result->input_mapping, resource_vars, - ctx->expected_output_dtype(i), shape, buffer, allocator); - ctx->set_output(i, output_tensor); - } - output.set_buffer(se::OwningDeviceMemory(), {output_num}); + TF_ASSIGN_OR_RETURN( + Tensor output_tensor, + GetOrCreateTensorForOutput( + output, output_num, ctx, missing_ctx_input_prefix, + input_output_alias, compilation_result->input_mapping, + resource_vars, ctx->expected_output_dtype(i), shape, allocator, + allocate_xla_tensors_, stream, use_multiple_streams_, + definition_event)); + ctx->set_output(i, output_tensor); ++output_num; } } @@ -537,22 +586,14 @@ Status XlaComputationLaunchContext::PopulateOutputs( return errors::Internal("Mismatched type in variable write"); } - Tensor output_tensor; - if (allocate_xla_tensors_) { - TF_RETURN_IF_ERROR( - ctx->allocate_temp(write.type, write.shape, &output_tensor)); - if (write.shape.num_elements() > 0) { - PopulateXlaTensor(&output_tensor, &output, output_num, stream, - use_multiple_streams_, definition_event); - } - } else { - se::DeviceMemoryBase buffer = output.buffer({output_num}); - output_tensor = GetOrCreateTensorForOutput( - output_num, ctx, missing_ctx_input_prefix, input_output_alias, - compilation_result->input_mapping, resource_vars, write.type, - write.shape, buffer, allocator); - } - output.set_buffer(se::OwningDeviceMemory(), {output_num}); + TF_ASSIGN_OR_RETURN( + Tensor output_tensor, + GetOrCreateTensorForOutput(output, output_num, ctx, + missing_ctx_input_prefix, input_output_alias, + compilation_result->input_mapping, + resource_vars, write.type, write.shape, + allocator, allocate_xla_tensors_, stream, + use_multiple_streams_, definition_event)); var->is_initialized |= write.modified; *var->tensor() = output_tensor; ++output_num; @@ -564,11 +605,26 @@ xla::StatusOr> XlaComputationLaunchContext::BuildXlaCompilerArguments( absl::Span must_be_constant_idxs, absl::Span inputs, - absl::Span variable_args) { + absl::Span variable_args, Device* device) { CHECK(absl::c_is_sorted(must_be_constant_idxs)); std::vector out; out.resize(inputs.size()); + // TODO(cheshire): Avoid duplication with framework/op_kernel.h + DeviceContext* device_context = nullptr; + TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context)); + bool using_default_context = false; + auto cleanup = xla::MakeCleanup([&] { + if (device_context != nullptr && !using_default_context) { + device_context->Unref(); + } + }); + if (device_context == nullptr) { + using_default_context = true; + auto* dev_info = device->tensorflow_gpu_device_info(); + if (dev_info) device_context = dev_info->default_context; + } + absl::flat_hash_map variable_info_lookup; for (const VariableInfo& info : variable_args) { CHECK(!info.var() || info.lock_held()) @@ -581,14 +637,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( const Tensor* input = inputs[input_num]; XlaCompiler::Argument& arg = out[input_num]; - if (absl::c_binary_search(must_be_constant_idxs, input_num)) { - // Handles compile-time constants. - TF_RET_CHECK(input->dtype() != DT_RESOURCE); - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input->dtype(); - arg.shape = input->shape(); - arg.constant_value = *input; - } else if (variable_info_lookup.count(input_num)) { + if (variable_info_lookup.count(input_num)) { // Handles resource variables. TF_RET_CHECK(input->dtype() == DT_RESOURCE); const VariableInfo& variable = *variable_info_lookup[input_num]; @@ -609,6 +658,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.type = DT_INVALID; arg.shape = TensorShape(); } + + if (absl::c_binary_search(must_be_constant_idxs, input_num)) { + TF_RET_CHECK(variable.var() && variable.var()->is_initialized); + const Tensor* value = variable.var()->tensor(); + Tensor value_on_host(value->dtype(), value->shape()); + if (!device_context) { + value_on_host = *value; + } else { + TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync( + value, "", device, &value_on_host)); + } + arg.kind = XlaCompiler::Argument::kConstantResource; + arg.constant_value = value_on_host; + } + } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input->dtype(); + arg.shape = input->shape(); + arg.constant_value = *input; } else { // Normal inputs. TF_RET_CHECK(input->dtype() != DT_RESOURCE); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index ac085a022c8e02..97b82324a7f16c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -143,7 +143,8 @@ class XlaComputationLaunchContext { static xla::StatusOr> BuildXlaCompilerArguments(absl::Span must_be_constant_idxs, absl::Span inputs, - absl::Span variable_args); + absl::Span variable_args, + Device* device); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. @@ -207,7 +208,20 @@ class XlaTensorBuffer : public TensorBuffer { TensorBuffer* root_buffer() override { return this; } void FillAllocationDescription(AllocationDescription* proto) const override { - proto->set_allocated_bytes(actual_size_); + proto->set_requested_bytes(static_cast(expected_size_)); + proto->set_allocator_name(allocator_->Name()); + proto->set_ptr(reinterpret_cast(data())); + if (allocator_->TracksAllocationSizes()) { + auto ab = static_cast(allocator_->AllocatedSize(data())); + proto->set_allocated_bytes(ab); + int64 id = allocator_->AllocationId(data()); + if (id > 0) { + proto->set_allocation_id(id); + } + if (RefCountIsOne()) { + proto->set_has_single_reference(true); + } + } } private: diff --git a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc index 6c6c490e032669..7c4378415a94fc 100644 --- a/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc +++ b/tensorflow/compiler/jit/xla_ops_on_regular_devices.cc @@ -29,6 +29,14 @@ namespace tensorflow { .HostMemory("feature_group_count") \ .Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaConvV2") \ + .HostMemory("window_strides") \ + .HostMemory("padding") \ + .HostMemory("lhs_dilation") \ + .HostMemory("rhs_dilation") \ + .HostMemory("feature_group_count") \ + .Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER( \ Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \ XlaCompileOnDemandOp); \ @@ -38,6 +46,8 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER(Name("XlaDotV2").Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER( \ Name("XlaDynamicSlice").HostMemory("size_indices").Device(DEVICE), \ XlaCompileOnDemandOp); \ @@ -74,6 +84,9 @@ namespace tensorflow { XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \ XlaCompileOnDemandOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("XlaVariadicSort").HostMemory("dimension").Device(DEVICE), \ + XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \ XlaCompileOnDemandOp); \ REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \ diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index b38bf9282b1023..cfd4de0f32f9e6 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -79,7 +79,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { auto device = static_cast(device_base); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; - se::DeviceMemoryAllocator* custom_allocator = nullptr; + std::shared_ptr custom_allocator; if (device->device_type() == DEVICE_CPU) { platform_id = se::host::kHostPlatformId; @@ -101,37 +101,35 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device_base) { // allocator to allocate real buffers. platform_id = xla_device_metadata->platform()->id(); custom_allocator = - xla_device_metadata->client()->backend().memory_allocator(); + xla_device_metadata->client()->backend().shared_memory_allocator(); } return XlaPlatformInfo(DeviceType(device->device_type()), platform_id, xla_device_metadata, custom_allocator); } -se::DeviceMemoryAllocator* GetAllocator( - absl::optional* tf_allocator_adapter, +std::shared_ptr GetAllocator( DeviceBase* device, se::Stream* stream, const XlaPlatformInfo& platform_info) { if (platform_info.custom_allocator()) { return platform_info.custom_allocator(); } + auto* alloc = device->GetAllocator({}); if (!stream) { // Stream is not set for the host platform. se::Platform* platform = se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) .ValueOrDie(); - tf_allocator_adapter->emplace(device->GetAllocator({}), platform); - return &tf_allocator_adapter->value(); + return std::make_shared(alloc, platform); } - tf_allocator_adapter->emplace(device->GetAllocator({}), stream); - return &tf_allocator_adapter->value(); + return std::make_shared(alloc, stream); } XlaCompiler::Options GenerateCompilerOptions( const XlaCompilationCache& cache, const FunctionLibraryRuntime& function_library, DeviceBase* device, - se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, - absl::optional* tf_allocator_adapter) { + se::Stream* stream, const XlaPlatformInfo& platform_info, + bool has_ref_vars) { XlaCompiler::Options options; options.client = static_cast(cache.client()); if (stream != nullptr) { @@ -142,8 +140,7 @@ XlaCompiler::Options GenerateCompilerOptions( options.graph_def_version = function_library.graph_def_version(); options.allow_cpu_custom_calls = (platform_info.platform_id() == se::host::kHostPlatformId); - options.device_allocator = - GetAllocator(tf_allocator_adapter, device, stream, platform_info); + options.device_allocator = GetAllocator(device, stream, platform_info); if (platform_info.xla_device_metadata()) { options.shape_representation_fn = platform_info.xla_device_metadata()->shape_representation_fn(); diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index bfb438cc398281..177503dc6dcd11 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -29,10 +29,10 @@ class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} XlaPlatformInfo(XlaPlatformInfo&&) = default; - explicit XlaPlatformInfo(const DeviceType device_type, - se::Platform::Id platform_id, - const XlaDevice::Metadata* xla_device_metadata, - se::DeviceMemoryAllocator* device_allocator) + explicit XlaPlatformInfo( + const DeviceType device_type, se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + std::shared_ptr device_allocator) : device_type_(device_type), platform_id_(platform_id), xla_device_metadata_(xla_device_metadata), @@ -45,7 +45,7 @@ class XlaPlatformInfo { } // Non-null only when run on an XLA device. - se::DeviceMemoryAllocator* custom_allocator() const { + std::shared_ptr custom_allocator() const { return device_allocator_; } @@ -74,7 +74,9 @@ class XlaPlatformInfo { // If the op associated with this XlaPlatformInfo is placed on an XLA device // then device_allocator_ is the xla::Backend's memory allocator. If the op // is placed on a regular CPU or GPU device then device_allocator_ is null. - se::DeviceMemoryAllocator* device_allocator_; + // The allocator is of unknown provenance; keep it in a shared pointer to + // set an artificial refcount of one. + std::shared_ptr device_allocator_; TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); }; @@ -94,8 +96,7 @@ XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); // dummy tensors. // // `stream` parameter is nullable when running on host. -se::DeviceMemoryAllocator* GetAllocator( - absl::optional* tf_allocator_adapter, +std::shared_ptr GetAllocator( DeviceBase* device, se::Stream* stream, const XlaPlatformInfo& platform_info); @@ -104,8 +105,8 @@ se::DeviceMemoryAllocator* GetAllocator( XlaCompiler::Options GenerateCompilerOptions( const XlaCompilationCache& cache, const FunctionLibraryRuntime& function_library, DeviceBase* device, - se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, - absl::optional* tf_allocator_adapter); + se::Stream* stream, const XlaPlatformInfo& platform_info, + bool has_ref_vars); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tpu_device.cc b/tensorflow/compiler/jit/xla_tpu_device.cc new file mode 100644 index 00000000000000..4d4b1edd23ab77 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tpu_device.cc @@ -0,0 +1,486 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_tpu_device.h" + +#include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_device_ops.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/tensor_reference.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/tpu/tpu_node_device_util.h" +#include "tensorflow/core/tpu/virtual_device.h" +#include "tensorflow/stream_executor/tpu/c_api_conversions.h" +#include "tensorflow/stream_executor/tpu/status_helper.h" +#include "tensorflow/stream_executor/tpu/tpu_node_context.h" +#include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h" +#include "tensorflow/stream_executor/tpu/tpu_stream_interface.h" + +namespace tensorflow { +namespace { + +static bool tpu_autoclustering_flag = false; +static bool tpu_xla_device_failure_closes_chips_flag = true; +static bool tpu_use_substreams_for_cross_tpu_device_transfers_flag = true; + +// Given a tensor of `shape` and `type`, as what shape should it be stored on +// the TPU device? This function tranposes or flattens the excessively-padded +// tensors to rank 1, but leaves other tensor shapes alone. +xla::StatusOr TpuShapeRepresentation(const TensorShape& shape, + DataType type, + bool use_fast_memory) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR( + tensorflow::TensorShapeToXLAShape(type, shape, &xla_shape)); + ApiConverter::StackHelper se_shape(xla_shape); + ApiConverter::StackHelper tpu_shape; + StatusHelper status; + tpu::ExecutorApiFn()->XlaShapeToTpuShapeRepresentationFn( + &se_shape.value, type, use_fast_memory, &tpu_shape.value, + status.c_status); + if (!status.status().ok()) { + return status.status(); + } + return tpu_shape.AsCpp(); +} + +// Given a tensor, returns the shape of its representation on device, +// fully padded. Contents of `shape` are undefined on error. +Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return errors::InvalidArgument( + "Expected an XlaTensor when computing padded shape"); + } + + if (!xla_tensor->has_shaped_buffer()) { + return errors::InvalidArgument( + "XlaTensor is expected to have device memory allocated when " + "computing padded shape"); + } + + const xla::Shape& on_device_shape = + xla_tensor->shaped_buffer().on_device_shape(); + + StatusHelper status; + ApiConverter::StackHelper se_shape(on_device_shape); + ApiConverter::StackHelper tpu_shape; + tpu::ExecutorApiFn()->XlaShapeToTpuPaddedShapeFn( + &se_shape.value, &tpu_shape.value, status.c_status); + if (!status.ok()) { + return status.status(); + } + *shape = tpu_shape.AsCpp(); + return Status::OK(); +} + +// Check if TPU has been initialized. TPU initialization is not necessary +// for 1x1. +Status CheckIfTPUInitialized() { + auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(); + if (!tpu_platform->Initialized()) { + return errors::FailedPrecondition( + "The TPU system has not been initialized."); + } + return Status::OK(); +} + +// Implementation of TPU->TPU device copies that copies over the dedicated TPU +// interconnects, which is much faster than PCIe or the host network. +// TODO(b/117426293): This implementation is only called for direct interconnect +// transfers between TPU devices attached to the same host. Ideally, we would +// generalize this support to direct interconnect transfers across hosts, but +// currently the CopyTensor infrastructure seems to the network topology is +// strictly hierarchical, that is, transfers between devices on different hosts +// can only take place using the host network. +void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, + DeviceContext* dst_dev_context, Device* src, + Device* dst, AllocatorAttributes src_allocator_attrs, + AllocatorAttributes dst_allocator_attrs, + const Tensor* input, Tensor* output, + int dev_to_dev_stream_index, StatusCallback done) { + XlaDeviceContext* const src_xla_context = + static_cast(src_dev_context); + XlaDeviceContext* const dst_xla_context = + static_cast(dst_dev_context); + static const bool should_use_substream = + tpu_use_substreams_for_cross_tpu_device_transfers_flag; + + auto impl = [&]() -> Status { + if (src->name() != dst->name()) { + Status s = CheckIfTPUInitialized(); + if (!s.ok()) { + done(s); + return Status::OK(); + } + } + if (input->shape().num_elements() == 0) { + // Zero-element tensors have no backing buffers. + done(Status::OK()); + return Status::OK(); + } + + se::Stream* const src_compute_stream = src_xla_context->stream(); + TF_RET_CHECK(src_compute_stream != nullptr); + TF_RET_CHECK(input->dtype() == output->dtype()) + << "input type: " << DataTypeString(input->dtype()) << " output type " + << DataTypeString(output->dtype()); + TF_RET_CHECK(input->shape() == output->shape()); + TF_RET_CHECK(DMAHelper::CanUseDMA(input)); + auto* const src_compute_stream_impl = static_cast( + src_compute_stream->implementation()); + + se::Stream* dst_compute_stream = dst_xla_context->stream(); + auto* const dst_compute_stream_impl = static_cast( + dst_compute_stream->implementation()); + + if (src_compute_stream_impl->IsSameSharedMemoryLocation( + dst_compute_stream_impl)) { + // Surprisingly, this path does get triggered in practice. + *output = *input; + done(Status::OK()); + return Status::OK(); + } + + // To avoid stream exhaustion, we pick a substream from a pool if enabled. + se::Stream* const device_to_device_master_stream = + should_use_substream ? dst_xla_context->device_to_device_stream(0) + : nullptr; + se::Stream* const dst_device_to_device_stream = + should_use_substream + ? device_to_device_master_stream->GetOrCreateSubStream() + : dst_xla_context->GetDeviceToDeviceStream(); + TF_RET_CHECK(dst_device_to_device_stream != nullptr); + auto return_substream = gtl::MakeCleanup( + [device_to_device_master_stream, dst_device_to_device_stream] { + if (device_to_device_master_stream) { + device_to_device_master_stream->ReturnSubStream( + dst_device_to_device_stream); + } + }); + + auto* const dst_device_to_device_stream_impl = + static_cast( + dst_device_to_device_stream->implementation()); + + const int dst_device_ordinal = + dst_xla_context->stream()->parent()->device_ordinal(); + + XlaTensor* const xla_input = XlaTensor::FromTensor(input); + TF_RET_CHECK(xla_input != nullptr && xla_input->has_shaped_buffer()); + XlaTensor* const xla_output = XlaTensor::FromTensor(output); + TF_RET_CHECK(xla_output != nullptr && !xla_output->has_shaped_buffer()); + TF_RET_CHECK(input->shape() == output->shape()); + + TF_ASSIGN_OR_RETURN(xla::Shape shape, + dst_xla_context->shape_representation_fn()( + input->shape(), input->dtype(), + /*use_fast_memory=*/false)); + TF_RETURN_IF_ERROR(xla_output->AllocateShapedBuffer( + input->dtype(), shape, dst_xla_context->client(), dst_device_ordinal)); + + VLOG(2) << "TpuDeviceToDeviceCopy: src: " + << src_compute_stream->parent()->device_ordinal() << ", " + << " dst: " << dst_compute_stream->parent()->device_ordinal() + << ", " + << " input buffers: " << xla_input->shaped_buffer().ToString() + << " output buffers: " << xla_output->shaped_buffer().ToString(); + + // Wait for definition event of the source tensor so the input buffers are + // available. + xla_input->WaitForDefinitionEventOnStream(dst_device_to_device_stream); + + // Wait for the destination tensor buffers to be ready, if they are not + // available for an immediate write. + if (!dst_xla_context->transfer_manager()->CanShapedBufferBeAccessedNow( + dst_compute_stream->parent(), xla_output->shaped_buffer())) { + dst_device_to_device_stream->ThenWaitFor(dst_compute_stream); + // If the representation is a tuple, we also must wait for the tuple index + // buffers to be available on the destination host to device transfer + // stream. + if (xla_output->shaped_buffer().on_device_shape().IsTuple()) { + dst_xla_context->host_to_device_stream()->ThenWaitFor( + dst_compute_stream); + } + } + + for (const auto& leaf : xla_input->shaped_buffer().buffers().leaves()) { + const xla::ShapeIndex& index = leaf.first; + const se::DeviceMemoryBase& input_buffer = leaf.second; + const se::DeviceMemoryBase& output_buffer = + xla_output->shaped_buffer().buffer(index); + TF_RET_CHECK(input_buffer.size() == output_buffer.size()) + << "input: " << input_buffer.size() + << " output: " << output_buffer.size(); + TF_RETURN_IF_ERROR( + dst_device_to_device_stream_impl->EnqueueOnTpuDeviceSendRecvLocal( + input_buffer, output_buffer)); + } + + // If the on-device shape is a tuple, write new tuple index buffers. + if (xla_output->shaped_buffer().on_device_shape().IsTuple()) { + TF_RETURN_IF_ERROR( + dst_xla_context->transfer_manager()->WriteTupleIndexTablesAsync( + dst_xla_context->host_to_device_stream(), + xla_output->shaped_buffer())); + + // We need a single definition event for an XlaTensor, so make the + // device to device stream wait for the stream that wrote the tuple index + // tables on the destination device. Should this prove to be a problem, + // we can always extend XlaTensor to take a pair of definition events that + // must all be satisfied, or add an Event::Merge() API that allows us to + // build an event that is triggered when all of its dependencies are + // triggered. + dst_device_to_device_stream->ThenWaitFor( + dst_xla_context->host_to_device_stream()); + } + + auto definition_event = + std::make_shared(dst_xla_context->stream()->parent()); + TF_RET_CHECK(definition_event->Init()) << "Event failed to initialize!"; + dst_device_to_device_stream->ThenRecordEvent(definition_event.get()); + xla_output->ResetDefinitionEvent(std::move(definition_event), + dst_device_to_device_stream); + + // The input must remain alive until the transfer completes, so we keep a + // reference. We also wait until the transfer completes before calling + // done(). + // The latter may be too conservative, but given the host is involved in + // waiting for the transfer to complete anyway there is probably little + // downside. If we were to add the ability for computations to wait directly + // on transfers, then we might want to rethink this property. + // Also ideally this host callback should be on source stream rather than + // destination stream, but when this function returns, the send requests + // might not be enqueued to the stream yet, we put it on destination stream. + TensorReference input_reference(*input); + std::move(return_substream).release(); + dst_device_to_device_stream->ThenDoHostCallback( + [input_reference, done = std::move(done), + device_to_device_master_stream, dst_device_to_device_stream] { + if (device_to_device_master_stream) { + device_to_device_master_stream->ReturnSubStream( + dst_device_to_device_stream); + } + input_reference.Unref(); + done(Status::OK()); + }); + + return Status::OK(); + }; + Status status = impl(); + if (!status.ok()) { + done(status); + } +} + +class TpuNodeDeviceFactory : public DeviceFactory { + public: + Status ListPhysicalDevices(std::vector* devices) override; + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override; +}; + +Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector* devices) { + tpu::TpuPlatformInterface* platform = + tpu::TpuPlatformInterface::GetRegisteredPlatform(); + if (platform == nullptr) { + // If we don't have a platform registered, then we have no devices. + return Status::OK(); + } + + int device_count = platform->VisibleDeviceCount(); + + for (int i = 0; i < device_count; ++i) { + const string device_name = absl::StrCat("/physical_device:TPU:", i); + devices->push_back(device_name); + } + + return Status::OK(); +} + +Status TpuNodeDeviceFactory::CreateDevices( + const SessionOptions& session_options, const string& name_prefix, + std::vector>* devices) { + tpu::TpuPlatformInterface* platform = + tpu::TpuPlatformInterface::GetRegisteredPlatform(); + if (platform == nullptr) { + // If we don't have a platform registered, then we should not create any. + return Status::OK(); + } + + if (platform != nullptr && platform->ShouldRegisterTpuDeviceToDeviceCopy()) { + RegisterTpuDeviceToDeviceCopy(); + } + + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = DEVICE_TPU_XLA_JIT; + registration.autoclustering_policy = + tpu_autoclustering_flag + ? XlaOpRegistry::AutoclusteringPolicy::kAlways + : XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested; + + registration.cluster_resource_variable_ops_unsafely = true; + registration.cluster_stack_ops = false; + registration.cluster_tensor_array_ops = true; + registration.cluster_stateful_rng_ops = true; + registration.cluster_control_trigger = true; + registration.elide_assert_and_checknumerics = true; + registration.cluster_variant_ops = true; + registration.cluster_slow_ops = true; + registration.cluster_inaccurate_ops = true; + XlaOpRegistry::RegisterCompilationDevice(DEVICE_TPU_NODE, registration); + + static XlaDeviceOpRegistrations* registrations = + RegisterXlaDeviceKernels(DEVICE_TPU_NODE, DEVICE_TPU_XLA_JIT); + (void)registrations; + + int device_count = platform->VisibleDeviceCount(); + VLOG(1) << "Creating " << device_count << " TPU devices"; + for (int i = 0; i < device_count; ++i) { + TF_RETURN_IF_ERROR(tpu::TpuNodeContext::Initialize(i)); + + XlaDevice::Options options; + options.platform = platform; + options.device_name_prefix = name_prefix; + options.device_name = DEVICE_TPU_NODE; + options.device_ordinal = i; + options.compilation_device_name = DEVICE_TPU_XLA_JIT; + options.use_multiple_streams = true; + options.shape_representation_fn = &TpuShapeRepresentation; + options.padded_shape_fn = &TpuPaddedShapeFn; + auto device = absl::make_unique(session_options, options); + + // The GpuDeviceInfo actually provides information not only for GPU + // devices but also for TPU. The name is a legacy from the pre-TPU + // dark ages. + Status status = device->UseGpuDeviceInfo(); + if (!status.ok()) { + errors::AppendToMessage(&status, "while setting up ", DEVICE_TPU_XLA_JIT, + " device number ", i); + return status; + } + device->SetAllowsSyncOnCompletion(false); + if (tpu_xla_device_failure_closes_chips_flag) { + device->SetHandleDeviceErrorCallback(&tpu::TpuNodeContext::CloseTpuHost); + } + + devices->push_back(std::move(device)); + } + + return Status::OK(); +} + +class TpuSystemDeviceFactory : public DeviceFactory { + public: + Status ListPhysicalDevices(std::vector* devices) override; + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override; +}; + +Status TpuSystemDeviceFactory::ListPhysicalDevices( + std::vector* devices) { + int device_count = 0; + TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count)); + if (device_count == 0) { + VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device"; + return Status::OK(); + } + + devices->push_back("/physical_device:TPU_SYSTEM:0"); + + return Status::OK(); +} + +Status TpuSystemDeviceFactory::CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) { + int device_count = 0; + TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count)); + if (device_count == 0) { + VLOG(1) << "Host has no TPUs, not creating a TPU_SYSTEM device"; + return Status::OK(); + } + + int64 memory_limit; + TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpuMemoryLimit(&memory_limit)); + + // Creates a device that represents a TPU distributed system. + const DeviceAttributes attrs = Device::BuildDeviceAttributes( + absl::StrCat(name_prefix, "/device:", DEVICE_TPU_SYSTEM, ":", 0), + DeviceType(DEVICE_TPU_SYSTEM), Bytes(memory_limit), DeviceLocality(), + absl::StrCat("device: ", DEVICE_TPU_SYSTEM, " device")); + devices->push_back(absl::make_unique(options.env, attrs)); + VLOG(1) << "Created TPU_SYSTEM device. This host has " << device_count + << " TPUs"; + + return Status::OK(); +} + +} // namespace + +void RegisterTpuDeviceToDeviceCopy() { + static auto* const register_tpu_tpu_copy = new CopyTensor::Registration( + DEVICE_TPU_NODE, DEVICE_TPU_NODE, TpuDeviceToDeviceCopy); + (void)register_tpu_tpu_copy; +} + +void RegisterTpuNodeDevice( + bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips, + bool tpu_use_substreams_for_cross_tpu_device_transfers) { + tpu_autoclustering_flag = tpu_autoclustering; + tpu_xla_device_failure_closes_chips_flag = + tpu_xla_device_failure_closes_chips; + tpu_use_substreams_for_cross_tpu_device_transfers_flag = + tpu_use_substreams_for_cross_tpu_device_transfers; + + REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes); + REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes); + REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes); + REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory); +} + +void RegisterTpuSystemDevice() { + REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory); +} + +#if !defined(PLATFORM_GOOGLE) + +// We automatically register this if we are building for open source. For +// Google platforms, we initialize these devices in other places. + +REGISTER_XLA_LAUNCH_KERNEL(DEVICE_TPU_NODE, XlaLocalLaunchOp, kTpuAllTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_TPU_NODE, XlaCompileOp, kTpuAllTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_TPU_NODE, XlaRunOp, kTpuAllTypes); +REGISTER_XLA_DEVICE_KERNELS(DEVICE_TPU_NODE, kTpuAllTypes); +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_NODE, TpuNodeDeviceFactory); +REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory); + +#endif // PLATFORM_GOOGLE + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_tpu_device.h b/tensorflow/compiler/jit/xla_tpu_device.h new file mode 100644 index 00000000000000..bb31c65b575509 --- /dev/null +++ b/tensorflow/compiler/jit/xla_tpu_device.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +void RegisterTpuDeviceToDeviceCopy(); + +void RegisterTpuNodeDevice( + bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips, + bool tpu_use_substreams_for_cross_tpu_device_transfers); + +void RegisterTpuSystemDevice(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_ diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 18d05bdaace668..340b5ba1efdc83 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -3,7 +3,11 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_binary", + "tf_cc_test", +) package( default_visibility = [ @@ -75,6 +79,7 @@ cc_library( "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "//tensorflow/core:lib", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", @@ -108,6 +113,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_passes", + "//tensorflow/compiler/mlir/tosa:tf_passes", + "//tensorflow/compiler/mlir/tosa:tfl_passes", ], ) @@ -126,12 +133,14 @@ cc_library( srcs = ["mlir_graph_optimization_pass.cc"], hdrs = ["mlir_graph_optimization_pass.h"], deps = [ + "//tensorflow/compiler/mlir:mlir_bridge_rollout_policy", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -191,6 +200,30 @@ tf_cc_binary( ], ) +cc_library( + name = "mlir_bridge_rollout_policy", + srcs = ["mlir_bridge_rollout_policy.cc"], + hdrs = ["mlir_bridge_rollout_policy.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/jit:flags", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "mlir_graph_optimization_pass_test", + srcs = ["mlir_graph_optimization_pass_test.cc"], + deps = [ + ":mlir_graph_optimization_pass", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm-project//mlir:IR", + ], +) + filegroup( name = "litfiles", srcs = glob(["runlit*py"]), diff --git a/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md new file mode 100644 index 00000000000000..a0623e05ad6dd9 --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/includes/tf_passes.md @@ -0,0 +1,691 @@ + +### `-cluster-ops-by-policy`: Clusters ops according to specified policy. +This pass clusters ops according to the policy specified by the pass options. +Clustered ops are moved to a tf_device::clusterOp region. + +First you need to specify the 'oplist=' option. This option +specifies the names of the ops that should be clustered together. Then you need +to specify the algorithm for forming a cluster with a `mode=` option: + +1. `use-def` (default): cluster ops together if they form a single use def-use + chain, that is, the next op in the list uses the result of the previous op + and is the only user of that result. +2. `union-find`: cluster ops together that are connected to each other with + potentially different use def chains using union-find algorithm. + +For both algorithms the ops should be located in the same block, be assigned to +the same device and have no side effects. + +For example, running this pass with options: + "oplist=tf.Cast,tf.Add algorithm=use-def" + +```mlir +func @cluster_oplist(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = "tf.Cast"(%arg0) : (tensor) -> tensor + %1 = "SomeOp" (%arg1) : (tensor) -> tensor + %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> tensor + return %2 : tensor +} +``` + +will produce tf_device::opCluster enclosing tf.Add and tf.Neg: + +```mlir +func @cluster_oplist(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "SomeOp"(%arg1) : (tensor) -> tensor + %1 = "tf_device.cluster"() ( { + %2 = "tf.Cast"(%arg0) : (tensor) -> tensor + %3 = "tf.Add"(%2, %0) : (tensor, tensor) -> tensor + tf_device.return %3 : tensor + }) : () -> tensor + return %1 : tensor +} +``` + +Running with `union-find` algorithm allows to cluster together operations that +do not form a single use-def chain: + "oplist=tf.Add,tf.Sub algorithm=union-find" + +```mlir +func @cluster_oplist(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "tf.Add"(%0, %1) : (tensor, tensor) -> tensor + return %2 : tensor +} +``` + +will produce tf_device::opCluster enclosing tf.Add and tf.Sub: + +```mlir +func @cluster_oplist(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%1, %2) : (tensor, tensor) -> tensor + tf_device.return %3 : tensor + }) : () -> tensor + return %0 : tensor +} + +#### Options +``` +-policy-name : Adds a policy string attribute to all extracted clusters. This attribute allows to distinguish clusters formed by different policies or maybe other clustering algorithms. +-min-cluster-size : Do not form clusters smaller of the given size. +-algorithm : Clustering algorithm type: `use-def` or `union-find` +-oplist : Cluster listed ops when they form a single use def-use chain, such that each op's single user is the next op in the list. +``` +### `-prepare-tpu-computation-for-tf-export`: Prepare TPU computation to be legal for export to TensorFlow +Prepares TPU computation module attached to _TPUCompileMlir op for +TensorFlow graph export by making transformation such as replacing or +removing MLIR or XLA specific attributes that are not legal in TensorFlow +graph. +### `-tf-device-attribute-to-launch`: Wraps each TF op which has a non-empty device attribute in a tf_device.launch. +This pass wraps TF ops which have a non-empty device attribute in a tf_device.lauch with +the same device attribute. + +For example, the following: + +```mlir +func @single_op_launch() { + %a = "tf.opA"() {device = "CPU:0"} : () -> tensor + return %a +} +``` + +will be transformed into: + +```mlir +func @single_op_launch() { + %1 = tf_device.launch() ( { + %a = "tf.opA"() : () -> tensor + tf_device.return %a + }) {device = "CPU:0"} : () -> tensor + return %1 +} +``` +### `-tf-device-cluster-outlining`: Outlines regions of tf_device.cluster operations +This pass outlines the body of a `tf_device.cluster` into a function and +replaces the `tf_device.cluster` op with an equivalent `tf_device.cluster_func` +op. Implicit operands will be captured and materialized as explicit arguments to +the newly created functions and associated `tf_device.cluster_func` ops. + +For example, the following: + +```mlir +func @computation(%arg0: tensor) -> tensor { + %cluster = "tf_device.cluster"() ( { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + tf_device.return %identity : tensor + }) : () -> (tensor) + return %cluster : tensor +} +``` + +will be transformed into: + +```mlir +func @computation(%arg0: tensor) -> tensor { + %cluster = "tf_device.cluster_func"(%arg0) {func = @_func} : (tensor) -> tensor + return %cluster : tensor +} + +func @_func(%arg0: tensor) -> tensor { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + return %identity : tensor +} +``` +### `-tf-device-constant-sinking`: Sinks constants implicitly captured in a tf_device.cluster region. +This pass sinks implicitly captured constants (`tf.Const` ops) used by and into +a `tf_device.cluster` region. Performing this prior to outlining will reduce the +number of arguments of the outlined function. + +For example, the following: + +```mlir +func @cluster() -> tensor { + %const = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cluster = "tf_device.cluster"() ( { + %identity = "tf.Identity"(%const) : (tensor) -> tensor + tf_device.return %identity : tensor + }) : () -> (tensor) + return %cluster : tensor +} +``` + +will be transformed into: + +```mlir +func @cluster() -> tensor { + %cluster = "tf_device.cluster"() ( { + %const = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %identity = "tf.Identity"(%const) : (tensor) -> tensor + tf_device.return %identity : tensor + }) : () -> (tensor) + return %cluster : tensor +} +``` +### `-tf-executor-graph-pruning`: Prunes unreachable ops in a tf_executor.graph +This pass removes ops from a `tf_executor.graph` that are not transitively, via +data or control dependencies, connected to the associated `tf_executor.fetch` +op. The order of ops will be preserved. Functions named `main` with no +`tf.entry_function` attribute will not be pruned, as such graphs/functions may +have been imported from a V1 TensorFlow graph, where feeds/fetches/targets are +not provided at certain stages of IR transformation (e.g. pre-placement). + +Option `ops-to-preserve` allows to specify ops that should not be pruned, +regardless of their reachability. + +For example, the following: + +```mlir +func @graph(%arg0: tensor, %arg1: tensor) -> tensor { + %graph = tf_executor.graph { + %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor) -> tensor + %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor) -> tensor + %unreachable_data:2 = tf_executor.island wraps "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> () + %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> () + %unreachable_control = tf_executor.island wraps "tf.NoOp"() : () -> tensor + tf_executor.fetch %reachable_data#0, %reachable_control : tensor, !tf_executor.control + } + return %graph : tensor +} +``` + +will be transformed into: + +```mlir +func @graph(%arg0: tensor, %arg1: tensor) -> tensor { + %graph = tf_executor.graph { + %transitive_reachable_data:2 = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor) -> tensor + %reachable_data:2 = tf_executor.island wraps "tf.Identity"(%transitive_reachable_data#0) : (tensor) -> tensor + %transitive_reachable_control = tf_executor.island wraps "tf.NoOp"() : () -> () + %reachable_control = tf_executor.island(%transitive_reachable_control) wraps "tf.NoOp"() : () -> () + tf_executor.fetch %reachable_data#0, %reachable_control : tensor, !tf_executor.control + } + return %graph : tensor +} +``` + +#### Options +``` +-ops-to-preserve : Comma separated list of ops that should not be pruned regardless of reachability +``` +### `-tf-executor-to-functional-conversion`: Lifts tf_executor.island inner ops from a tf_executor.graph +This pass converts tf_executor.graphs consisting of only tf_executor.islands and +a tf_executor.fetch into a sea of nodes consisting of TensorFlow Dialect ops by +lifting such ops out of a tf_executor.graph's tf_executor.islands. If V1 control +flow ops are present in a tf_executor.graph, an error will be returned. + +For example, the following: + +```mlir +func @my_fn(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %graph_results:2 = tf_executor.graph { + %island_0_result, %island_0_control = tf_executor.island { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.yield %identity : tensor + } + %island_1_result, %island_1_control = tf_executor.island { + %identity_n:2 = "tf.IdentityN"(%arg1, %island_0_result) : (tensor, tensor) -> (tensor, tensor) + tf_executor.yield %identity_n#0 + } + tf_executor.fetch %island_0_result, %island_1_result : tensor, tensor + } + return %graph_results#0, %graph_results#1 : tensor, tensor +} +``` + +will be transformed into: + +```mlir +func @my_fn(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + %identity_n:2 = "tf.IdentityN"(%arg1, %identity) : (tensor, tensor) -> (tensor, tensor) + return %identity, %identity_n#0 : tensor, tensor +} +``` +### `-tf-functional-control-flow-to-regions`: Transforms functional control flow operations to their region-based counterparts +This pass transforms functional control flow operations in the TensorFlow +dialect to their region-based counterparts, i.e., `tf.If` is transformed to +`tf.IfRegion` and `tf.While` is transformed to `tf.WhileRegion`. + +For example, this functional operation + +```mlir + %0 = "tf.If"(%arg0, %arg1) { + then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false + } : (tensor, tensor<*xf32>) -> tensor<*xf32> +``` + +will be transformed into this region-based operation + +```mlir + %0 = "tf.IfRegion"(%arg0) ( { + %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }, { + %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<*xf32> +``` +### `-tf-mark-ops-for-outside-compilation`: Marks ops in device cluster for outside compilation if they are unsupported on device. +This pass marks unsupported ops in a device cluster with +`_xla_outside_compilation` attribute so the operations will run on the host +instead of the device. Unsupported ops are ops that can not be code +generated to run on the device for the cluster including: + +1. String operations on TPUs. +2. Operations that don't have a kernel defined for the device. + +This pass is conservative in that it will mark all ops for outside compilation +that can not be compiled for the device. Exceptions for this are added for ops +that will be rewritten or decomposed before compiling on device. + + +For example, tf_device.cluster op with an unsupported op, tf.UnsupportedOp: + +```mlir +func @unsupported_op() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.UnsupportedOp"() : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} +``` + +will mark tf.UnsupportedOp with `_xla_outside_compilation` attribute: + +```mlir +func @unsupported_op() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.UnsupportedOp"() {_xla_outside_compilation = "auto0"} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {allow_soft_placement = true, device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor + return %0 : tensor +} +``` +### `-tf-region-control-flow-to-functional`: Transforms region-based control flow operations to their functional counterparts +This pass transforms region-based control flow operations in the TensorFlow +dialect to their functional counterparts, i.e., `tf.IfRegion` is transformed to +`tf.If` and `tf.WhileRegion` is transformed to `tf.While`. + +For example, this region-based operation + +```mlir + %0 = "tf.IfRegion"(%arg0) ( { + %1 = call @then_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }, { + %1 = call @else_branch_func(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<*xf32> +``` + +will be transformed into this functional operation + +```mlir + %0 = "tf.If"(%arg0, %arg1) { + then_branch = @then_branch_func, else_branch = @else_branch_func, is_stateless = false + } : (tensor, tensor<*xf32>) -> tensor<*xf32> +``` +### `-tf-shape-inference`: Simple Shape Inference on TensorFlow Dialect + +#### Options +``` +-max-iterations : Maximum shape inference iterations +``` +### `-tf-tpu-cluster-formation`: Forms clusters from operations assigned to the same TPU computation +TPU computations from the frontend are composed of a `tf.TPUReplicateMetadata` +op, a subgraph of ops (TensorFlow Dialect) each with a matching `_tpu_replicate` +attribute relative to the associated `tf.TPUReplicateMetadata` op, and +optionally `tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops feeding in +inputs and outputs to and from a replicated TPU computation. The number of times +a TPU computation is replicated is defined in the `tf.TPUReplicateMetadata` op +(`num_replicas` attribute) and operand and result sizes of +`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` respectively must match, +excluding packed tensors. It is also assumed ops of the same TPU computation do +not have ops outside of the TPU computation that are both inputs and outputs to +the same TPU computation. + +This pass takes the TPU computation subgraph, moves them into a +`tf_device.cluster`, and copies over attributes from the associated +`tf.TPUReplicateMetadata` op to the newly created `tf_device.cluster`. If the +computation is replicated (`num_replicas` > 1), the `num_replicas` attribute is +not copied over but instead the `tf_device.cluster` is further wrapped with a +`tf_device.replicate`, and associated `tf.TPUReplicatedInput` and +`tf.TPUReplicatedOutput` ops are replaced as the `tf_device.replicate` operands +and results. Otherwise, the single operands and results of the associated +`tf.TPUReplicatedInput` and `tf.TPUReplicatedOutput` ops are simply forwarded to +the `tf_device.cluster`. + +For example, the following non replicated computation: + +```mlir +func @tpu_computation(%arg0: tensor) -> tensor { + // Metadata op for cluster `cluster` with 1 replica, 1 core per replica and + // with topology ``. + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 1, num_cores_per_replica = 1, topology = "", device_assignment = [], padding_map = []} : () -> () + %replicated_input = "tf.TPUReplicatedInput"(%arg0) : (tensor) -> tensor + %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %replicated_output = "tf.TPUReplicatedOutput(%identity) : (tensor) -> tensor + return %replicated_output : tensor +} +``` + +will be transformed into: + +```mlir +func @tpu_computation(%arg0: tensor) -> tensor { + %cluster = "tf_device.cluster"() ( { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + tf_device.return %identity : tensor + }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor) + return %cluster : tensor +} +``` + +The following replicated computation: + +```mlir +func @tpu_computation(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", num_relicas = 2, num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> () + %replicated_input = "tf.TPUReplicatedInput"(%arg0, %arg1) : (tensor, tensor) -> tensor + %identity = "tf.Identity"(%replicated_input) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %replicated_output:2 = "tf.TPUReplicatedOutput(%identity) : (tensor) -> (tensor, tensor) + return %replicated_output#0, %replicated_output#1 : tensor, tensor +} +``` + +will be transformed into: + +```mlir +func @tpu_computation(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %replicate:2 = tf_device.replicate([%arg0, %arg1] as %replicated_input) {n = 2 : i32} { + %cluster = "tf_device.cluster"() ( { + %identity = "tf.Identity"(%replicated_input) : (tensor) -> tensor + tf_device.return %identity : tensor + }) {_tpu_replicate = "cluster", num_cores_per_replica = 1, topology = "topology", device_assignment = [], padding_map = []} : () -> (tensor) + tf_device.return %cluster : tensor + } + return %replicate#0, %replicate#1 : tensor, tensor +} +``` +### `-tf-tpu-extract-outside-compilation`: Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region. +This pass extracts a CPU computation cluster with `_xla_outside_compilation` +annotation, which denotes ops that should be run on CPU/host, from a TPU cluster. +Each outside compilation cluster is moved to +a tf_device.parallel_execute region. The TPU cluster is also moved to a +tf_device.parallel_execute region. Communication ops between device and host are +added to pass inputs/outputs to/from the outside compiled region. + +For example, the following tf_device.cluster with an op marked for `xla_outside_compilation`: + +```mlir +func @outside_compilation() -> tensor { + %0 = "tf_device.cluster"() ( { + %1 = "tf.Const"() {_xla_outside_compilation = "0", value = dense<1.0> : tensor} : () -> (tensor) + %2 = "tf.Identity"(%1) {_xla_outside_compilation = "0"} : (tensor) -> (tensor) + %3 = "tf.AddV2"(%1, %2) : (tensor, tensor) -> (tensor) + tf_device.return %3 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} +``` + +will become a tf_device.parallel_execute op with a CPU/host region and +a tf_device.cluster with communication ops to send data to/from device/host: + +```mlir +func @outside_compilation() -> tensor { + %0 = "tf_device.parallel_execute"() ( { + "tf_device.launch"() ( { + %1 = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<3x!tf.string> + %2 = "tf._XlaRecvAtHost"(%1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_args"} : (tensor<3x!tf.string>) -> tensor + %3 = "tf.Identity"(%2) : (tensor) -> tensor + "tf._XlaSendFromHost"(%3, %1) {device_ordinal = 0 : i64, key = "host_compute_channel_0_0_retvals"} : (tensor, tensor<3x!tf.string>) -> () + tf_device.return + }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () + tf_device.return + }, { + %1 = "tf_device.cluster"() ( { + %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %3 = "tf._XlaHostComputeMlir"(%2) {recv_key = "host_compute_channel_0_0_retvals", send_key = "host_compute_channel_0_0_args", tpu_core = 0 : i64} : (tensor) -> tensor + %4 = "tf.AddV2"(%2, %3) : (tensor, tensor) -> tensor + tf_device.return %4 : tensor + }) {device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""} : () -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return %0 : tensor +} +``` +### `-tf-tpu-reorder-replicate-partitioned-inputs`: Reorder replicated and partitioned input ops. +This pass rewrites how data parallelism and model parallelism is expressed for +inputs. It reorders `tf.TPUPartitionedInput` (model parallelism) and +`tf.TPUReplicatedInput` (data parallelism) ops. It transforms a DAG where +multiple `tf.TPUPartitionedInput` ops are feeding into a single +`tf.TPUReplicatedInput` into a DAG where multiple `tf.TPUReplicatedInput` ops +are feeding into a single `tf.TPUPartitionedInput`. Transforming the IR in such +a manner will allow subsequent cluster formation pass to handle IR with both +data and model parallelism in an easier manner. + +For example, the following: + +```mlir +!rtype = type tensor>> +func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype { + %pi_0 = "tf.TPUPartitionedInput"(%arg0, %arg1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype + %pi_1 = "tf.TPUPartitionedInput"(%arg2, %arg3) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype + %ri = "tf.TPUReplicatedInput"(%pi_0, %pi_1) : (!rtype, !rtype) -> !rtype + return %ri : !rtype +} +``` + +will be transformed into: + +```mlir +!rtype = type tensor>> +func @data_and_model_parallelism(%arg0: !rtype, %arg1: !rtype, %arg2: !rtype, %arg3: !rtype) -> !rtype { + %ri_0 = "tf.TPUReplicatedInput"(%arg0, %arg2) : (!rtype, !rtype) -> !rtype + %ri_1 = "tf.TPUReplicatedInput"(%arg1, %arg3) : (!rtype, !rtype) -> !rtype + %pi = "tf.TPUPartitionedInput"(%ri_0, %ri_1) {_XlaSharding = "", device = "", partition_dim = -1 : i64} : (!rtype, !rtype) -> !rtype + return %pi : !rtype +} +``` +### `-tf-tpu-resource-partition`: Partitions unpartitioned resource read/write to partitioned resource variables. +This pass creates individual resource reads/writes from the unpartitioned +resource variable (from `tf.TPUPartitionedInput`) to individual partitioned +resource variables (`tf.TPUPartitionedInput` operands). As resource op +decomposition/lifting occurs with the unpartitioned resource variables, +transforming the IR in such a manner will allow for subsequent passes to operate +on individual resource variable handles per core/device. + +For example, the following: + +```mlir +func @cluster(%arg0: tensor>>, %arg1: tensor>>) { + %partitioned_variable = "tf.TPUPartitionedInput"(%arg0, %arg1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor>>, tensor>>) -> tensor>> + %read = "tf.ReadVariableOp"(%partitioned_variable) : (tensor>>) -> tensor + %computation = "tf_device.cluster_func"(%read) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor + "tf.AssignVariableOp"(%partitioned_variable, %computation) : (tensor>>, tensor) -> () + return +} + +func @computation(%arg0: tensor) -> tensor { + return %arg0: tensor +} +``` + +will be transformed into: + +```mlir +func @cluster(%arg0: tensor>>, %arg1: tensor>>) { + %read0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + %read1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + %partitioned_input = "tf.TPUPartitionedInput"(%read0, %read1) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor, tensor) -> tensor + %computation = "tf_device.cluster_func"(%partitioned_input) {func = @computation, use_spmd_for_xla_partitioning = true} : (tensor) -> tensor + %partitioned_output:2 = "tf.TPUPartitionedOutput"(%computation) {N = 2 : i64, _XlaSharding = "", partition_dim = -1 : i64} : (tensor) -> (tensor, tensor) + "tf.AssignVariableOp"(%arg0, %partitioned_output#0) : (tensor>>, tensor) -> () + "tf.AssignVariableOp"(%arg1, %partitioned_output#1) : (tensor>>, tensor) -> () + return +} + +func @computation(%arg0: tensor) -> tensor { + return %arg0: tensor +} +``` +### `-tf-tpu-resource-read-for-write`: Inserts tf.ReadVariableOp inputs to a TPU cluster for resource writes with no reads +This pass materializes `tf.ReadVariableOp` inputs to an outlined TPU computation +for resource variables where only writes are present so later in the pipeline +such resource variables can be fused with generated `tf.TPUExecute` ops, which +only supports resource variable read or read + write. For all TPU computations, +resource variables are required to be initialized prior to execution. Write only +resource variable uses can be generated currently via packed tensor uses. + +For example, the following: + +```mlir +func @write_only_resource(%value: tensor, %resource: tensor<*x!tf.resource>>) { + %0 = "tf_device.cluster_func"(%value) {func = @cluster} : (tensor) -> tensor + "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +func @cluster(%arg0: tensor) -> tensor { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + return %identity : tensor +} +``` + +will be transformed into: + +```mlir +func @write_only_resource(%value: tensor, %resource: tensor<*x!tf.resource>>) { + %resource_read = "tf.ReadVariableOp"(%resource) : (tensor<*x!tf.resource>>) -> tensor + %0 = "tf_device.cluster_func"(%value, %resource_read) {func = @cluster} : (tensor, tensor) -> tensor + "tf.AssignVariableOp"(%resource, %0) : (tensor<*x!tf.resource>>, tensor) -> () + return +} + +func @cluster(%arg0: tensor, %arg1: tensor) -> tensor { + %identity = "tf.Identity"(%arg0) : (tensor) -> tensor + return %identity : tensor +} +``` +### `-tf-tpu-rewrite`: Rewrites a `tf_device.cluster_func` on TPUs into TPU runtime operations. +This pass rewrites a `tf_device.cluster_func` operation into a sequence of `tf._TPUCompileMlir` +and `tf.TPUExecute` operations. `tf._TPUCompileMlir` contains a MLIR module that is +functionally equivalent to the function referenced by `tf_device.cluster_func`. +This makes the module to be jit-compiled and executed on TPU. +If it is not possible to rewrite the operation or device assignment fails, +a failure will be returned. + +Note, many parameters to the `tf_device.cluster_func` are ommited in this +and following examples. +For example, a non replicated `tf_device.cluster_func`: + +```mlir +func @tf_tpu_rewrite(%arg0: tensor) { + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func} : (tensor) -> tensor + return +} +``` + +will be rewritten as: + +```mlir +func @tf_tpu_rewrite(%arg0: tensor) { + %0:2 = "tf_device.launch"() ( { + %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = ""} : () -> (tensor, tensor<3x!tf.string>) + tf_device.return %compilation_status, %program : tensor, tensor<3x!tf.string> + }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%0#0) : (tensor) -> () + tf_device.return + }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () + %1 = "tf_device.launch"() ( { + %2 = "tf.TPUExecute"(%arg0, %0#1) : (tensor, tensor<3x!tf.string>) -> tensor + tf_device.return %2 : tensor + }) {device = "/job:worker/replica:0/task:0/device:TPU:0"} : () -> tensor + return +} +``` + +A replicated `tf_device.cluster_func`: + +```mlir +func @tf_tpu_rewrite(%arg0: tensor, %arg1: tensor) { + %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor) {n = 2 : i32} { + %1 = "tf_device.cluster_func"(%ri) {_tpu_replicate = "cluster0", func = @func} : (tensor) -> tensor + tf_device.return %1 : tensor + } + return +} +``` + +will be rewritten as: + +```mlir +func @tf_tpu_rewrite(%arg0: tensor, %arg1: tensor) { + %0:2 = tf_device.replicate([%arg0, %arg1] as %arg2: tensor) {devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}, n = 2 : i32} { + %1:2 = "tf_device.launch"() ( { + %compilation_status, %program = "tf._TPUCompileMlir"() {mlir_module = ""} : () -> (tensor, tensor<3x!tf.string>) + tf_device.return %compilation_status, %program : tensor, tensor<3x!tf.string> + }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%1#0) : (tensor) -> () + tf_device.return + }) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> () + %2 = "tf_device.launch"() ( { + %3 = "tf.TPUExecute"(%arg2, %1#1) : (tensor, tensor<3x!tf.string>) -> tensor + tf_device.return %3 : tensor + }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor + tf_device.return %2 : tensor + } + return +} + +A non replicated `tf_device.cluster_func` with the model parallelism: + +```mlir +func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> { + %0 = "tf_device.cluster_func"(%arg0) {_tpu_replicate = "cluster0", func = @func, num_cores_per_replica = 2, input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<8xi32>) -> tensor<8xi32> + return %0 : tensor<8xi32> +} +``` + +will be rewritten as: + +```mlir +func @tf_tpu_rewrite(%arg0: tensor<8xi32>) -> tensor<8xi32> { + %0:3 = "tf_device.launch"() ( { + %compilation_status, %program:2 = "tf._TPUCompileMlir"() {mlir_module = ""} : () -> (tensor, tensor<3x!tf.string>, tensor<3x!tf.string>) + tf_device.return %compilation_status, %program#0, %program#1 : tensor, tensor<3x!tf.string>, tensor<3x!tf.string> + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> (tensor, tensor<3x!tf.string>, tensor<3x!tf.string>) + "tf_device.launch"() ( { + "tf.TPUCompileSucceededAssert"(%0#0) : (tensor) -> () + tf_device.return + }) {device = "/job:localhost/replica:0/task:0/device:CPU:0"} : () -> () + %1 = "tf_device.parallel_execute"() ( { + %2 = "tf_device.launch"() ( { + %3 = "tf.TPUExecute"(%arg0, %0#1) : (tensor<8xi32>, tensor<3x!tf.string>) -> tensor<8xi32> + tf_device.return %3 : tensor<8xi32> + }) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> tensor<8xi32> + tf_device.return %2 : tensor<8xi32> + }, { + "tf_device.launch"() ( { + "tf.TPUExecute"(%0#2) : (tensor<3x!tf.string>) -> () + tf_device.return + }) {device = "/job:localhost/replica:0/task:0/device:TPU:1"} : () -> () + tf_device.return + }) : () -> tensor<8xi32> + return %1 : tensor<8xi32> +} +``` +### `-tf-verify-for-export`: Verify module is suitable for export back to TF Graph +Verifies whether all functions in module are of single tf_executor.graph and +each tf_executor.island in tf_executor.graph only has a single op. diff --git a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md index 8e7e605fc4c10c..1130199fbae7eb 100644 --- a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md +++ b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md @@ -141,7 +141,7 @@ Conclusions: * ElementalIrEmitter ops go for (4), but not incrementally. There is no way to do it op by op, because all elementally-emitted ops are connected into the same graph. This work can also serve as a unification point of several - on-going forces (xla/service/mlir\_gpu, the kernel generator, Linalg). + on-going forces (the kernel generator, Linalg). * All other ops go for (1). As a stretch goal, they might be migrated to (3) or (4). diff --git a/tensorflow/compiler/mlir/hlo/.bazelrc b/tensorflow/compiler/mlir/hlo/.bazelrc new file mode 100644 index 00000000000000..840949acaef93c --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/.bazelrc @@ -0,0 +1,15 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +build --cxxopt=-std=c++14 +build --host_cxxopt=-std=c++14 diff --git a/tensorflow/compiler/mlir/hlo/.gitignore b/tensorflow/compiler/mlir/hlo/.gitignore index cc1696bf575e2c..53e833597c18de 100644 --- a/tensorflow/compiler/mlir/hlo/.gitignore +++ b/tensorflow/compiler/mlir/hlo/.gitignore @@ -1,4 +1,4 @@ build llvm-project llvm-build - +bazel-* diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 1636bbb89ee550..465304d08f6898 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -1,11 +1,8 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "filegroup") - # buildifier: disable=same-origin-load load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") -load("//third_party/mlir:tblgen.bzl", "gentbl") +load("//third_party/mlir:tblgen.bzl", "gentbl", "td_library") # TODO(b/160617323): Decouple MLIR HLO from TensorFlow/XLA load("//tensorflow:tensorflow.bzl", "tf_cc_test") @@ -35,34 +32,28 @@ package_group( ], ) -exports_files(["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"]) - -exports_files(["include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td"]) +exports_files([ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", +]) -filegroup( +td_library( name = "hlo_ops_td_files", - srcs = [ - "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td", - "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", - "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", - "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", - "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td", - "@llvm-project//mlir:OpBaseTdFiles", + srcs = glob(["include/mlir-hlo/Dialect/mhlo/IR/*.td"]) + [ + # TODO(gcmn): These should be encapsulate in a td_library. "@llvm-project//mlir:include/mlir/Interfaces/CopyOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", - "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", "@llvm-project//mlir:include/mlir/Interfaces/ViewLikeInterface.td", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td", ], -) - -filegroup( - name = "hlo_ops_base_td", - srcs = [ - "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + compatible_with = get_compatible_with_cloud(), + includes = ["include"], + deps = [ + "@llvm-project//mlir:MemRefOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectTdFiles", ], ) @@ -78,7 +69,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td", - td_srcs = [ + deps = [ "@llvm-project//mlir:PassBaseTdFiles", ], ) @@ -95,7 +86,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td", - td_srcs = [ + deps = [ "@llvm-project//mlir:PassBaseTdFiles", ], ) @@ -110,10 +101,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td", - td_relative_includes = [ - "include", - ], - td_srcs = [":hlo_ops_td_files"], + deps = [":hlo_ops_td_files"], ) gentbl( @@ -126,17 +114,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", - td_includes = [ - "include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td", - ], - td_relative_includes = [ - "include", - ], - td_srcs = [ - ":hlo_ops_base_td", - ":hlo_ops_td_files", - "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", - ], + deps = [":hlo_ops_td_files"], ) gentbl( @@ -149,10 +127,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", - td_relative_includes = [ - "include", - ], - td_srcs = [":hlo_ops_td_files"], + deps = [":hlo_ops_td_files"], ) gentbl( @@ -164,10 +139,19 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", - td_relative_includes = [ - "include", + deps = [":hlo_ops_td_files"], +) + +gentbl( + name = "hlo_ops_base_enums_inc_gen", + compatible_with = get_compatible_with_cloud(), + tbl_outs = [ + ("-gen-enum-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"), + ("-gen-enum-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.cc.inc"), ], - td_srcs = [":hlo_ops_td_files"], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td", + deps = [":hlo_ops_td_files"], ) gentbl( @@ -182,17 +166,26 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/IR/hlo_patterns.td", - td_relative_includes = [ - "include", - ], - td_srcs = [ + deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:StdOpsTdFiles", - "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td", - "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td", + "@llvm-project//mlir:TensorOpsTdFiles", ], ) +gentbl( + name = "lhlo_ops_structs_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td", + deps = [":hlo_ops_td_files"], +) + gentbl( name = "lhlo_ops_inc_gen", compatible_with = get_compatible_with_cloud(), @@ -200,15 +193,10 @@ gentbl( tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", - td_relative_includes = [ - "include", - ], - td_srcs = [":hlo_ops_td_files"], + deps = [":hlo_ops_td_files"], ) gentbl( @@ -221,12 +209,30 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", - td_relative_includes = [ - "include", + deps = [":hlo_ops_td_files"], +) + +gentbl( + name = "lhlo_gpu_ops_enums_inc_gen", + compatible_with = get_compatible_with_cloud(), + strip_include_prefix = "include", + tbl_outs = [ + ("-gen-enum-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc"), + ("-gen-enum-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc"), ], - td_srcs = [ - ":hlo_ops_td_files", - "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td", + deps = [":hlo_ops_td_files"], +) + +cc_library( + name = "hlo_ops_common", + srcs = ["lib/Dialect/mhlo/IR/hlo_ops_common.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h"], + includes = ["include"], + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -248,6 +254,23 @@ cc_library( ], ) +cc_library( + name = "lhlo_gpu_ops_enums", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc", + "lib/Dialect/mhlo/IR/lhlo_gpu_ops_enums.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h", + ], + includes = ["include"], + deps = [ + ":lhlo_gpu_ops_enums_inc_gen", + "@llvm-project//llvm:Support", + ], +) + gentbl( name = "lhlo_gpu_ops_inc_gen", compatible_with = get_compatible_with_cloud(), @@ -258,14 +281,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td", - td_relative_includes = [ - "include", - ], - td_srcs = [ - ":hlo_ops_td_files", - "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td", - "include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td", - ], + deps = [":hlo_ops_td_files"], ) #TODO(aminim): revisit the naming and grouping of these rules post-move. @@ -278,10 +294,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/IR/mhlo_canonicalize.td", - td_relative_includes = [ - "include", - ], - td_srcs = [":hlo_ops_td_files"], + deps = [":hlo_ops_td_files"], ) gentbl( @@ -299,12 +312,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td", - td_relative_includes = [ - "include", - ], - td_srcs = [ - ":hlo_ops_td_files", - ], + deps = [":hlo_ops_td_files"], ) cc_library( @@ -342,6 +350,22 @@ cc_library( ], ) +cc_library( + name = "hlo_ops_base_enums", + srcs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc", + "lib/Dialect/mhlo/IR/hlo_ops_base_enums.cc", + ], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h", + ], + includes = ["include"], + deps = [ + ":hlo_ops_base_enums_inc_gen", + "@llvm-project//llvm:Support", + ], +) + cc_library( name = "convert_op_folder", srcs = ["lib/utils/convert_op_folder.cc"], @@ -370,13 +394,15 @@ cc_library( ], includes = ["include"], deps = [ - "hlo_ops_pattern_gen", ":canonicalize_inc_gen", ":chlo_ops_inc_gen", ":convert_op_folder", + ":hlo_ops_base_enums", ":hlo_ops_base_inc_gen", ":hlo_ops_base_structs", + ":hlo_ops_common", ":hlo_ops_inc_gen", + ":hlo_ops_pattern_gen", ":infer_fusibility_op_interface", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -387,6 +413,7 @@ cc_library( "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], @@ -398,20 +425,28 @@ cc_library( srcs = [ "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h", + "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc", "lib/Dialect/mhlo/IR/lhlo_ops.cc", + "lib/Dialect/mhlo/IR/lhlo_ops_structs.cc", ], hdrs = [ "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h", ], includes = ["include"], deps = [ + ":hlo_ops_base_enums", ":hlo_ops_base_inc_gen", ":hlo_ops_base_structs", + ":hlo_ops_common", ":lhlo_ops_inc_gen", + ":lhlo_ops_structs_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:CopyOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", @@ -436,8 +471,10 @@ cc_library( includes = ["include"], deps = [ ":hlo", + ":hlo_ops_base_enums", ":hlo_ops_base_structs", ":infer_fusibility_op_interface", + ":lhlo_gpu_ops_enums", ":lhlo_gpu_ops_inc_gen", ":lhlo_gpu_ops_structs", "@llvm-project//llvm:Support", @@ -500,7 +537,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:TensorDialect", ], ) @@ -512,11 +549,23 @@ cc_library( ":lhlo", ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", ], ) +cc_library( + name = "map_chlo_to_hlo_op", + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"], + deps = [ + ":hlo", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "map_hlo_to_lhlo_op", hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"], @@ -538,6 +587,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -550,6 +600,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", @@ -559,57 +610,71 @@ cc_library( ) cc_library( - name = "lhlo_legalize_to_llvm", - srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"], - hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], + name = "legalize_to_linalg", + srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], deps = [ + ":hlo", ":lhlo", + ":map_lmhlo_to_scalar_op", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) cc_library( - name = "legalize_to_linalg", - srcs = ["lib/Dialect/mhlo/transforms/legalize_to_linalg.cc"], + name = "transform_unranked_hlo", + srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"], hdrs = [ "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", ], deps = [ ":hlo", - ":lhlo", - ":map_lmhlo_to_scalar_op", + ":map_chlo_to_hlo_op", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, ) cc_library( - name = "transform_unranked_hlo", - srcs = ["lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc"], + name = "move_up_dynamic_broadcasts_for_fusion", + srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"], hdrs = [ "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", ], deps = [ ":hlo", + ":map_chlo_to_hlo_op", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, @@ -645,10 +710,12 @@ cc_library( "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:ViewLikeInterface", ], @@ -672,7 +739,9 @@ cc_library( "@llvm-project//mlir:Shape", "@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:StandardOpsTransforms", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, @@ -726,10 +795,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td", - td_relative_includes = [ - "include", - ], - td_srcs = [ + deps = [ ":hlo_ops_td_files", "@llvm-project//mlir:StdOpsTdFiles", ], @@ -742,11 +808,11 @@ cc_library( deps = [ ":hlo", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", ], alwayslink = 1, ) @@ -764,6 +830,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], alwayslink = 1, ) @@ -798,6 +865,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", @@ -815,12 +883,8 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/transforms/lower_complex_patterns.td", - td_relative_includes = [ - "include", - ], - td_srcs = [ + deps = [ ":hlo_ops_td_files", - "@llvm-project//llvm:Support", "@llvm-project//mlir:StdOpsTdFiles", ], ) @@ -879,7 +943,9 @@ cc_library( ":hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) @@ -891,10 +957,13 @@ cc_library( deps = [ ":chlo_legalize_to_hlo_inc_gen", ":hlo", + ":map_chlo_to_hlo_op", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) @@ -911,12 +980,7 @@ gentbl( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td", - td_relative_includes = [ - "include", - ], - td_srcs = [ - ":hlo_ops_td_files", - ], + deps = [":hlo_ops_td_files"], ) cc_library( @@ -938,7 +1002,6 @@ cc_library( srcs = [ "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc", - "lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc", "lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc", "lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc", "lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc", @@ -948,7 +1011,6 @@ cc_library( ":chlo_legalize_to_hlo", # build-cleaner: keep ":hlo", ":lhlo", - ":lhlo_legalize_to_llvm", # build-cleaner: keep ":materialize_broadcasts", # build-cleaner: keep ":pass_details", ":unfuse_batch_norm", # build-cleaner: keep @@ -960,6 +1022,7 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, @@ -991,6 +1054,7 @@ cc_library( ":mhlo_control_flow_to_scf", ":mhlo_fusion", ":mhlo_to_mhlo_lowering_patterns", + ":move_up_dynamic_broadcasts_for_fusion", ":sink_constants_to_control_flow", ":test_passes", ":transform_unranked_hlo", diff --git a/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/CMakeLists.txt index c4e2ea123df839..8bfc0d2d01e878 100644 --- a/tensorflow/compiler/mlir/hlo/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/CMakeLists.txt @@ -41,27 +41,22 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") # Options and settings #------------------------------------------------------------------------------- -#------------------------------------------------------------------------------- -# MSVC defaults -#------------------------------------------------------------------------------- - -if(MSVC) - add_compile_options( - $<$:/MD> - $<$:/MD> - $<$:/MD> - ) -endif() +option(MHLO_BUILD_EMBEDDED "Build MHLO as part of another project" OFF) #------------------------------------------------------------------------------- # MLIR/LLVM Configuration #------------------------------------------------------------------------------- -find_package(MLIR REQUIRED CONFIG) -message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") -message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") -list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +# Find MLIR to install if we are building standalone. If building as part of +# another project, let it handle the MLIR dependency. The dependent project +# might use a bundled version of MLIR instead of installing, for instance. +if(NOT MHLO_BUILD_EMBEDDED) + find_package(MLIR REQUIRED CONFIG) + message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") + message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") + list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +endif() if(LLVM_ENABLE_ZLIB) find_package(ZLIB) diff --git a/tensorflow/compiler/mlir/hlo/README.md b/tensorflow/compiler/mlir/hlo/README.md index 61517cd9fca217..05aabe3f67e165 100644 --- a/tensorflow/compiler/mlir/hlo/README.md +++ b/tensorflow/compiler/mlir/hlo/README.md @@ -22,7 +22,7 @@ upstream. ## QuickStart: building and testing -These instructions work on Linux, you may have to adjust for your plaform. +These instructions work on Linux, you may have to adjust for your platform. To build the code in this repository, you need a clone of the LLVM/MLIR git repository: diff --git a/tensorflow/compiler/mlir/hlo/WORKSPACE b/tensorflow/compiler/mlir/hlo/WORKSPACE new file mode 100644 index 00000000000000..563df212e958ec --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/WORKSPACE @@ -0,0 +1,57 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Workspace for MLIR HLO.""" + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +LLVM_COMMIT = "" + +LLVM_SHA256 = "" + +LLVM_BAZEL_TAG = "llvm-project-{commit}".format(commit = LLVM_COMMIT) + +http_archive( + name = "llvm-bazel", + strip_prefix = "llvm-bazel-{tag}/llvm-bazel".format(tag = LLVM_BAZEL_TAG), + url = "https://github.com/google/llvm-bazel/archive/{tag}.tar.gz".format(tag = LLVM_BAZEL_TAG), +) + +load("@llvm-bazel//:terminfo.bzl", "llvm_terminfo_disable") +load("@llvm-bazel//:zlib.bzl", "llvm_zlib_disable") +load("@llvm-bazel//:configure.bzl", "llvm_configure") + +http_archive( + name = "llvm-project-raw", + build_file_content = "#empty", + sha256 = LLVM_SHA256, + strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT), + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), + "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), + ], +) + +llvm_terminfo_disable( + name = "llvm_terminfo", +) + +llvm_zlib_disable( + name = "llvm_zlib", +) + +llvm_configure( + name = "llvm-project", + src_path = ".", + src_workspace = "@llvm-project-raw//:WORKSPACE", +) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt index 3fa2b908d9cf4a..8b50b5894ab715 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -14,7 +14,7 @@ # limitations under the License. # # Need a separate function because of the .cc vs .cpp used in the one provided by MLIR -function(add_mlir_hlo_dialect dialect dialect_namespace) +function(add_mlir_hlo_dialect dialect) set(LLVM_TARGET_DEFINITIONS ${dialect}.td) mlir_tablegen(${dialect}.h.inc -gen-op-decls) mlir_tablegen(${dialect}.cc.inc -gen-op-defs) @@ -24,23 +24,34 @@ function(add_mlir_hlo_dialect dialect dialect_namespace) add_dependencies(mlir-headers MLIR${dialect}IncGen) endfunction() -add_mlir_hlo_dialect(chlo_ops chlo) -add_mlir_hlo_dialect(lhlo_ops lmhlo) +add_mlir_hlo_dialect(chlo_ops) set(LLVM_TARGET_DEFINITIONS hlo_ops.td) mlir_tablegen(hlo_ops.h.inc -gen-op-decls) mlir_tablegen(hlo_ops.cc.inc -gen-op-defs) mlir_tablegen(hlo_ops_base_structs.h.inc -gen-struct-attr-decls) mlir_tablegen(hlo_ops_base_structs.cc.inc -gen-struct-attr-defs) +mlir_tablegen(hlo_ops_base_enums.h.inc -gen-enum-decls) +mlir_tablegen(hlo_ops_base_enums.cc.inc -gen-enum-defs) add_public_tablegen_target(MLIRhlo_opsIncGen) -set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops.td) -mlir_tablegen(lhlo_gpu_ops.h.inc -gen-op-decls) -mlir_tablegen(lhlo_gpu_ops.cc.inc -gen-op-defs) -set(LLVM_TARGET_DEFINITIONS lhlo_gpu_ops_structs.td) -mlir_tablegen(lhlo_gpu_ops_structs.h.inc -gen-struct-attr-decls) -mlir_tablegen(lhlo_gpu_ops_structs.cc.inc -gen-struct-attr-defs) -add_public_tablegen_target(MLIRlhlo_gpu_opsIncGen) -add_dependencies(mlir-headers MLIRlhlo_gpu_opsIncGen) +function(add_mlir_hlo_dialect_separate_files dialect has_enums) + set(LLVM_TARGET_DEFINITIONS ${dialect}.td) + mlir_tablegen(${dialect}.h.inc -gen-op-decls) + mlir_tablegen(${dialect}.cc.inc -gen-op-defs) + set(LLVM_TARGET_DEFINITIONS ${dialect}_structs.td) + mlir_tablegen(${dialect}_structs.h.inc -gen-struct-attr-decls) + mlir_tablegen(${dialect}_structs.cc.inc -gen-struct-attr-defs) + if(${has_enums}) + set(LLVM_TARGET_DEFINITIONS ${dialect}_enums.td) + mlir_tablegen(${dialect}_enums.h.inc -gen-enum-decls) + mlir_tablegen(${dialect}_enums.cc.inc -gen-enum-defs) + endif() + add_public_tablegen_target(MLIR${dialect}IncGen) + add_dependencies(mlir-headers MLIR${dialect}IncGen) +endfunction() + +add_mlir_hlo_dialect_separate_files(lhlo_ops NO) +add_mlir_hlo_dialect_separate_files(lhlo_gpu_ops YES) add_mlir_interface(infer_fusibility_op_interface) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 05b22770401ca6..b1795315bab33b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -18,12 +18,12 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -36,7 +36,7 @@ class HloClientDialect : public Dialect { void initialize(); public: - explicit HloClientDialect(MLIRContext *context) + explicit HloClientDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context, TypeID::get()) { initialize(); @@ -66,6 +66,16 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant, return b.create(loc, getAttr(), val); } +Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant, + Value val); + +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val); + +Value getConstantLikeInfValue(OpBuilder& b, Location loc, Value val, + bool negative); + +Value getConstantLikeSmallestFiniteValue(OpBuilder& b, Location loc, Value val); + } // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 13d5f02368b164..b3f81a029e094c 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -66,7 +66,7 @@ class HLOClient_Op traits> : // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // shape broadcasting. // -// These correspond to operations in the mhlo dialect without the +// These correspond to operations in the chlo and mhlo dialects without the // "broadcast_" prefix, except that those ops require same-shaped operands and // results. // @@ -89,10 +89,9 @@ class HLOClient_BroadcastBinaryElementwiseOp< OptionalAttr:$broadcast_dimensions ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value left, Value right, " - "DenseIntElementsAttr broadcast_dimensions" - >]; + let builders = [ + OpBuilder<(ins "Value":$left, "Value":$right, + "DenseIntElementsAttr":$broadcast_dimensions)>]; let results = (outs HLO_Tensor); @@ -179,6 +178,15 @@ def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp< }]; } +def HLOClient_BroadcastPolygammaOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_polygamma", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Polygamma function (with optional broadcasting)"; + + let description = [{ + Returns `Polygamma(operand, operand)` element-wise. + }]; +} + def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp< "broadcast_power", [NoSideEffect, SameOperandsAndResultElementType]> { @@ -257,8 +265,31 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp< }]; } +def HLOClient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_zeta", + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Hurwitz zeta function"; + + let description = [{ + Returns `Zeta(operand, operand)` element-wise. + + $$ + \(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\) + $$ + }]; + + let arguments = (ins + HLO_FpTensor:$lhs, + HLO_FpTensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); + let results = (outs HLO_FpTensor); +} + //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// XLA binary logical elementwise op definitions. // The same description as the arithmetic binary elementwise ops applies. //===----------------------------------------------------------------------===// @@ -310,6 +341,47 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp< }]; } +//===----------------------------------------------------------------------===// +// XLA non-broadcasting binary operations. +// +// These are operations that are supported by the XLA Builder API but that are +// not part of the HLO compiler instructions as modelled by the MHLO dialect. +//===----------------------------------------------------------------------===// + +def HLOClient_ZetaOp : HLOClient_Op<"zeta", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Hurwitz zeta function"; + let description = [{ + Returns `Zeta(operand, operand)` element-wise. + + $$ + \(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\) + $$ + }]; + + let arguments = (ins HLO_FpTensor:$x, HLO_FpTensor:$q); + let results = (outs HLO_FpTensor:$result); + + let assemblyFormat = [{ + $x `,` $q attr-dict `:` type($x) `,` type($q) `->` type(results) + }]; +} + +def HLOClient_PolygammaOp : HLOClient_Op<"polygamma", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Polygamma function"; + let description = [{ + Returns `Polygamma(operand, operand)` element-wise. + }]; + + let arguments = (ins HLO_FpTensor:$n, HLO_FpTensor:$x); + let results = (outs HLO_FpTensor:$result); + + let assemblyFormat = [{ + $n `,` $x attr-dict `:` type($n) `,` type($x) `->` type(results) + }]; +} + //===----------------------------------------------------------------------===// // Broadcasting complex op //===----------------------------------------------------------------------===// @@ -338,16 +410,19 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< //===----------------------------------------------------------------------===// class HLOClient_UnaryElementwiseOp traits, - Type TensorType> : HLOClient_Op { - let arguments = (ins TensorType:$operand); - let results = (outs TensorType:$result); + Type ArgTensorType, Type ResultTensorType> : HLOClient_Op { + let arguments = (ins ArgTensorType:$operand); + let results = (outs ResultTensorType:$result); - let assemblyFormat = "$operand attr-dict `:` type($operand)"; + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; } -def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], - HLO_FpOrComplexTensor> { +def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { let summary = "Acos operator"; let description = [{ @@ -360,8 +435,48 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], }]; } -def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], - HLO_FpOrComplexTensor> { +def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { + let summary = "Acosh operation"; + + let description = [{ + Returns `Acosh(operand)` element-wise. + + $$ + \acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 + \acosh(x) = nan if x < -1 + $$ + }]; +} + +def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { + let summary = "Asin operator"; + + let description = [{ + Returns `Asin(operand)` element-wise. + + $$ + \asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) + $$ + }]; +} + +def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { + let summary = "Asinh operation"; + + let description = [{ + Returns `Asinh(operand)` element-wise. + + $$ + \asinh(x) = log(x + sqrt(x^2 + 1)) + $$ + }]; +} + +def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { let summary = "Atan operator"; let description = [{ @@ -373,8 +488,48 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], }]; } -def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], - HLO_FpOrComplexTensor> { +def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { + let summary = "Atanh operator"; + + let description = [{ + Returns `Atanh(operand)` element-wise. + + $$ + \atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 + = nan otherwise + $$ + }]; +} + +def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { + let summary = "Conj operator"; + + let description = [{ + Returns `Conj(operand)` element-wise. + + $$ + \conj(x) = (\real(x), \neg(\imag(x))) + $$ + }]; +} + +def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { + let summary = "Cosh operator"; + + let description = [{ + Returns `Cosh(operand)` element-wise. + + $$ + \cosh(x) = (e^x + e^-x) / 2 + $$ + }]; +} + +def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { let summary = "Sinh operation"; let description = [{ @@ -387,8 +542,8 @@ def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], }]; } -def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", [], - HLO_FpOrComplexTensor> { +def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan", + [SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> { let summary = "Tan operation"; let description = [{ @@ -418,6 +573,78 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", let hasCanonicalizer = 1; } +def HLOClient_DigammaOp : HLOClient_UnaryElementwiseOp<"digamma", + [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { + let summary = "Digamma function"; + + let description = [{ + Returns `Digamma(operand)` element-wise. + }]; +} + +def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf", + [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { + let summary = "Erfc operator"; + + let description = [{ + Computes the Gauss error function of `x` element-wise. + + erf(x) = erf_impl(x) if |x| < 1 + = 1 - erfc_impl(x) otherwise + }]; +} + +def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc", + [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { + let summary = "Erfc operator"; + + let description = [{ + Computes an approximation of the error function complement (1 - erf(x)). + + erfc(x) = erfc_impl(x) if |x| > 1 + = 1 - erf_impl(x) otherwise + }]; +} + +def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf", + [DeclareOpInterfaceMethods], HLO_FpTensor, + HLO_PredTensor> { + let summary = "IsInf predicate"; + + let description = [{ + Returns if a value is +/-inf element-wise. + }]; +} + +def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf", + [DeclareOpInterfaceMethods], HLO_FpTensor, + HLO_PredTensor> { + let summary = "IsNegInf predicate"; + + let description = [{ + Returns if a value is -inf element-wise. + }]; +} + +def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", + [DeclareOpInterfaceMethods], HLO_FpTensor, + HLO_PredTensor> { + let summary = "IsPosInf predicate"; + + let description = [{ + Returns if a value is +inf element-wise. + }]; +} + +def HLOClient_LgammaOp : HLOClient_UnaryElementwiseOp<"lgamma", + [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { + let summary = "Lgamma function"; + + let description = [{ + Returns `Lgamma(operand)` element-wise. + }]; +} + //===----------------------------------------------------------------------===// // Broadcasting compare op //===----------------------------------------------------------------------===// @@ -427,7 +654,10 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< string summary = "Compare operator (with optional broadcasting)"; string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. See https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. @@ -437,14 +667,90 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< HLO_Tensor:$lhs, HLO_Tensor:$rhs, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); let results = (outs HLO_PredTensor); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" - >]; + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs, + "DenseIntElementsAttr":$broadcast_dimensions, + "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>]; +} + +//===----------------------------------------------------------------------===// +// Broadcasting select op +//===----------------------------------------------------------------------===// + +def HLOClient_BroadcastSelectOp : HLOClient_Op< + "broadcast_select", + [NoSideEffect, DeclareOpInterfaceMethods]> { + string summary = "Select operator (with optional numpy-style broadcasting)"; + + string description = [{ + Constructs an output array from elements of two input arrays, based on the + values of a predicate array. + + See https://www.tensorflow.org/xla/operation_semantics#select + }]; + + let arguments = (ins + HLO_PredTensor:$pred, + HLO_Tensor:$on_true, + HLO_Tensor:$on_false + ); + + let results = (outs HLO_Tensor); + + let assemblyFormat = [{ + $pred `,` $on_true `,` $on_false attr-dict `:` + `(` type($pred) `,` type($on_true) `,` type($on_false) `)` `->` type(results) + }]; +} + +//===----------------------------------------------------------------------===// +// Helper ops +//===----------------------------------------------------------------------===// + +def HLOClient_MinimumBroadcastShapesOp : + HLOClient_Op<"minimum_broadcast_shapes", [NoSideEffect]> { + string summary = "Minimizes the rank of two or more shapes to be broadcasted"; + + string description = [{ + Given two or more 1D tensors representing shapes, returns one 1D tensor for + each operand, where operand `i` corresponds to output `i`. + + The returned tensors have the property that they specify a shape which is a + reshape of the corresponding input shape, and the broadcasted output shape + (using shape::BroadcastOp) of the returned shapes is a reshape of the + broadcasted output shape of the input shapes. Among all possibilities with + this property, the one is chosen which minimizes the rank of each returned + shape. + + The general idea of this op is that it can be used for ops which have a + broadcasting semantic to operate on shapes with a possibly smaller rank + while preserving equivalence of the computed values. After computing the + result of the op using reshaped operands, the result can be reshaped to the + result that would have been originally computed. + + Here is an example with two input shapes: + + ```mlir + chlo.minimum_broadcast_shapes [1, 2, 3, 1, 2, 1], + [1, 1, 1, 2, 3] -> [6, 2, 1], [2, 3] + ``` + + The broadcasted output shape of the operands is [1, 2, 3, 1, 2, 3], the + broadcasted output shape of the outputs is [6, 2, 3]. These two shapes are + reshapes of each other, and also each output is a reshape of the + corresponding input. + }]; + + let arguments = (ins Variadic<1DTensorOf<[Index]>>:$shapes); + let results = (outs Variadic<1DTensorOf<[Index]>>:$results); + + let assemblyFormat = "$shapes attr-dict `:` type($shapes) `->` type($results)"; + } #endif // CHLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index b354189c12a612..21e9c9f07ddd3f 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -21,19 +21,21 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" // clang-format off #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" // clang-format on @@ -80,6 +82,9 @@ LogicalResult deriveShapeFromFirstOperand( OpBuilder *builder, Operation *op, SmallVectorImpl *reifiedReturnShapes); +// Type derivation function that returns a tensor type with a new element type. +TensorType getSameShapeTensorType(TensorType tensor_type, Type element_type); + } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 579e89ca1375c2..52a5a495795196 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -41,7 +41,9 @@ def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">; def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">; def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [ HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION -]>; +]> { + let cppNamespace = "::mlir::mhlo"; +} //===----------------------------------------------------------------------===// // MHLO nullary op definitions. @@ -58,9 +60,8 @@ def HLO_ConstOp : HLO_Op<"constant", HLO_StaticShapeTensor:$output ); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Attribute value" - >]; + let builders = [ + OpBuilder<(ins "Attribute":$value)>]; let assemblyFormat = "attr-dict $value"; @@ -118,38 +119,37 @@ def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions class HLO_UnaryElementwiseOp traits, - Type TensorType>: HLO_Op { - let arguments = (ins TensorType:$operand); - let results = (outs TensorType); - let extraClassDeclaration = [{ - static LogicalResult inferReturnTypeComponents( - MLIRContext* context, Optional location, - ValueRange operands, DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl& inferedReturnShapes) { - return failure(); - } - LogicalResult reifyReturnTypeShapes( - OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { - return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), - &reifiedReturnShapes); - } - bool inferInputOutputShapeEquality(int input, int output) { - return true; - } - llvm::Optional inferEffectiveWorkloadShape() { - return getOperation()->getResult(0); - } - }]; + Type TensorType> : HLO_Op { + let arguments = (ins TensorType:$operand); + let results = (outs TensorType); + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypeComponents( + MLIRContext* context, Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + return failure(); + } + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); + } + bool inferInputOutputShapeEquality(int input, int output) { + return true; + } + llvm::Optional inferEffectiveWorkloadShape() { + return getOperation()->getResult(0); + } + }]; } // Abs supports complex to real, so element type is not guaranteed to match. def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", - [NoSideEffect, SameOperandsAndResultShape], + [NoSideEffect, + DeclareOpInterfaceMethods], TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { - let builders = [OpBuilder< - "Value operand" - >]; } def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", @@ -158,13 +158,11 @@ def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; -def HLO_ConvertOp : HLO_UnaryElementwiseOp< - "convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, +def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert", + [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, BASE_HLO_ConvertOp { - - let builders = [OpBuilder< - "Value operand, Type result_element_ty" - >]; + let builders = [ + OpBuilder<(ins "Value":$operand, "Type":$result_element_ty)>]; let hasFolder = 1; @@ -191,15 +189,14 @@ def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", - [NoSideEffect, SameOperandsAndResultShape, - DeclareOpInterfaceMethods], + [NoSideEffect, DeclareOpInterfaceMethods], HLO_ComplexTensor>, BASE_HLO_ImagOp { let results = (outs HLO_FpTensor); let hasFolder = 1; } -def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", - [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, +def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, + DeclareOpInterfaceMethods], HLO_Tensor>, BASE_HLO_IsFiniteOp { let arguments = (ins HLO_FpTensor:$x); let results = (outs HLO_PredTensor:$y); @@ -220,6 +217,7 @@ def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", def HLO_NotOp: HLO_UnaryElementwiseOp<"not", [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, BASE_HLO_NotOp { + let hasFolder = 1; } def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", @@ -233,8 +231,7 @@ def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", BASE_HLO_PopulationCountOp; def HLO_RealOp: HLO_UnaryElementwiseOp<"real", - [NoSideEffect, SameOperandsAndResultShape, - DeclareOpInterfaceMethods], + [NoSideEffect, DeclareOpInterfaceMethods], HLO_ComplexTensor>, BASE_HLO_RealOp { let results = (outs HLO_FpTensor); let hasFolder = 1; @@ -274,7 +271,8 @@ def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations class HLO_BinaryElementwiseOp traits> : - HLO_Op { + HLO_Op { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs @@ -317,8 +315,7 @@ def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", - [NoSideEffect, SameOperandsAndResultShape, - DeclareOpInterfaceMethods]>, + [NoSideEffect, DeclareOpInterfaceMethods]>, BASE_HLO_ComplexOp { let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let results = (outs HLO_ComplexTensor); @@ -403,12 +400,18 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> { of the data. Multiple Infeed operations are allowed in a computation, but there must be a total order among the Infeed operations. + Attributes: + layout: Array attribute. Same shape as the output of the infeed, except + that every tensor is replaced by a minor_to_major array for the + tensor's layout. + See https://www.tensorflow.org/xla/operation_semantics#infeed. }]; let arguments = (ins HLO_Token:$token, - DefaultValuedAttr:$infeed_config + DefaultValuedAttr:$infeed_config, + OptionalAttr:$layout ); let results = (outs HLO_Tuple); let hasCustomHLOConverter = 1; @@ -491,7 +494,8 @@ def HLO_RecvOp : HLO_Op<"recv", []> { // MHLO parallelism related op definitions. //===----------------------------------------------------------------------===// -def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, +def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect, + DeclareOpInterfaceMethods]>, BASE_HLO_ReplicaIdOp { let results = (outs TensorOf<[UI32]>); } @@ -618,10 +622,9 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ let results = (outs Variadic); - let builders = [OpBuilder< - "OpBuilder &, OperationState &state, ValueRange operands, " - "ValueRange init_values, DenseIntElementsAttr dimensions" - >]; + let builders = [ + OpBuilder<(ins "ValueRange":$operands, "ValueRange":$init_values, + "DenseIntElementsAttr":$dimensions)>]; let extraClassDeclaration = [{ bool isFusibleWithConsumer() { @@ -657,18 +660,16 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &results, " - "Value value, int32_t index">]; + let builders = [ + OpBuilder<(ins "Value":$value, "int32_t":$index)>]; } def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { let arguments = (ins Variadic:$val); let results = (outs HLO_Tuple); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &results, " - "ValueRange values">]; + let builders = [ + OpBuilder<(ins "ValueRange":$values)>]; let hasCanonicalizer = 1; } @@ -680,16 +681,19 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); let results = (outs HLO_PredTensor); let hasFolder = 1; - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " - "StringAttr comparison_direction" - >]; + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs, + "StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>, + ]; + + let hasCustomHLOConverter = 1; } //===----------------------------------------------------------------------===// @@ -699,7 +703,8 @@ def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, def HLO_SliceOp: HLO_Op< "slice", [NoSideEffect, SameOperandsAndResultElementType, - AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { + AllTypesMatch<["start_indices", "limit_indices", "strides"]>, + DeclareOpInterfaceMethods]> { let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$start_indices, @@ -711,25 +716,10 @@ def HLO_SliceOp: HLO_Op< let hasCanonicalizer = 1; let hasFolder = 1; - - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value operand, " - "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " - "DenseIntElementsAttr strides" - >]; - - let extraClassDeclaration = [{ - // Infers output type for given operand and attributes. Result type is - // unranked if any of the attributes is illegal. - static Type InferOutputTypes(Builder *builder, Value operand, - DenseIntElementsAttr start_indices, - DenseIntElementsAttr limit_indices, - DenseIntElementsAttr strides); - }]; } def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", - [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> { + [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]>, BASE_HLO_DynamicSliceOp { let arguments = (ins HLO_Tensor:$operand, Variadic:$start_indices, @@ -742,7 +732,7 @@ def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>, - AllShapesMatch<["operand", "result"]>]> { + AllShapesMatch<["operand", "result"]>]>, BASE_HLO_DynamicUpdateSliceOp { let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$update, @@ -835,8 +825,9 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } -def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", - [NoSideEffect]> { +def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [ + NoSideEffect, DeclareOpInterfaceMethods]> { string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; string description = [{ This is a generalization of the BroadcastInDimOp which accepts its output @@ -884,7 +875,8 @@ def HLO_ClampOp : HLO_Op<"clamp", } def HLO_ConcatenateOp : HLO_Op<"concatenate", - [NoSideEffect, SameOperandsAndResultElementType, DeclareOpInterfaceMethods]>, BASE_HLO_ConcatenateOp { + [NoSideEffect, SameOperandsAndResultElementType, + DeclareOpInterfaceMethods]>, BASE_HLO_ConcatenateOp { let arguments = (ins Variadic:$val, @@ -896,6 +888,11 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate", let hasCanonicalizer = 1; let hasFolder = 1; + let extraClassDeclaration = [{ + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { + return succeeded(mlir::verifyCompatibleShapes(l, r)); + } + }]; } def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", @@ -913,12 +910,14 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs), - ConvolutionAttributes.attributes); + ConvolutionAttributes.attributes); let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; } -def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { +def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, + BASE_HLO_CopyOp { let arguments = (ins HLO_Tensor); let results = (outs HLO_Tensor); let hasFolder = 1; @@ -942,7 +941,7 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { DefaultValuedAttr:$has_side_effect, DefaultValuedAttr:$backend_config ); - let results = (outs HLO_Tensor); + let results = (outs Variadic); let hasCustomHLOConverter = 1; } @@ -955,7 +954,8 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let results = (outs HLO_Tensor); } -def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { +def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, + BASE_HLO_DotGeneralOp { let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -965,6 +965,9 @@ def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneral let results = (outs HLO_Tensor); let verifier = [{ return Verify(*this); }]; + // DotGeneral op required custom exporter to pass the preferred element type + // to Xla builder. + let hasCustomHLOConverter = 1; } // Define Base Einsum op within the HLO dialect as these are client ops and @@ -1034,7 +1037,7 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, BASE_HLO_GetDimensionSizeOp { let arguments = (ins HLO_Tensor:$operand, - I32Attr:$dimension + I64Attr:$dimension ); // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the // XLA semantics is available. This limitation is because of the current XLA @@ -1063,6 +1066,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape", let results = (outs HLO_StaticShapeTensor); let hasFolder = 1; + let hasCanonicalizer = 1; let hasCustomHLOConverter = 1; } @@ -1146,12 +1150,15 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, let arguments = (ins HLO_Tensor:$operand, I32Tensor:$size, - I32Attr:$dimension + I64Attr:$dimension ); let results = (outs HLO_Tensor); + + let hasFolder = 1; } -def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp { +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, + SameOperandsAndResultShape]>, BASE_HLO_SortOp { let arguments = (ins Variadic:$operands, DefaultValuedAttr:$dimension, @@ -1162,10 +1169,9 @@ def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShap let regions = (region SizedRegion<1>:$comparator); - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &state, ValueRange operands, " - "int64_t dimension = -1, bool is_stable = false" - >]; + let builders = [ + OpBuilder<(ins "ValueRange":$operands, CArg<"int64_t", "-1">:$dimension, + CArg<"bool", "false">:$is_stable)>]; // TODO(b/129422361): SortOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -1328,7 +1334,8 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { let hasCustomHLOConverter = 1; } -def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp { +def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, + BASE_HLO_RngBitGeneratorOp { let arguments = (ins // TODO(jpienaar): This could be an enum instead. I32Attr:$rng_algorithm, @@ -1391,8 +1398,9 @@ def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp { let hasCustomHLOConverter = 1; } -def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>, - BASE_HLO_ReducePrecisionOp { +def HLO_ReducePrecisionOp : + HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>, + BASE_HLO_ReducePrecisionOp { let arguments = (ins HLO_FpTensor:$operand, I32Attr:$exponent_bits, diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index da8c921a47bfa9..896fe0fff05285 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -23,6 +23,7 @@ def HLO_Dialect : Dialect { let cppNamespace = "::mlir::mhlo"; } +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td" def HLO_Pred : TypeAlias; @@ -98,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; +// A layout attribute (1D tensor of index type) +def HLO_LayoutAttr : Attr< + And<[IndexElementsAttr.predicate, + CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank() + == 1}]>]>, + "A 1D tensor of index type (layout)"> { + let storageType = IndexElementsAttr.storageType; + let returnType = IndexElementsAttr.returnType; + let convertFromStorage = IndexElementsAttr.convertFromStorage; +} + //===----------------------------------------------------------------------===// // MHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -636,6 +648,23 @@ class BASE_HLO_ReplicaIdOp { }]; } +class BASE_HLO_PartitionIdOp { + string summary = "PartitionId operator"; + + string description = [{ + Returns the unique ID (int32 scalar) of the partition. + }]; +} + +class BASE_HLO_AllGatherOp { + string summary = "AllGather operator"; + + string description = [{ + Performs concatenation across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allgather + }]; +} class BASE_HLO_AllReduceOp { string summary = "AllReduce operator"; @@ -692,68 +721,17 @@ class BASE_HLO_TupleOp { }]; } -//===----------------------------------------------------------------------===// -// Precision Config enum definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA PrecisionConfig proto enum. -def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">; -def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">; -def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">; - -def HLO_PrecisionAttr : StrEnumAttr<"Precision", - "XLA precision for an operand. Has backend specific meaning.", - [HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>; - -// TODO(b/129153247) See if it's possible to also validate the size. -def HLO_PrecisionConfigAttr: - OptionalAttr< - TypedArrayAttrBase>; - -//===----------------------------------------------------------------------===// -// Fast Fourier Transform Type enum definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA FftType proto enum. -def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">; -def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">; -def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">; -def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">; -def HLO_FftTypeAttr : StrEnumAttr<"FftType", - "XLA fast fourier transform type.", - [HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT, - HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]>; - -//===----------------------------------------------------------------------===// -// Comparison op definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA ComparisonDirection enum. -def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">; -def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">; -def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">; -def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">; -def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">; -def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">; - -def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", - "Which comparison operation to perform.", - [ - HLO_COMPARISON_DIRECTION_EQ, - HLO_COMPARISON_DIRECTION_NE, - HLO_COMPARISON_DIRECTION_GE, - HLO_COMPARISON_DIRECTION_GT, - HLO_COMPARISON_DIRECTION_LE, - HLO_COMPARISON_DIRECTION_LT - ]>; class BASE_HLO_CompareOp { string summary = "Comparison operator"; string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. See https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. @@ -764,13 +742,6 @@ class BASE_HLO_CompareOp { // Quantize op definitions. //===----------------------------------------------------------------------===// -// These mirror the XLA ComparisonDirection enum. -def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">; - -def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode", - "Dequantization mode. Only MIN_COMBINED is supported.", - [HLO_MIN_COMBINED]>; - class BASE_HLO_DequantizeOp { string summary = "Dequantize operator"; @@ -1010,7 +981,23 @@ class BASE_HLO_ConcatenateOp { // Common convolution attributes //===----------------------------------------------------------------------===// -class ConvolutionAttributes { +// TODO(b/129153247) See if it's possible to also validate the size. +def HLO_PrecisionConfigAttr: + OptionalAttr< + TypedArrayAttrBase>; + +def BoolElementsAttr : + ElementsAttrBase< + And<[CPred<"$_self.isa<::mlir::DenseIntOrFPElementsAttr>()">, + CPred<"$_self.cast<::mlir::DenseIntOrFPElementsAttr>().getType().getElementType().isInteger(1)">]>, + "constant boolean vector/tensor attribute"> { + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +def ConvolutionAttributes { dag attributes = (ins // Default value: one for each of the spatial dimension. OptionalAttr:$window_strides, @@ -1020,6 +1007,8 @@ class ConvolutionAttributes { OptionalAttr:$lhs_dilation, // Default value: one for each of the spatial dimension. OptionalAttr:$rhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_reversal, ConvDimensionNumbers:$dimension_numbers, I64Attr:$feature_group_count, I64Attr:$batch_group_count, @@ -1035,6 +1024,14 @@ class BASE_HLO_ConvOp { See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. }]; + + code extraClassDeclaration = [{ + bool hasWindowReversal() { + auto reversal = window_reversalAttr(); + return reversal && llvm::any_of(reversal.getBoolValues(), + [](bool v) { return v; }); + } + }]; } class BASE_HLO_CopyOp { @@ -1251,21 +1248,6 @@ class BASE_HLO_TransposeOp { }]; } -// These mirror the XLA Transpose enum in Triangular Solve options. -def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">; -def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">; -def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">; -def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">; - -def HLO_TransposeAttr : StrEnumAttr<"Transpose", - "Transpose options", - [ - HLO_TRANSPOSE_INVALID, - HLO_NO_TRANSPOSE, - HLO_TRANSPOSE, - HLO_ADJOINT - ]>; - class BASE_HLO_TriangularSolveOp { string summary = "TriangularSolve operator"; @@ -1363,7 +1345,7 @@ class BASE_HLO_BitcastOp { string description = [{ This op changes the shape of the input in the way that the physical - arranggment of elements are unchanged. + arrangement of elements are unchanged. However, the op needs layout information to make sense of "physical arrangement of elements". Layout support in MHLO is currently under diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h new file mode 100644 index 00000000000000..38414b49003f72 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines enums used in MHLO and LMHLO. +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_ENUMS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td new file mode 100644 index 00000000000000..eb1830aed8ca14 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.td @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef HLO_OPS_BASE_ENUMS +#define HLO_OPS_BASE_ENUMS + +//===----------------------------------------------------------------------===// +// Precision Config enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA PrecisionConfig proto enum. +def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">; +def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">; +def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">; + +def HLO_PrecisionAttr : StrEnumAttr<"Precision", + "XLA precision for an operand. Has backend specific meaning.", + [HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]> { + let cppNamespace = "::mlir::mhlo"; +} + +//===----------------------------------------------------------------------===// +// Fast Fourier Transform Type enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA FftType proto enum. +def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">; +def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">; +def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">; +def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">; + +def HLO_FftTypeAttr : StrEnumAttr<"FftType", + "XLA fast fourier transform type.", + [HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT, + HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]> { + let cppNamespace = "::mlir::mhlo"; +} + +//===----------------------------------------------------------------------===// +// Comparison op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA ComparisonDirection enum. +def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">; +def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">; +def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">; +def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">; +def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">; +def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">; + +def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", + "Which comparison operation to perform.", + [ + HLO_COMPARISON_DIRECTION_EQ, + HLO_COMPARISON_DIRECTION_NE, + HLO_COMPARISON_DIRECTION_GE, + HLO_COMPARISON_DIRECTION_GT, + HLO_COMPARISON_DIRECTION_LE, + HLO_COMPARISON_DIRECTION_LT + ]> { + let cppNamespace = "::mlir::mhlo"; +} + +def HLO_DEFAULT_COMPARISON_TYPE : NativeCodeCall<"StringAttr()">; +def HLO_COMPARISON_TYPE_FLOAT : StrEnumAttrCase<"FLOAT">; +def HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER : StrEnumAttrCase<"TOTALORDER">; +def HLO_COMPARISON_TYPE_SIGNED : StrEnumAttrCase<"SIGNED">; +def HLO_COMPARISON_TYPE_UNSIGNED : StrEnumAttrCase<"UNSIGNED">; + +def HLO_ComparisonTypeAttr : StrEnumAttr<"ComparisonType", + "Which comparison type to use.", + [ + HLO_COMPARISON_TYPE_FLOAT, + HLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + HLO_COMPARISON_TYPE_SIGNED, + HLO_COMPARISON_TYPE_UNSIGNED + ]> { + let cppNamespace = "::mlir::mhlo"; +} + +// These mirror the XLA Dequantize mode string enum. +def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">; + +def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode", + "Dequantization mode. Only MIN_COMBINED is supported.", + [HLO_MIN_COMBINED]> { + let cppNamespace = "::mlir::mhlo"; +} + +// These mirror the XLA Transpose enum in Triangular Solve options. +def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">; +def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">; +def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">; +def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">; + +def HLO_TransposeAttr : StrEnumAttr<"Transpose", + "Transpose options", + [ + HLO_TRANSPOSE_INVALID, + HLO_NO_TRANSPOSE, + HLO_TRANSPOSE, + HLO_ADJOINT + ]> { + let cppNamespace = "::mlir::mhlo"; +} + +#endif // HLO_OPS_BASE_ENUMS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h index 3b78ff8a36723e..70247d76d1dc80 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h @@ -18,9 +18,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_STRUCTS_H_ -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Identifier.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" // Order matters, this .inc header is not self-contained, and relies on the diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td index d25eb5104c6228..d512a7cd221db4 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.td @@ -26,7 +26,7 @@ def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> ]> { - let description = "Structure of dimension information for dot product"; + let summary = "Structure of dimension information for dot product"; } def ScatterDimensionNumbers : StructAttr< @@ -35,7 +35,7 @@ def ScatterDimensionNumbers : StructAttr< StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for scatter"; + let summary = "Structure of dimension information for scatter"; } def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ @@ -49,7 +49,7 @@ def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ StructFieldAttr<"output_feature_dimension", I64Attr>, StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { - let description = "Structure of dimension information for conv op"; + let summary = "Structure of dimension information for conv op"; } def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, @@ -57,7 +57,7 @@ def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, StructFieldAttr<"start_index_map", I64ElementsAttr>, StructFieldAttr<"index_vector_dim", I64Attr>]> { - let description = "Structure of dimension information for gather"; + let summary = "Structure of dimension information for gather"; } @@ -67,7 +67,7 @@ def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [ StructFieldAttr<"handle", I64Attr>, StructFieldAttr<"type", I64Attr>]> { - let description = "two 64-bit integers 'handle' and 'type'"; + let summary = "two 64-bit integers 'handle' and 'type'"; } #endif // HLO_OPS_BASE_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h new file mode 100644 index 00000000000000..e5b4477758f915 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_common.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_ + +// This file defines functionality shared between chlo/mhlo/lhlo dialects. + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace hlo { + +// Verifies the source target pairs attached to collective permute. +LogicalResult VerifyCollectivePermuteSourceTargetPairs( + Operation *op, DenseIntElementsAttr attr); + +} // namespace hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_COMMON_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 32940cbc623262..08f25693c6edf3 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -30,6 +30,15 @@ class ConstantSplat : NativeCodeCall< class HLO_ConstantLike : NativeCodeCall< "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; +def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + +def HLO_ConstantLikePosInfValue : NativeCodeCall< + "chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; + +def HLO_ConstantLikeNegInfValue : NativeCodeCall< + "chlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/true)">; + def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< @@ -43,4 +52,12 @@ def BinBroadcastDimensionsNonEmpty : NativeCodeCall< class GetScalarOfType : NativeCodeCall< "hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; +// Constraint that Attr has values [0, 1, ...]. +def IdentityBroadcastDims : AttrConstraint< + CPred<"hlo::IsSequenceStartingWith0($_self)">>; + +def NonComplexElementType : Type< + CPred<"!$_self.cast().getElementType().isa()">, + "Non-complex element type">; + #endif // HLO_UTILS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h index 00de1170f8a123..e26bf08fbd8b8b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td index f8e02d413e9db8..280c0a1c8a3a59 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td @@ -50,7 +50,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*args=*/(ins), /*methodBody=*/[{}], /*defaultImplementation=*/[{ - /// Return whether this op can be fused withh its consumers + /// Return whether this op can be fused with its consumers return true; }] >, @@ -64,21 +64,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether two inputs have the same shape. Operation *op = this->getOperation(); - assert(lhs < op->getNumOperands() && lhs >= 0 && - rhs < op->getNumOperands() && rhs >= 0); + assert(lhs >= 0 && rhs >= 0); if (lhs == rhs) return true; - - // if both lhs and rhs have static shapes, check them directly - Type lhs_ty = op->getOperand(lhs).getType(); - Type rhs_ty = op->getOperand(rhs).getType(); - auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); - auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); - if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || - !rhs_shape_type || !rhs_shape_type.hasStaticShape() || - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - return false; - } - return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + return inferShapeEquality(op->getOperand(lhs), op->getOperand(rhs)); }] >, InterfaceMethod< @@ -91,21 +79,9 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether two outputs have the same shape. Operation *op = this->getOperation(); - assert(lhs < op->getNumResults() && lhs >= 0 && - rhs < op->getNumResults() && rhs >= 0); + assert(lhs >= 0 && rhs >= 0); if (lhs == rhs) return true; - - // if both lhs and rhs have static shapes, check them directly - Type lhs_ty = op->getResult(lhs).getType(); - Type rhs_ty = op->getResult(rhs).getType(); - auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); - auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); - if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || - !rhs_shape_type || !rhs_shape_type.hasStaticShape() || - lhs_shape_type.getRank() != rhs_shape_type.getRank()) { - return false; - } - return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + return inferShapeEquality(op->getResult(lhs), op->getResult(rhs)); }] >, InterfaceMethod< @@ -118,20 +94,8 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { /*defaultImplementation=*/[{ /// Return whether the input and the output have the same shape. Operation *op = this->getOperation(); - assert(input < op->getNumOperands() && input >= 0 && - output < op->getNumResults() && output >= 0); - - // if both input and output have static shapes, check them directly - Type input_ty = op->getOperand(input).getType(); - Type output_ty = op->getResult(output).getType(); - auto input_shape_type = input_ty.dyn_cast_or_null(); - auto output_shape_type = output_ty.dyn_cast_or_null(); - if (!input_shape_type || !input_shape_type.hasStaticShape() || - !output_shape_type || !output_shape_type.hasStaticShape() || - input_shape_type.getRank() != output_shape_type.getRank()) { - return false; - } - return input_shape_type.getShape() == output_shape_type.getShape(); + assert(input >= 0 && output >= 0); + return inferShapeEquality(op->getOperand(input), op->getResult(output)); }] >, InterfaceMethod< @@ -156,6 +120,21 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { }] >, ]; + + let extraClassDeclaration = [{ + // Returns whether the given values have the same static shape. + static bool inferShapeEquality(Value first, Value second) { + // If both lhs and rhs have static shapes, check them directly. + auto first_ty = first.getType().dyn_cast(); + auto second_ty = second.getType().dyn_cast(); + if (!first_ty || !first_ty.hasStaticShape() || + !second_ty || !second_ty.hasStaticShape() || + first_ty.getRank() != second_ty.getRank()) { + return false; + } + return first_ty.getShape() == second_ty.getShape(); + } + }]; } #endif diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td new file mode 100644 index 00000000000000..7cddf4e08b1d1c --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_DIALECT +#define LHLO_DIALECT + +include "mlir/IR/OpBase.td" + +// We define the dialect here so that both structs and ops can refer to it. +def LHLO_Dialect : Dialect { + let name = "lmhlo"; + let cppNamespace = "::mlir::lmhlo"; +} + +#endif // LHLO_DIALECT diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h index effa9ecc83b82e..3214ec6efb6fc6 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h @@ -22,14 +22,15 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index b3708bf4ff12e9..b9fe5fb09e6cb9 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -23,9 +23,9 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td" - class LHLOGPU_Op traits = []> : Op], traits)>; @@ -47,14 +47,14 @@ def I32Buffer : MemRefOf<[I32]>; def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, BASE_HLO_BatchNormGradOp { let arguments = (ins - Arg:$operand, - Arg:$scale, - Arg:$mean, - Arg:$stddev, - Arg:$grad_output, - Arg:$grad_operand, // gradient of $operand. - Arg:$grad_scale, - Arg:$grad_offset, + Arg:$operand, + Arg:$scale, + Arg:$mean, + Arg:$stddev, + Arg:$grad_output, + Arg:$grad_operand, // gradient of $operand. + Arg:$grad_scale, + Arg:$grad_offset, F32Attr:$epsilon, I64Attr:$feature_index ); @@ -63,12 +63,12 @@ def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, BASE_HLO_BatchNormInferenceOp { let arguments = (ins - Arg:$operand, - Arg:$scale, - Arg:$offset, - Arg:$mean, - Arg:$stddev, - Arg:$output, + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$stddev, + Arg:$output, F32Attr:$epsilon, I64Attr:$feature_index); } @@ -77,12 +77,12 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, BASE_HLO_BatchNormTrainingOp { let arguments = (ins - Arg:$operand, - Arg:$scale, - Arg:$offset, - Arg:$output, - Arg:$batch_mean, - Arg:$batch_stddev, + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$output, + Arg:$batch_mean, + Arg:$batch_stddev, F32Attr:$epsilon, I64Attr:$feature_index ); @@ -92,33 +92,11 @@ def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, // LMHLO ops representing convolution library functions. //===----------------------------------------------------------------------===// -def ActivationModeNone : StrEnumAttrCase<"None">; -def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; -def ActivationModeTanh : StrEnumAttrCase<"Relu">; -def ActivationModeRelu : StrEnumAttrCase<"Relu">; -def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; -def ActivationModeReluX : StrEnumAttrCase<"ReluX">; -def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; - -def ActivationAttr : StrEnumAttr<"Activation", - "Activation applied with fused convolution", - [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, - ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, - ActivationModeBandPass]>; - -def GpuConvolutionAttributes { +class GpuConvolutionAttributes { dag attributes = !con( - ConvolutionAttributes.attributes, + ConvolutionAttributes.attributes, (ins F64Attr:$result_scale), - (ins ConvolutionBackendConfigAttr:$backend_config)); -} - -def GpuFusedConvolutionAttributes { - dag attributes = !con( - ConvolutionAttributes.attributes, - (ins F64Attr:$result_scale, - ActivationAttr:$activation_mode, - F64Attr:$side_input_scale), + extraAttribs, (ins ConvolutionBackendConfigAttr:$backend_config)); } @@ -128,8 +106,8 @@ def LHLOGPU_ConvForwardOp : LHLOGPU_Op<"conv_forward"> { Arg:$input, Arg:$filter, Arg:$output, - Arg:$scratch), - GpuConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins)>.attributes); } def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { @@ -138,8 +116,8 @@ def LHLOGPU_ConvBackwardInputOp : LHLOGPU_Op<"conv_backwardinput"> { Arg:$d_output, Arg:$filter, Arg:$d_input, - Arg:$scratch), - GpuConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins)>.attributes); } def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { @@ -148,14 +126,27 @@ def LHLOGPU_ConvBackwardFilterOp : LHLOGPU_Op<"conv_backwardfilter"> { Arg:$input, Arg:$d_output, Arg:$d_filter, - Arg:$scratch), - GpuConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins)>.attributes); +} + +// output = activation(result_scale * conv(input, filter) + bias) +def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { + let arguments = !con( + (ins + Arg:$input, + Arg:$filter, + Arg:$bias, + Arg:$output, + Arg:$scratch), + GpuConvolutionAttributes<(ins + ActivationAttr:$activation_mode)>.attributes); } // output = activation(result_scale * conv(input, filter) + // side_input * side_input_scale + // bias) -def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { +def LHLOGPU_ConvForwardFusedSideInputOp : LHLOGPU_Op<"conv_forward_fused_with_side_input"> { let arguments = !con( (ins Arg:$input, @@ -163,8 +154,10 @@ def LHLOGPU_ConvForwardFusedOp : LHLOGPU_Op<"conv_forward_fused"> { Arg:$bias, Arg:$side_input, Arg:$output, - Arg:$scratch), - GpuFusedConvolutionAttributes.attributes); + Arg:$scratch), + GpuConvolutionAttributes<(ins + ActivationAttr:$activation_mode, + F64Attr:$side_input_scale)>.attributes); } //===----------------------------------------------------------------------===// @@ -179,9 +172,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { Arg:$rhs, Arg:$output, DotDimensionNumbers:$dot_dimension_numbers, - F64Attr:$alpha, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, I64Attr:$batch_size, - I64Attr:$algorithm); + OptionalAttr:$algorithm); } // output = alpha(lhs * rhs) + beta * bias @@ -192,19 +186,20 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { Arg:$bias, Arg:$output, DotDimensionNumbers:$dot_dimension_numbers, - F64Attr:$alpha, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, F64Attr:$beta, I64Attr:$batch_size, - I64Attr:$algorithm); + OptionalAttr:$algorithm); } def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { let arguments = (ins Arg:$input, Arg:$output, - Arg:$scratch, + Arg:$scratch, Arg:$info, - BoolAttr:$is_upper); + BoolAttr:$is_lower); } #endif // LHLO_GPU_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h new file mode 100644 index 00000000000000..724b413885f2f1 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines enums used in the LMHLO_GPU dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_ + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_ENUMS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td new file mode 100644 index 00000000000000..15f9ed67c192e1 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_enums.td @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_GPU_OPS_ENUMS +#define LHLO_GPU_OPS_ENUMS + +include "mlir/IR/OpBase.td" + +def ActivationModeNone : StrEnumAttrCase<"None">; +def ActivationModeSigmoid : StrEnumAttrCase<"Sigmoid">; +def ActivationModeTanh : StrEnumAttrCase<"Tanh">; +def ActivationModeRelu : StrEnumAttrCase<"Relu">; +def ActivationModeRelu6 : StrEnumAttrCase<"Relu6">; +def ActivationModeReluX : StrEnumAttrCase<"ReluX">; +def ActivationModeBandPass : StrEnumAttrCase<"BandPass">; + +def ActivationAttr : StrEnumAttr<"Activation", + "Activation applied with fused convolution", + [ActivationModeNone, ActivationModeSigmoid, ActivationModeTanh, + ActivationModeRelu, ActivationModeRelu6, ActivationModeReluX, + ActivationModeBandPass]> { + let cppNamespace = "::mlir::lmhlo_gpu"; +} + +#endif // LHLO_GPU_OPS_ENUMS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h index ff642b82c22d95..6b94d40fd3b3a4 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h @@ -1,30 +1,30 @@ /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ==============================================================================*/ + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ // This file defines structures used in the LMHLO_GPU dialect. -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ -#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Identifier.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" // Order matters, this .inc header is not self-contained, and relies on the // #includes above. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.h.inc" -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_GPU_OPS_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td index 2236fc38e29b47..963834bc936d36 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_structs.td @@ -1,4 +1,3 @@ - /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,8 +21,18 @@ include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops_base.td" def ConvolutionBackendConfigAttr : StructAttr<"ConvolutionBackendConfig", LHLO_GPU_Dialect, [ StructFieldAttr<"algorithm", I64Attr>, - StructFieldAttr<"tensor_ops_enabled", BoolAttr>]> { - let description = "GPU Convolution backend configuration"; + StructFieldAttr<"tensor_ops_enabled", BoolAttr>, + // The following 3 attributes describe the layout as an array of integers + // that list the dimensions in minor-to-major order similar to XLA's layout + // representation. operand_0_layout and operand_0_layout described the layout + // of the first 2 operands of the convolution, and result_layout describes + // the layout of the primary output operand of the convolution. + // Note: Not using names like input_layout or filter_layout as `input` may be + // an input operand (for ConvForward) but output for ConvBackward. + StructFieldAttr<"operand_0_layout", I64ArrayAttr>, + StructFieldAttr<"operand_1_layout", I64ArrayAttr>, + StructFieldAttr<"result_layout", I64ArrayAttr>]> { + let summary = "GPU Convolution backend configuration"; } #endif // LHLO_GPU_OPS_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 9dc6d7aa0c079d..7d32cffb7f7341 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -20,13 +20,16 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 28e51351c7e2bc..db3aa43afed47b 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -33,16 +33,14 @@ limitations under the License. #ifndef LHLO_OPS #define LHLO_OPS +include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td" include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td" - -def LHLO_Dialect : Dialect { - let name = "lmhlo"; - let cppNamespace = "::mlir::lmhlo"; -} +include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td" //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. @@ -85,6 +83,8 @@ def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; def LHLO_BitcastConvertOp: LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp; +def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp; + def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp; def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp; @@ -112,6 +112,8 @@ def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFinit def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp; +def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer>, BASE_HLO_LogisticOp; + def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp; def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; @@ -197,10 +199,11 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO //===----------------------------------------------------------------------===// // TODO(b/139813999): specify required function signature in a type-safe way. -def LHLO_ReduceOp: LHLO_Op<"reduce", [ - SameVariadicOperandSize, - SingleBlockImplicitTerminator<"TerminatorOp"> - ]>, BASE_HLO_ReduceOp { +// +// The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are +// moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp. +// TODO(timshen): cleanup lmhlo.TerminatorOp. +def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp { let arguments = (ins Arg, "", [MemRead]>:$operands, Arg, "", [MemRead]>:$init_values, @@ -211,9 +214,7 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [ let regions = (region SizedRegion<1>:$body); } -def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ - SingleBlockImplicitTerminator<"TerminatorOp"> - ]>, BASE_HLO_ReduceWindowOp { +def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", []>, BASE_HLO_ReduceWindowOp { let arguments = (ins Arg:$operand, @@ -232,46 +233,36 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ let regions = (region SizedRegion<1>:$body); } -// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, -// A tuple-like pattern match syntax could work: -// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { -// ... -// }, { -// ... -// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> () +// TODO(timshen): Add a custom syntax for this. def LHLO_CaseOp: LHLO_Op<"case", [ - AttrSizedOperandSegments, SingleBlockImplicitTerminator<"TerminatorOp"> ]>, BASE_HLO_CaseOp { - let arguments = (ins - Arg:$index, - Arg, "", [MemRead]>:$branch_operands, - Arg, "", [MemWrite]>:$out - ); + let arguments = (ins Arg:$index); let regions = (region VariadicRegion>:$branches); } // TODO(timshen): Add a custom syntax for this. -def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, - BASE_HLO_WhileOp { +def LHLO_WhileOp: LHLO_Op<"while", []>, BASE_HLO_WhileOp { let arguments = (ins - Arg, "", [MemRead]>:$val, - Arg, "", [MemWrite]>:$output - ); + Arg, "", [MemWrite]>:$cond_val, + OptionalAttr:$trip_count); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); } -def LHLO_CustomCallOp : LHLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { +def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>, + BASE_HLO_CustomCallOp { let arguments = (ins Arg, "", [MemRead]>:$args, - Arg:$output, + Arg, "", [MemWrite]>:$output, StrAttr:$call_target_name, DefaultValuedAttr:$has_side_effect, - DefaultValuedAttr:$backend_config + DefaultValuedAttr:$backend_config, + OptionalAttr:$target_arg_mapping ); + let verifier = [{ return Verify(*this); }]; } //===----------------------------------------------------------------------===// @@ -284,7 +275,8 @@ def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { Arg:$rhs, Arg:$out, OptionalAttr:$broadcast_dimensions, - HLO_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction, + OptionalAttr:$compare_type ); } @@ -304,176 +296,25 @@ def LHLO_SliceOp: LHLO_Op< ); } -def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { +def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice", + [AllElementTypesMatch<["operand", "output"]>]>, BASE_HLO_DynamicSliceOp { let arguments = (ins Arg:$operand, - Arg:$update, + Arg, "", [MemRead]>:$start_indices, Arg:$output, - Arg, "", [MemRead]>:$start_indices - ); -} - -//===----------------------------------------------------------------------===// -// StaticMemRefCastOp -//===----------------------------------------------------------------------===// - -def HLO_StaticMemRefCastOp: Op]> { - let summary = [{ - modifies the offset, sizes and strides of a statically shaped memref - }]; - let description = [{ - Casts the statically shaped memref operand to a memref with optionally - modified offsets, sizes and strides. - - Example: - ```mlir - %buf_transformed = - lmhlo.static_memref_cast %buf - : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]> - - // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and - // offset 2. - ``` - }]; - - let arguments = (ins Arg:$operand); - let results = (outs Res:$result); - - let builders = [OpBuilder<"MemRefType resultType, Value operand", - [{ - $_state.addOperands(operand); - $_state.types.push_back(resultType); - }]>]; - - let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } - }]; - - let verifier = [{ return Verify(*this); }]; - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `->` type($result) - }]; -} - -//===----------------------------------------------------------------------===// -// DynamicMemRefCastOp -//===----------------------------------------------------------------------===// - -def HLO_DynamicMemRefCastOp: Op]> { - let summary = "dynamic memref cast operation"; - let description = [{ - Change sizes and strides of a memref using the values computed in runtime. - - Example: - ```mlir - %buf_transformed = - lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y] - : memref -> memref - // The result of the op is a type-erased memref with `[%size_X, %size_Y]` - // shape and `[%step_X, %step_Y]` strides. The offset will be inherited - // from the input. - ``` - }]; - - let arguments = (ins - Arg:$operand, - Variadic:$sizes, - Variadic:$strides + I64ElementsAttr:$slice_sizes ); - let results = (outs Res:$result); - - let builders = [ - OpBuilder<"MemRefType resultType, Value operand, ValueRange sizes, " - "ValueRange strides", [{ - $_state.addOperands(operand); - $_state.addOperands(sizes); - $_state.addOperands(strides); - $_state.types.push_back(resultType); - }]>]; - - let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } - }]; - - let verifier = [{ return Verify(*this); }]; - let assemblyFormat = [{ - $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->` - type($result) - }]; } -//===----------------------------------------------------------------------===// -// ReshapeMemRefCastOp -//===----------------------------------------------------------------------===// - -def ReshapeMemRefCastOp: Op, - NoSideEffect]> { - let summary = "reshape memref cast operation"; - let description = [{ - The `reshape_memref_cast` operation converts a memref from one type to an - equivalent type with a provided shape. The data is never copied or moved. - The source and destination types are compatible if both have the same - element type, address space and identity layout map. The following - combinations are possible: - - a. Both are ranked memref types. - - ```mlir - // Reshape statically-shaped memref. - %dst = reshape_memref_cast %src(%shape) - : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> - %dst0 = reshape_memref_cast %src(%shape0) - : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> - ``` - - b. Source type is ranked, destination type is unranked. - - ```mlir - // Reshape dynamically-shaped 1D memref. - %dst = reshape_memref_cast %src(%shape) - : (memref, memref) to memref<*xf32> - ``` - - c. Source type is unranked, destination type is ranked. - - ```mlir - // Flatten unranked memref. - %dst = reshape_memref_cast %src(%shape) - : (memref<*xf32>, memref<1xi32>) to memref - ``` - - d. Both are unranked memref types. - - ```mlir - // Reshape unranked memref. - %dst = reshape_memref_cast %src(%shape) - : (memref<*xf32>, memref) to memref<*xf32> - ``` - }]; - +def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []>, BASE_HLO_DynamicUpdateSliceOp { let arguments = (ins - AnyRankedOrUnrankedMemRef:$operand, - LHLO_ExtentBuffer:$shape + Arg:$operand, + Arg:$update, + Arg, "", [MemRead]>:$start_indices, + Arg:$output ); - let results = (outs AnyRankedOrUnrankedMemRef:$result); - - let extraClassDeclaration = [{ - BaseMemRefType getType() { - return getResult().getType().cast(); } - }]; - - let verifier = [{ return Verify(*this); }]; - let assemblyFormat = [{ - $operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape) - `)` `->` type($result) - }]; } - //===----------------------------------------------------------------------===// // LMHLO Other op definitions. //===----------------------------------------------------------------------===// @@ -526,12 +367,6 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>, ); } -// TODO(timshen): add a custom verifier. -def LHLO_BitcastOp: LHLO_Op<"bitcast", []> { - let arguments = (ins Arg:$input, - Arg:$output); -} - def LHLO_BroadcastOp : LHLO_Op<"broadcast", []>, BASE_HLO_BroadcastOp { let arguments = (ins @@ -573,7 +408,7 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { Arg:$lhs, Arg:$rhs, Arg:$output), - ConvolutionAttributes.attributes); + ConvolutionAttributes.attributes); } def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { @@ -690,19 +525,44 @@ def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, ); } -def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, - BASE_HLO_AllReduceOp { - let arguments = (ins - Arg:$operand, - Arg:$output, +// Common base class for AllReduce, AllGather, and AllToAll. +class LHLO_CollectiveCommunicationOp traits = []> : + LHLO_Op { + dag arguments_base = (ins + Arg, "", [MemRead]>:$operands, + Arg, "", [MemWrite]>:$results, I64ElementsAttr:$replica_groups, DefaultValuedAttr:$constrain_layout, OptionalAttr:$channel_id, DefaultValuedAttr:$use_global_device_ids ); + let verifier = [{ return Verify(*this); }]; + let extraClassDeclaration = [{ + // AllGather is cross replica if channel_id is not set. + bool IsCrossReplica() { return !channel_id().hasValue(); } + }]; +} + +def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">, + BASE_HLO_AllGatherOp { + let arguments = !con( + arguments_base, + (ins I64Attr:$all_gather_dimension)); +} + +def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce">, + BASE_HLO_AllReduceOp { + let arguments = arguments_base; let regions = (region SizedRegion<1>:$computation); } +def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all">, + BASE_HLO_AllToAllOp { + let arguments = !con( + arguments_base, + (ins OptionalAttr:$split_dimension)); +} + def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, BASE_HLO_CollectivePermuteOp { @@ -712,6 +572,7 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, I64ElementsAttr:$source_target_pairs, OptionalAttr:$channel_id ); + let verifier = [{ return Verify(*this); }]; } def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp { @@ -731,16 +592,16 @@ def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_Ch ); } -def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp { +def LHLO_InfeedOp: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp { let arguments = (ins - Arg:$output, + Arg, "", [MemWrite]>:$outputs, DefaultValuedAttr:$config ); } -def LHLO_Outfeed: LHLO_Op<"outfeed", []> { +def LHLO_OutfeedOp: LHLO_Op<"outfeed", []> { let arguments = (ins - Arg:$operand, + Arg, "", [MemRead]>:$operands, DefaultValuedAttr:$config ); } @@ -749,6 +610,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp { let arguments = (ins Arg, "", [MemWrite]>); } +def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp { + let arguments = (ins Arg, "", [MemWrite]>); +} + def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>, BASE_HLO_TriangularSolveOp { let arguments = (ins @@ -758,7 +623,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType BoolAttr:$left_side, BoolAttr:$lower, BoolAttr:$unit_diagonal, - HLO_TransposeAttr:$transpose_a + HLO_TransposeAttr:$transpose_a, + HLO_LayoutAttr:$layout_a, + HLO_LayoutAttr:$layout_b, + HLO_LayoutAttr:$layout_output ); } @@ -812,8 +680,46 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"ArrayRef attributes"> + OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)> ]; + + let extraClassDeclaration = [{ + SmallVector getInputBuffers() { + SmallVector buffers; + this->region().walk([&](memref::TensorLoadOp load) { + if (load.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(load.memref()); + }); + return buffers; + } + + SmallVector getOutputBuffers() { + SmallVector buffers; + this->region().walk([&](memref::TensorStoreOp store) { + if (store.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(store.memref()); + }); + return buffers; + } + + SmallVector getFusionParameters() { + SmallVector buffers; + this->region().walk([&](memref::TensorLoadOp load) { + if (load.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(load); + }); + return buffers; + } + + SmallVector getFusionResults() { + SmallVector buffers; + this->region().walk([&](memref::TensorStoreOp store) { + if (store.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(store.tensor()); + }); + return buffers; + } + }]; } def TerminatorOp : @@ -822,9 +728,9 @@ def TerminatorOp : let description = [{ Terminator operation for the LHLO dialect. }]; - let builders = [OpBuilder<"ValueRange operands", - [{ build($_builder, $_state, llvm::None, operands, llvm::None); }] - >]; + let builders = [ + OpBuilder<(ins "ValueRange":$operands), + [{ build($_builder, $_state, llvm::None, operands, llvm::None); }]>]; } #endif // LHLO_OPS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td index 9cd77417ffd3e9..ba158d92054c6a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_base.td @@ -16,6 +16,7 @@ limitations under the License. #ifndef LHLO_OPS_BASE #define LHLO_OPS_BASE +include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/IR/OpBase.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" @@ -40,8 +41,6 @@ def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; -def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; - -def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; +def LHLO_Buffer : MemRefOf<[AnyFloat, AnyInteger, AnyComplex]>; #endif // LHLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h new file mode 100644 index 00000000000000..8b14843dbadcd8 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines structures used in LMHLO dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_ + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/Types.h" + +// Order matters, this .inc header is not self-contained, and relies on the +// #includes above. +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_STRUCTS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td new file mode 100644 index 00000000000000..d9ae1ca67ce7f8 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.td @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef LHLO_OPS_STRUCTS +#define LHLO_OPS_STRUCTS + +include "mlir-hlo/Dialect/mhlo/IR/lhlo_dialect.td" + +// This structure defines information about how arguments to the LHLO custom +// call operation relate to the arguments of the target function. In most cases +// the mapping will be 1:1, but in certain cases, it may not be. As an example, +// tokens are not represented in the LHLO dialect, but the custom call target +// might still expect to see buffer arguments corresponding to tokens, in which +// case the mapping will not be 1:1. +def CustomCallTargetArgMapping : StructAttr<"CustomCallTargetArgMapping", + LHLO_Dialect, [ + // number of buffer expected by the target for arguments. + StructFieldAttr<"num_args", I64Attr>, + // number of buffer expected by the target for results. + StructFieldAttr<"num_results", I64Attr>, + // map each custom call op arg to its position in target args. + StructFieldAttr<"args_to_target_args", I64ArrayAttr>, + // map each custom call op arg to its position in target results. + StructFieldAttr<"results_to_target_results", I64ArrayAttr>]> { + let summary = "Custom call operands to target argument mapping info"; +} + +#endif // LHLO_OPS_STRUCTS diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 39b4ca650431f5..17c052472cc9ad 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { } -def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> { - let summary = "Legalize from LHLO dialect to LLVM."; - let constructor = "createTestLhloToLLVMPass()"; -} - - def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { let summary = "Legalize from LHLO dialect to parallel loops."; let constructor = "createLegalizeLhloToParallelLoopsPass()"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h new file mode 100644 index 00000000000000..d9e637dfc6367e --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_CHLO_TO_MHLO_OP_H_ + +#include + +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace chlo { + +struct HloComplexAdaptor { + static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; +template +struct HloBinaryElementwiseAdaptor { + static ToOpTy CreateOp(FromOpTy from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; +struct HloCompareAdaptor { + static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create( + from_op.getLoc(), result_type, broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction(), from_op.compare_typeAttr()); + } +}; + +// Populate a pattern for each Broadcasting CHlo op. This requires the pattern +// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values. +template