diff --git a/.bazelrc b/.bazelrc index 09e2ef5..f93742d 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,2 +1,11 @@ # This flag is required for CUDA repo that @org_tensorflow depends on. common --experimental_repo_remote_exec + +build:manylinux_2_17_x86_64 --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:manylinux_2_17_x86_64 --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:manylinux_2_17_x86_64 --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:manylinux_2_17_x86_64 --repo_env=CC="/usr/lib/llvm-18/bin/clang" +build:manylinux_2_17_x86_64 --repo_env=TF_SYSROOT="/dt9" +build:manylinux_2_17_x86_64 --extra_execution_platforms="@sigbuild-r2.17-clang_config_platform//:platform" +build:manylinux_2_17_x86_64 --host_platform="@sigbuild-r2.17-clang_config_platform//:platform" +build:manylinux_2_17_x86_64 --platforms="@sigbuild-r2.17-clang_config_platform//:platform" diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 0000000..ce1c163 --- /dev/null +++ b/.bazelversion @@ -0,0 +1,2 @@ +6.5.0 +# https://github.com/tensorflow/tensorflow/blob/master/.bazelversion diff --git a/BUILD b/BUILD index b5a09d4..f314455 100644 --- a/BUILD +++ b/BUILD @@ -36,29 +36,3 @@ py_library( "//tensorflow_compression/python/util:packed_tensors", ], ) - -py_binary( - name = "build_api_docs", - srcs = ["tools/build_api_docs.py"], - deps = [":tensorflow_compression"], -) - -py_binary( - name = "build_pip_pkg", - srcs = ["build_pip_pkg.py"], - data = [ - "LICENSE", - "README.md", - "MANIFEST.in", - "requirements.txt", - "tensorflow_compression/all_tests.py", - ":tensorflow_compression", - # The following targets are for Python unit tests. - "//tensorflow_compression/python/datasets:py_src", - "//tensorflow_compression/python/distributions:py_src", - "//tensorflow_compression/python/entropy_models:py_src", - "//tensorflow_compression/python/layers:py_src", - "//tensorflow_compression/python/ops:py_src", - "//tensorflow_compression/python/util:py_src", - ], -) diff --git a/README.md b/README.md index d46a5c7..192389d 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,8 @@ started. For a more in-depth introduction from a classical data compression perspective, consider our [paper on nonlinear transform -coding](https://arxiv.org/abs/2007.03034), or watch @jonycgn's [talk on learned -image compression](https://www.youtube.com/watch?v=x_q7cZviXkY). For an +coding](https://arxiv.org/abs/2007.03034), or watch @jonarchists's [talk on +learned image compression](https://www.youtube.com/watch?v=x_q7cZviXkY). For an introduction to lossy data compression from a machine learning perspective, take a look at @yiboyang's [review paper](https://arxiv.org/abs/2202.06533). @@ -40,6 +40,26 @@ docs](https://www.tensorflow.org/api_docs/python/tfc) for details): reparameterizing kernels and biases in the Fourier domain, and an implementation of generalized divisive normalization (GDN). +**Important update:** As of February 1, 2024, TensorFlow Compression is in +maintenance mode. This means concretely: + +- The full feature set of TFC is frozen. No new features will be developed, but + the repository will receive maintenance fixes. + +- Going forward, new TFC packages will only work with TensorFlow 2.14. This is + due to an incompatibility introduced in the Keras version shipped with TF + 2.15, which would require a rewrite of our layer and entropy model classes. + +- To ensure existing models can still be run with TF 2.15 and later, we are + releasing a new package + [tensorflow-compression-ops](https://github.com/tensorflow/compression/tree/master/tensorflow_compression_ops), + which only contains the C++ ops. These will be updated as long as possible for + newer TF versions. + +- Binary packages are provided for both options on pypi.org: + [TFC](https://pypi.org/project/tensorflow-compression/) and + [TFC ops](https://pypi.org/project/tensorflow-compression-ops/). + ## Documentation & getting help @@ -75,7 +95,7 @@ releases](https://github.com/tensorflow/compression/releases). To install TFC via `pip`, run the following command: ```bash -pip install tensorflow-compression +python -m pip install tensorflow-compression ``` To test that the installation works correctly, you can run the unit tests with: @@ -95,7 +115,7 @@ installed TensorFlow version. Run it in a cell before executing your Python code: ``` -!pip install tensorflow-compression~=$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+\.\d+).*/\1.0/sg') +%pip install tensorflow-compression~=$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+\.\d+).*/\1.0/sg') ``` Note: The binary packages of TFC are tied to TF with the same minor version @@ -114,7 +134,7 @@ host. For instance, you can use a command line like this: ```bash docker run tensorflow/tensorflow:latest bash -c \ - "pip install tensorflow-compression && + "python -m pip install tensorflow-compression && python -m tensorflow_compression.all_tests" ``` @@ -132,7 +152,7 @@ installs TensorFlow and TensorFlow Compression: ```bash conda create --name ENV_NAME python cudatoolkit cudnn conda activate ENV_NAME -pip install tensorflow-compression +python -m pip install tensorflow-compression ``` Depending on the requirements of the `tensorflow` pip package, you may need to @@ -257,22 +277,19 @@ that TensorFlow uses. Inside the Docker container, the following steps need to be taken: 1. Clone the `tensorflow/compression` repo from GitHub. -2. Install Python dependencies. -3. Run `:build_pip_pkg` inside the cloned repo. +2. Run `tools/build_pip_pkg.sh` inside the cloned repo. For example: ```bash -sudo docker run -i --rm -v /tmp/tensorflow_compression:/tmp/tensorflow_compression \ - tensorflow/build:latest-python3.10 bash -c \ - "git clone https://github.com/tensorflow/compression.git /tensorflow_compression && - cd /tensorflow_compression && - python -m pip install -U pip setuptools wheel && - python -m pip install -r requirements.txt && - bazel build -c opt --copt=-mavx --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain :build_pip_pkg && - pushd bazel-bin/build_pip_pkg.runfiles/tensorflow_compression && - python build_pip_pkg.py . /tmp/tensorflow_compression && - popd" +git clone https://github.com/tensorflow/compression.git /tensorflow_compression +docker run -i --rm \ + -v /tmp/tensorflow_compression:/tmp/tensorflow_compression\ + -v /tensorflow_compression:/tensorflow_compression \ + -w /tensorflow_compression \ + -e "BAZEL_OPT=--config=manylinux_2_17_x86_64" \ + tensorflow/build:latest-python3.10 \ + bash tools/build_pip_pkg.sh /tmp/tensorflow_compression ``` For Darwin, the Docker image and specifying the toolchain is not necessary. We @@ -282,12 +299,9 @@ Python virtual environment to do this): ```bash git clone https://github.com/tensorflow/compression.git /tensorflow_compression cd /tensorflow_compression -python -m pip install -U pip setuptools wheel -python -m pip install -r requirements.txt -bazel build -c opt --copt=-mavx --macos_minimum_os=10.14 :build_pip_pkg -pushd bazel-bin/build_pip_pkg.runfiles/tensorflow_compression -python build_pip_pkg.py . /tmp/tensorflow_compression " -popd +BAZEL_OPT="--macos_minimum_os=10.14" bash \ + tools/build_pip_pkg.sh \ + /tmp/tensorflow_compression ``` In both cases, the wheel file is created inside `/tmp/tensorflow_compression`. @@ -295,7 +309,7 @@ In both cases, the wheel file is created inside `/tmp/tensorflow_compression`. To test the created package, first install the resulting wheel file: ```bash -pip install /tmp/tensorflow_compression/tensorflow_compression-*.whl +python -m pip install /tmp/tensorflow_compression/tensorflow_compression-*.whl ``` Then run the unit tests (Do not run the tests in the workspace directory where @@ -312,7 +326,7 @@ popd When done, you can uninstall the pip package again: ```bash -pip uninstall tensorflow-compression +python -m pip uninstall tensorflow-compression ``` ## Evaluation @@ -327,11 +341,11 @@ for more information. If you use this library for research purposes, please cite: ``` @software{tfc_github, - author = "Ballé, Johannes and Hwang, Sung Jin and Agustsson, Eirikur", + author = "Ballé, Jona and Hwang, Sung Jin and Agustsson, Eirikur", title = "{T}ensor{F}low {C}ompression: Learned Data Compression", url = "http://github.com/tensorflow/compression", - version = "2.14.0", - year = "2023", + version = "2.14.1", + year = "2024", } ``` In the above BibTeX entry, names are top contributors sorted by number of diff --git a/WORKSPACE b/WORKSPACE index fecbff3..af71ba9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -5,59 +5,36 @@ tensorflow_compression_workspace() load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -# `bazel_skylib` and `rules_python` versions should match the ones used in -# `org_tensorflow`. -http_archive( - name = "bazel_skylib", - sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - ], -) - -http_archive( - name = "rules_python", - sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", - strip_prefix = "rules_python-0.26.0", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", -) - http_archive( name = "org_tensorflow", - sha256 = "ce357fd0728f0d1b0831d1653f475591662ec5bca736a94ff789e6b1944df19f", - strip_prefix = "tensorflow-2.14.0", + sha256 = "d7876f4bb0235cac60eb6316392a7c48676729860da1ab659fb440379ad5186d", + strip_prefix = "tensorflow-2.18.0", urls = [ - "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.14.0.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.18.0.tar.gz", ], ) # Copied from `@org_tensorflow//:WORKSPACE`. -load( - "@rules_python//python:repositories.bzl", - "py_repositories", - "python_register_toolchains", -) -py_repositories() - -load( - "@org_tensorflow//tensorflow/tools/toolchains/python:python_repo.bzl", - "python_repository", -) -python_repository(name = "python_version_repo") -load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") +load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") +tf_workspace3() -# TF workspace scripts below requires `@python` toolchains repo. -# Toolchain setup here is to please the TF workspace scripts, -# and we do not use this Python version to build pip packages. -python_register_toolchains( - name = "python", - ignore_root_user_error = True, - python_version = HERMETIC_PYTHON_VERSION, +# Initialize hermetic Python +load("@local_tsl//third_party/py:python_init_rules.bzl", "python_init_rules") +python_init_rules() + +load("@local_tsl//third_party/py:python_init_repositories.bzl", "python_init_repositories") +python_init_repositories( + default_python_version = "system", + requirements = { + "3.9": "@org_tensorflow//:requirements_lock_3_9.txt", + "3.10": "@org_tensorflow//:requirements_lock_3_10.txt", + "3.11": "@org_tensorflow//:requirements_lock_3_11.txt", + "3.12": "@org_tensorflow//:requirements_lock_3_12.txt", + }, ) -load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") -tf_workspace3() +load("@local_tsl//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") +python_init_toolchains() load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") tf_workspace2() @@ -67,3 +44,9 @@ tf_workspace1() load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") tf_workspace0() + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) +cuda_configure(name = "local_config_cuda") diff --git a/models/bls2017.py b/models/bls2017.py index 49839bf..715059f 100644 --- a/models/bls2017.py +++ b/models/bls2017.py @@ -107,7 +107,7 @@ def call(self, x, training): """Computes rate and distortion losses.""" entropy_model = tfc.ContinuousBatchedEntropyModel( self.prior, coding_rank=3, compression=False) - x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary? + x = tf.cast(x, self.compute_dtype) # TODO(jonarchist): Why is this necessary? y = self.analysis_transform(x) y_hat, bits = entropy_model(y, training=training) x_hat = self.synthesis_transform(y_hat) diff --git a/models/bmshj2018.py b/models/bmshj2018.py index c0b18ab..5539244 100644 --- a/models/bmshj2018.py +++ b/models/bmshj2018.py @@ -162,7 +162,7 @@ def call(self, x, training): side_entropy_model = tfc.ContinuousBatchedEntropyModel( self.hyperprior, coding_rank=3, compression=False) - x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary? + x = tf.cast(x, self.compute_dtype) # TODO(jonarchist): Why is this necessary? y = self.analysis_transform(x) z = self.hyper_analysis_transform(abs(y)) z_hat, side_bits = side_entropy_model(z, training=training) diff --git a/models/ms2020.py b/models/ms2020.py index 17f2009..cccb439 100644 --- a/models/ms2020.py +++ b/models/ms2020.py @@ -199,7 +199,7 @@ def __init__(self, lmbda, def call(self, x, training): """Computes rate and distortion losses.""" - x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary? + x = tf.cast(x, self.compute_dtype) # TODO(jonarchist): Why is this necessary? # Build the encoder (analysis) half of the hierarchical autoencoder. y = self.analysis_transform(x) y_shape = tf.shape(y)[1:-1] diff --git a/requirements.txt b/requirements.txt index df05d35..f5ae19e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -scipy ~= 1.4 +scipy ~= 1.11.0 tensorflow ~= 2.14.0 -tensorflow-probability ~= 0.15 +tensorflow-probability >= 0.15, < 0.23 diff --git a/tensorflow_compression/cc/BUILD b/tensorflow_compression/cc/BUILD index 7f422d6..9b08f29 100644 --- a/tensorflow_compression/cc/BUILD +++ b/tensorflow_compression/cc/BUILD @@ -37,7 +37,7 @@ cc_library( srcs = ["lib/bit_coder.cc"], hdrs = ["lib/bit_coder.h"], deps = [ - "@com_google_absl//absl/base:endian", + "@com_google_absl//absl/base:config", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow_compression/cc/kernels/range_coder_kernels.cc b/tensorflow_compression/cc/kernels/range_coder_kernels.cc index ae297d6..fae3d80 100644 --- a/tensorflow_compression/cc/kernels/range_coder_kernels.cc +++ b/tensorflow_compression/cc/kernels/range_coder_kernels.cc @@ -104,7 +104,7 @@ Status CheckInRange(absl::string_view name, int64_t value, int64_t min, return errors::InvalidArgument( absl::Substitute("$0=$1 not in range [$2, $3)", name, value, min, max)); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status ScanCDF(const int32_t* const end, const int32_t** current, @@ -133,7 +133,7 @@ Status ScanCDF(const int32_t* const end, const int32_t** current, ++p; } *current = p; - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status IndexCDFVector(const TTypes::ConstFlat& table, @@ -144,7 +144,7 @@ Status IndexCDFVector(const TTypes::ConstFlat& table, for (const int32_t* current = start; current != end;) { TF_RETURN_IF_ERROR(ScanCDF(end, ¤t, lookup)); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status IndexCDFMatrix(const TTypes::ConstMatrix& table, @@ -160,12 +160,12 @@ Status IndexCDFMatrix(const TTypes::ConstMatrix& table, return errors::InvalidArgument("CDF must end with 1 << precision."); } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } class RangeEncoderInterface : public EntropyEncoderInterface { public: - static tensorflow::StatusOr> Make( + static absl::StatusOr> Make( tensorflow::OpKernelContext* context, const TensorShape& handle_shape) { std::shared_ptr p(new RangeEncoderInterface); @@ -199,7 +199,7 @@ class RangeEncoderInterface : public EntropyEncoderInterface { // when index tensor was not provided. tensorflow::mutex mu; - tensorflow::Status status ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + absl::Status status ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); #define REQUIRE_IN_RANGE(name, value, min, max) \ if (auto s = CheckInRange(name, value, min, max); \ @@ -283,7 +283,7 @@ class RangeEncoderInterface : public EntropyEncoderInterface { encoder_[i].Finalize(&encoded_[i]); output(i) = std::move(encoded_[i]); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } private: @@ -331,7 +331,7 @@ class RangeEncoderInterface : public EntropyEncoderInterface { class RangeDecoderInterface : public EntropyDecoderInterface { public: - static StatusOr> Make( + static absl::StatusOr> Make( tensorflow::OpKernelContext* context) { std::shared_ptr p(new RangeDecoderInterface); @@ -364,7 +364,7 @@ class RangeDecoderInterface : public EntropyDecoderInterface { CHECK_EQ(decoder_.size(), output.dimension(0)); tensorflow::mutex mu; - tensorflow::Status status ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); + absl::Status status ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu); #define REQUIRE_IN_RANGE(name, value, min, max) \ if (auto s = CheckInRange(name, value, min, max); \ @@ -442,7 +442,7 @@ class RangeDecoderInterface : public EntropyDecoderInterface { VLOG(0) << "RangeDecoder #" << i << " final status was an error"; } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } private: diff --git a/tensorflow_compression/cc/kernels/range_coding_kernels.cc b/tensorflow_compression/cc/kernels/range_coding_kernels.cc index 58563a1..f2487e2 100644 --- a/tensorflow_compression/cc/kernels/range_coding_kernels.cc +++ b/tensorflow_compression/cc/kernels/range_coding_kernels.cc @@ -144,11 +144,11 @@ Status CheckCdfShape(const TensorShape& data_shape, "The last dimension of `cdf` should be > 1: ", cdf_shape.DebugString()); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } -tensorflow::Status CheckCdfValues(int precision, - const tensorflow::Tensor& cdf_tensor) { +absl::Status CheckCdfValues(int precision, + const tensorflow::Tensor& cdf_tensor) { const auto cdf = cdf_tensor.flat_inner_dims(); const auto size = cdf.dimension(1); if (size <= 2) { @@ -169,7 +169,7 @@ tensorflow::Status CheckCdfValues(int precision, } } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } class RangeEncodeOp : public OpKernel { @@ -230,11 +230,11 @@ class RangeEncodeOp : public OpKernel { private: template - tensorflow::Status RangeEncodeImpl(TTypes::ConstFlat data, - absl::Span data_shape, - TTypes::ConstMatrix cdf, - absl::Span cdf_shape, - std::string* output) const { + absl::Status RangeEncodeImpl(TTypes::ConstFlat data, + absl::Span data_shape, + TTypes::ConstMatrix cdf, + absl::Span cdf_shape, + std::string* output) const { const int64_t data_size = data.size(); const int64_t cdf_size = cdf.size(); const int64_t chip_size = cdf.dimension(1); @@ -265,7 +265,7 @@ class RangeEncodeOp : public OpKernel { } encoder.Finalize(output); - return tensorflow::OkStatus(); + return absl::OkStatus(); } int precision_; @@ -343,11 +343,11 @@ class RangeDecodeOp : public OpKernel { private: template - tensorflow::Status RangeDecodeImpl(TTypes::Flat output, - absl::Span output_shape, - TTypes::ConstMatrix cdf, - absl::Span cdf_shape, - const tstring& encoded) const { + absl::Status RangeDecodeImpl(TTypes::Flat output, + absl::Span output_shape, + TTypes::ConstMatrix cdf, + absl::Span cdf_shape, + const tstring& encoded) const { BroadcastRange view{output.data(), output_shape, cdf.data(), cdf_shape}; @@ -369,7 +369,7 @@ class RangeDecodeOp : public OpKernel { *data = decoder.Decode({cdf_slice, chip_size}, precision_); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } int precision_; diff --git a/tensorflow_compression/cc/kernels/range_coding_kernels_test.cc b/tensorflow_compression/cc/kernels/range_coding_kernels_test.cc index f3ee9f8..823fd2c 100644 --- a/tensorflow_compression/cc/kernels/range_coding_kernels_test.cc +++ b/tensorflow_compression/cc/kernels/range_coding_kernels_test.cc @@ -130,7 +130,7 @@ class RangeCoderOpsTest : public OpsTestBase { *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status RunDecodeOp(int precision, absl::Span input, @@ -166,7 +166,7 @@ class RangeCoderOpsTest : public OpsTestBase { *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } void TestEncodeAndDecode(int precision, const Tensor& data, diff --git a/tensorflow_compression/cc/kernels/range_coding_kernels_util.cc b/tensorflow_compression/cc/kernels/range_coding_kernels_util.cc index 1440816..fa2f2d2 100644 --- a/tensorflow_compression/cc/kernels/range_coding_kernels_util.cc +++ b/tensorflow_compression/cc/kernels/range_coding_kernels_util.cc @@ -87,7 +87,7 @@ Status MergeAxes(const TensorShape& broadcast_shape, } merged_storage_shape.push_back(storage_stride); - return tensorflow::OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow_compression diff --git a/tensorflow_compression/cc/kernels/run_length_gamma_kernels.cc b/tensorflow_compression/cc/kernels/run_length_gamma_kernels.cc index fa41727..588d3db 100644 --- a/tensorflow_compression/cc/kernels/run_length_gamma_kernels.cc +++ b/tensorflow_compression/cc/kernels/run_length_gamma_kernels.cc @@ -24,8 +24,10 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" @@ -46,13 +48,6 @@ using tensorflow::TensorShape; using tensorflow::TensorShapeUtils; using tensorflow::tstring; -#define OP_REQUIRES_OK_ABSL(context, status) \ - { \ - auto s = (status); \ - OP_REQUIRES(context, s.ok(), tensorflow::Status( \ - static_cast(s.code()), s.message())); \ - } - class RunLengthGammaEncodeOp : public OpKernel { public: explicit RunLengthGammaEncodeOp(OpKernelConstruction* context) @@ -140,7 +135,7 @@ class RunLengthGammaDecodeOp : public OpKernel { for (int64_t i = 0; i < data.size(); i++) { // Get number of zeros. auto num_zeros = dec.ReadGamma(); - OP_REQUIRES_OK_ABSL(context, num_zeros.status()); + OP_REQUIRES_OK(context, num_zeros.status()); // Advance the index to the next non-zero element. i += *num_zeros - 1; @@ -155,11 +150,11 @@ class RunLengthGammaDecodeOp : public OpKernel { // Get sign of value. auto positive = dec.ReadOneBit(); - OP_REQUIRES_OK_ABSL(context, positive.status()); + OP_REQUIRES_OK(context, positive.status()); // Get magnitude. auto magnitude = dec.ReadGamma(); - OP_REQUIRES_OK_ABSL(context, magnitude.status()); + OP_REQUIRES_OK(context, magnitude.status()); // Write value to data tensor element at index. data(i) = *positive ? *magnitude : -*magnitude; @@ -170,7 +165,5 @@ class RunLengthGammaDecodeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("RunLengthGammaDecode").Device(DEVICE_CPU), RunLengthGammaDecodeOp); -#undef OP_REQUIRES_OK_ABSL - } // namespace } // namespace tensorflow_compression diff --git a/tensorflow_compression/cc/kernels/run_length_gamma_kernels_test.cc b/tensorflow_compression/cc/kernels/run_length_gamma_kernels_test.cc index 76f3bc2..caae8f5 100644 --- a/tensorflow_compression/cc/kernels/run_length_gamma_kernels_test.cc +++ b/tensorflow_compression/cc/kernels/run_length_gamma_kernels_test.cc @@ -73,7 +73,7 @@ class BitCodingOpsTest : public OpsTestBase { TF_RETURN_IF_ERROR(RunOpKernel()); *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status RunDecodeOp(absl::Span inputs, Tensor* output) { @@ -91,7 +91,7 @@ class BitCodingOpsTest : public OpsTestBase { TF_RETURN_IF_ERROR(RunOpKernel()); *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } void TestEncodeAndDecode(const Tensor& data_tensor) { @@ -271,7 +271,7 @@ TEST_F(BitCodingOpsTest, DecodeConsistent) { test::ExpectTensorEqual(data_tensor, expected_data_tensor); } -// TODO(nicolemitchell,jonycgn) Add more corner cases to unit tests. +// TODO(nicolemitchell,jonarchist) Add more corner cases to unit tests. // Examples: decode empty string (null pointer), decode strings that end // prematurely, decode long string of zeros that causes overflow in ReadGamma, // decode incorrect run length that exceeds tensor size, encode int32::min diff --git a/tensorflow_compression/cc/kernels/run_length_kernels.cc b/tensorflow_compression/cc/kernels/run_length_kernels.cc index 0069c35..8947365 100644 --- a/tensorflow_compression/cc/kernels/run_length_kernels.cc +++ b/tensorflow_compression/cc/kernels/run_length_kernels.cc @@ -24,8 +24,10 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" @@ -46,16 +48,7 @@ using tensorflow::TensorShape; using tensorflow::TensorShapeUtils; using tensorflow::tstring; -#define OP_REQUIRES_OK_ABSL(context, status) \ - { \ - auto s = (status); \ - OP_REQUIRES( \ - context, s.ok(), \ - tensorflow::Status(static_cast(s.code()), \ - s.message())); \ - } - -// TODO(jonycgn): Try to avoid in-loop branches based on attributes. +// TODO(jonarchist): Try to avoid in-loop branches based on attributes. class RunLengthEncodeOp : public OpKernel { public: @@ -227,7 +220,7 @@ class RunLengthDecodeOp : public OpKernel { while (p < end) { // Skip to the next non-zero element. auto run_length = ReadRunLength(context, dec); - OP_REQUIRES_OK_ABSL(context, run_length.status()); + OP_REQUIRES_OK(context, run_length.status()); p += *run_length + run_length_offset; @@ -240,19 +233,19 @@ class RunLengthDecodeOp : public OpKernel { if (use_run_length_for_non_zeros_) { run_length = ReadRunLength(context, dec); - OP_REQUIRES_OK_ABSL(context, run_length.status()); + OP_REQUIRES_OK(context, run_length.status()); const int32_t* const next_zero = p + *run_length + 1; OP_REQUIRES(context, next_zero <= end, errors::DataLoss("Decoded past end of tensor.")); while (p < next_zero) { auto nonzero = ReadNonZero(context, dec); - OP_REQUIRES_OK_ABSL(context, nonzero.status()); + OP_REQUIRES_OK(context, nonzero.status()); *p++ = *nonzero; } run_length_offset = 1; } else { auto nonzero = ReadNonZero(context, dec); - OP_REQUIRES_OK_ABSL(context, nonzero.status()); + OP_REQUIRES_OK(context, nonzero.status()); *p++ = *nonzero; } } @@ -267,7 +260,5 @@ class RunLengthDecodeOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("RunLengthDecode").Device(DEVICE_CPU), RunLengthDecodeOp); -#undef OP_REQUIRES_OK_ABSL - } // namespace } // namespace tensorflow_compression diff --git a/tensorflow_compression/cc/kernels/run_length_kernels_test.cc b/tensorflow_compression/cc/kernels/run_length_kernels_test.cc index a5ff8cf..f342a18 100644 --- a/tensorflow_compression/cc/kernels/run_length_kernels_test.cc +++ b/tensorflow_compression/cc/kernels/run_length_kernels_test.cc @@ -78,7 +78,7 @@ class BitCodingOpsTest : public OpsTestBase { TF_RETURN_IF_ERROR(RunOpKernel()); *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status RunDecodeOp(absl::Span inputs, Tensor* output, @@ -102,7 +102,7 @@ class BitCodingOpsTest : public OpsTestBase { TF_RETURN_IF_ERROR(RunOpKernel()); *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } void TestEncodeAndDecode(const Tensor& data_tensor, const int run_length_code, @@ -304,7 +304,7 @@ TEST_F(BitCodingOpsTest, DecodeConsistent) { test::ExpectTensorEqual(data_tensor, expected_data_tensor); } -// TODO(nicolemitchell,jonycgn) Add more corner cases to unit tests. +// TODO(nicolemitchell,jonarchist) Add more corner cases to unit tests. // Examples: decode empty string (null pointer), decode strings that end // prematurely, decode long string of zeros that causes overflow in ReadGamma, // decode incorrect run length that exceeds tensor size, encode int32::min diff --git a/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc b/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc index ca5e0fa..969d2d4 100644 --- a/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc +++ b/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc @@ -51,7 +51,7 @@ using tensorflow::TensorShapeUtils; using tensorflow::tstring; using tensorflow::TTypes; -tensorflow::Status CheckIndex(int64_t upper_bound, const Tensor& index) { +absl::Status CheckIndex(int64_t upper_bound, const Tensor& index) { auto flat = index.flat(); for (int64_t i = 0; i < flat.size(); ++i) { if (flat(i) < 0 || upper_bound <= flat(i)) { @@ -59,10 +59,10 @@ tensorflow::Status CheckIndex(int64_t upper_bound, const Tensor& index) { upper_bound, "): value=", flat(i)); } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } -tensorflow::Status CheckCdfSize(int64_t upper_bound, const Tensor& cdf_size) { +absl::Status CheckCdfSize(int64_t upper_bound, const Tensor& cdf_size) { auto flat = cdf_size.vec(); for (int64_t i = 0; i < flat.size(); ++i) { if (flat(i) < 3 || upper_bound < flat(i)) { @@ -70,11 +70,11 @@ tensorflow::Status CheckCdfSize(int64_t upper_bound, const Tensor& cdf_size) { upper_bound, "]: value=", flat(i)); } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } -tensorflow::Status CheckCdf(int precision, const Tensor& cdf, - const Tensor& cdf_size) { +absl::Status CheckCdf(int precision, const Tensor& cdf, + const Tensor& cdf_size) { auto matrix = cdf.matrix(); auto size = cdf_size.vec(); CHECK_EQ(matrix.dimension(0), size.size()); @@ -96,23 +96,21 @@ tensorflow::Status CheckCdf(int precision, const Tensor& cdf, } } } - return tensorflow::OkStatus(); + return absl::OkStatus(); } // Assumes that CheckArgumentShapes().ok(). -tensorflow::Status CheckArgumentValues(int precision, const Tensor& index, - const Tensor& cdf, - const Tensor& cdf_size, - const Tensor& offset) { +absl::Status CheckArgumentValues(int precision, const Tensor& index, + const Tensor& cdf, const Tensor& cdf_size, + const Tensor& offset) { TF_RETURN_IF_ERROR(CheckIndex(cdf.dim_size(0), index)); TF_RETURN_IF_ERROR(CheckCdfSize(cdf.dim_size(1), cdf_size)); TF_RETURN_IF_ERROR(CheckCdf(precision, cdf, cdf_size)); - return tensorflow::OkStatus(); + return absl::OkStatus(); } -tensorflow::Status CheckArgumentShapes(const Tensor& index, const Tensor& cdf, - const Tensor& cdf_size, - const Tensor& offset) { +absl::Status CheckArgumentShapes(const Tensor& index, const Tensor& cdf, + const Tensor& cdf_size, const Tensor& offset) { if (!TensorShapeUtils::IsMatrix(cdf.shape()) || cdf.dim_size(1) < 3) { return errors::InvalidArgument( "'cdf' should be 2-D and cdf.dim_size(1) >= 3: ", cdf.shape()); @@ -131,7 +129,7 @@ tensorflow::Status CheckArgumentShapes(const Tensor& index, const Tensor& cdf, "should match the number of rows in 'cdf': offset.shape=", offset.shape(), ", cdf.shape=", cdf.shape()); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } class UnboundedIndexRangeEncodeOp : public OpKernel { @@ -306,12 +304,12 @@ class UnboundedIndexRangeDecodeOp : public OpKernel { } private: - tensorflow::Status RangeDecodeImpl(TTypes::Flat output, - TTypes::ConstFlat index, - TTypes::ConstMatrix cdf, - TTypes::ConstVec cdf_size, - TTypes::ConstVec offset, - TTypes::ConstFlat encoded) const { + absl::Status RangeDecodeImpl(TTypes::Flat output, + TTypes::ConstFlat index, + TTypes::ConstMatrix cdf, + TTypes::ConstVec cdf_size, + TTypes::ConstVec offset, + TTypes::ConstFlat encoded) const { RangeDecoder decoder(encoded(0)); DCHECK_GE(cdf.dimension(1), 2); @@ -365,7 +363,7 @@ class UnboundedIndexRangeDecodeOp : public OpKernel { output(i) = value; } - return tensorflow::OkStatus(); + return absl::OkStatus(); } int precision_; diff --git a/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels_test.cc b/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels_test.cc index c7cbf25..2c2dd28 100644 --- a/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels_test.cc +++ b/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels_test.cc @@ -167,7 +167,7 @@ class UnboundedIndexRangeCoderOpsTest : public OpsTestBase { *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } Status RunDecodeOpDebug(int precision, int overflow_width, @@ -205,7 +205,7 @@ class UnboundedIndexRangeCoderOpsTest : public OpsTestBase { *output = *GetOutput(0); inputs_.clear(); - return tensorflow::OkStatus(); + return absl::OkStatus(); } void TestEncodeAndDecode(int precision, int overflow_width, diff --git a/tensorflow_compression/cc/kernels/y4m_dataset_kernels.cc b/tensorflow_compression/cc/kernels/y4m_dataset_kernels.cc index 3baf24e..cddd3eb 100644 --- a/tensorflow_compression/cc/kernels/y4m_dataset_kernels.cc +++ b/tensorflow_compression/cc/kernels/y4m_dataset_kernels.cc @@ -13,63 +13,65 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include - -#include "absl/strings/str_join.h" +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/tsl/platform/mutex.h" +#include "tensorflow/tsl/platform/thread_annotations.h" +#include "tensorflow/tsl/platform/tstring.h" namespace tensorflow_compression { namespace { -namespace errors = tensorflow::errors; -using absl::string_view; -using std::string; -using std::vector; -using tensorflow::DataTypeVector; +namespace errors = tsl::errors; using tensorflow::DT_UINT8; -using tensorflow::mutex; -using tensorflow::mutex_lock; -using tensorflow::Node; -using tensorflow::OpKernelContext; -using tensorflow::PartialTensorShape; -using tensorflow::RandomAccessFile; -using tensorflow::Status; using tensorflow::Tensor; -using tensorflow::data::DatasetBase; -using tensorflow::data::DatasetIterator; -using tensorflow::data::DatasetOpKernel; -using tensorflow::data::IteratorContext; -using tensorflow::data::SerializationContext; -class Y4MDatasetOp : public DatasetOpKernel { +class Y4MDatasetOp : public tensorflow::data::DatasetOpKernel { public: explicit Y4MDatasetOp(tensorflow::OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + void MakeDataset(tensorflow::OpKernelContext* ctx, + tensorflow::data::DatasetBase** output) override { const Tensor* filenames_tensor; OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + OP_REQUIRES(ctx, filenames_tensor->dims() <= 1, + absl::InvalidArgumentError( + "`filenames` must be a scalar or a vector.")); - vector filenames; + std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.emplace_back(filenames_tensor->flat()(i)); + filenames.emplace_back(filenames_tensor->flat()(i)); } *output = new Dataset(ctx, std::move(filenames)); } private: - class Dataset : public DatasetBase { + class Dataset : public tensorflow::data::DatasetBase { public: - explicit Dataset(OpKernelContext* ctx, vector filenames) + explicit Dataset(tensorflow::OpKernelContext* ctx, + std::vector filenames) : DatasetBase(tensorflow::data::DatasetContext(ctx)), filenames_(std::move(filenames)) {} @@ -79,55 +81,59 @@ class Y4MDatasetOp : public DatasetOpKernel { new Iterator({this, absl::StrCat(prefix, "::Y4M")})); } - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_UINT8, DT_UINT8}); + const tensorflow::DataTypeVector& output_dtypes() const override { + static tensorflow::DataTypeVector* dtypes = + new tensorflow::DataTypeVector({DT_UINT8, DT_UINT8}); return *dtypes; } - const vector& output_shapes() const override { - static vector* shapes = - new vector{{-1, -1, 1}, {-1, -1, 2}}; + const std::vector& output_shapes() + const override { + static std::vector* shapes = + new std::vector{{-1, -1, 1}, + {-1, -1, 2}}; return *shapes; } - string DebugString() const override { return "Y4MDatasetOp::Dataset"; } + std::string DebugString() const override { return "Y4MDatasetOp::Dataset"; } - Status InputDatasets(vector* inputs) const override { - return tensorflow::OkStatus(); + absl::Status InputDatasets( + std::vector* inputs) const override { + return absl::OkStatus(); } - Status CheckExternalState() const override { - return tensorflow::OkStatus(); + absl::Status CheckExternalState() const override { + return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; + absl::Status AsGraphDefInternal(tensorflow::data::SerializationContext* ctx, + DatasetGraphDefBuilder* b, + tensorflow::Node** output) const override { + tensorflow::Node* filenames = nullptr; TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return tensorflow::OkStatus(); + return absl::OkStatus(); } private: - class Iterator : public DatasetIterator { + class Iterator : public tensorflow::data::DatasetIterator { public: explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - vector* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); + absl::Status GetNextInternal(tensorflow::data::IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + tsl::mutex_lock l(mu_); do { if (file_) { - const string_view frame_header = "FRAME\n"; + const absl::string_view frame_header = "FRAME\n"; size_t frame_size = width_ * height_ * 3; int64_t cbcr_width = width_; int64_t cbcr_height = height_; - string_view frame_buffer; + absl::string_view frame_buffer; if (chroma_format_ == ChromaFormat::I420) { frame_size /= 2; @@ -140,8 +146,9 @@ class Y4MDatasetOp : public DatasetOpKernel { buffer_.resize(frame_header.size() + frame_size); // Try to read the next frame. - Status status = file_->Read(file_pos_, buffer_.size(), - &frame_buffer, &buffer_[0]); + absl::Status status = + file_->Read(file_pos_, frame_buffer, + absl::MakeSpan(&buffer_[0], buffer_.size())); // Yield frame on successful read of a complete frame. if (status.ok()) { @@ -150,7 +157,8 @@ class Y4MDatasetOp : public DatasetOpKernel { if (!absl::ConsumePrefix(&frame_buffer, frame_header)) { return errors::InvalidArgument( "Input file '", dataset()->filenames_[file_index_], - "' has a FRAME marker at byte ", file_pos_, " which is " + "' has a FRAME marker at byte ", file_pos_, + " which is " "either invalid or has unsupported frame parameters."); } @@ -163,8 +171,8 @@ class Y4MDatasetOp : public DatasetOpKernel { std::memcpy(flat_y.data(), frame_buffer.data(), flat_y.size()); frame_buffer.remove_prefix(flat_y.size()); for (int i = 0; i < cbcr_size; i++) { - flat_cbcr.data()[2*i] = frame_buffer[i]; - flat_cbcr.data()[2*i+1] = frame_buffer[cbcr_size+i]; + flat_cbcr.data()[2 * i] = frame_buffer[i]; + flat_cbcr.data()[2 * i + 1] = frame_buffer[cbcr_size + i]; } out_tensors->push_back(std::move(y_tensor)); out_tensors->push_back(std::move(cbcr_tensor)); @@ -176,7 +184,7 @@ class Y4MDatasetOp : public DatasetOpKernel { // Catch any other errors than out of range, which needs special // treatment. - if (!errors::IsOutOfRange(status)) { + if (!absl::IsOutOfRange(status)) { return status; } @@ -199,7 +207,7 @@ class Y4MDatasetOp : public DatasetOpKernel { // Exit if there are no more files to process. if (file_index_ >= dataset()->filenames_.size()) { *end_of_sequence = true; - return tensorflow::OkStatus(); + return absl::OkStatus(); } // Open next file. @@ -215,10 +223,10 @@ class Y4MDatasetOp : public DatasetOpKernel { } protected: - Status SaveInternal( - SerializationContext* ctx, + absl::Status SaveInternal( + tensorflow::data::SerializationContext* ctx, tensorflow::data::IteratorStateWriter* writer) override { - mutex_lock l(mu_); + tsl::mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("file_index"), file_index_)); @@ -227,13 +235,13 @@ class Y4MDatasetOp : public DatasetOpKernel { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("file_pos"), file_pos)); - return tensorflow::OkStatus(); + return absl::OkStatus(); } - Status RestoreInternal( - IteratorContext* ctx, + absl::Status RestoreInternal( + tensorflow::data::IteratorContext* ctx, tensorflow::data::IteratorStateReader* reader) override { - mutex_lock l(mu_); + tsl::mutex_lock l(mu_); int64_t file_index; int64_t file_pos; @@ -252,14 +260,14 @@ class Y4MDatasetOp : public DatasetOpKernel { chroma_format_)); file_pos_ = file_pos; } - return tensorflow::OkStatus(); + return absl::OkStatus(); } private: enum class ChromaFormat { undefined, I420, I444 }; - Status ReadHeader(const RandomAccessFile& file, const size_t file_index, - string& header) { + absl::Status ReadHeader(const tsl::RandomAccessFile& file, + const size_t file_index, std::string& header) { // 256 bytes should be more than enough in most cases. If not, keep // reading chunks until header is complete. const size_t chunk_size = 256; @@ -267,11 +275,11 @@ class Y4MDatasetOp : public DatasetOpKernel { do { const uint64_t offset = header.size(); header.resize(offset + chunk_size); - string_view chunk; - Status status = file.Read( - offset, chunk_size, &chunk, &header[offset]); + absl::string_view chunk; + absl::Status status = file.Read( + offset, chunk, absl::MakeSpan(&header[offset], chunk_size)); // End of file error is fine, as long as the header is complete. - if (!(status.ok() || errors::IsOutOfRange(status))) { + if (!(status.ok() || absl::IsOutOfRange(status))) { return status; } const size_t pos = chunk.find('\n'); @@ -280,7 +288,7 @@ class Y4MDatasetOp : public DatasetOpKernel { std::memcpy(&header[offset], chunk.data(), pos + 1); } header.resize(offset + pos + 1); - return tensorflow::OkStatus(); + return absl::OkStatus(); } // We reached the end of the file, and the header is not complete. if (!status.ok()) { @@ -294,10 +302,10 @@ class Y4MDatasetOp : public DatasetOpKernel { } while (true); } - Status ParseHeader(string_view header, const size_t file_index, - int64_t& width, int64_t& height, - ChromaFormat& chroma_format) { - const string_view digits = "0123456789"; + absl::Status ParseHeader(absl::string_view header, + const size_t file_index, int64_t& width, + int64_t& height, ChromaFormat& chroma_format) { + const absl::string_view digits = "0123456789"; width = 0; height = 0; @@ -308,9 +316,9 @@ class Y4MDatasetOp : public DatasetOpKernel { header.remove_suffix(1); if (!absl::ConsumePrefix(&header, "YUV4MPEG2")) { - return errors::InvalidArgument( - "Input file '", dataset()->filenames_[file_index], - "' does not have a YUV4MPEG2 marker."); + return errors::InvalidArgument("Input file '", + dataset()->filenames_[file_index], + "' does not have a YUV4MPEG2 marker."); } while (!header.empty()) { @@ -375,39 +383,39 @@ class Y4MDatasetOp : public DatasetOpKernel { } if (!width) { - return errors::InvalidArgument( - "Input file '", dataset()->filenames_[file_index], - "' has no width specifier."); + return errors::InvalidArgument("Input file '", + dataset()->filenames_[file_index], + "' has no width specifier."); } if (!height) { - return errors::InvalidArgument( - "Input file '", dataset()->filenames_[file_index], - "' has no height specifier."); + return errors::InvalidArgument("Input file '", + dataset()->filenames_[file_index], + "' has no height specifier."); } if (chroma_format == ChromaFormat::undefined) { - return errors::InvalidArgument( - "Input file '", dataset()->filenames_[file_index], - "' has no chroma format specifier."); + return errors::InvalidArgument("Input file '", + dataset()->filenames_[file_index], + "' has no chroma format specifier."); } if (chroma_format == ChromaFormat::I420 && (width & 1 || height & 1)) { return errors::InvalidArgument( "Input file '", dataset()->filenames_[file_index], "' has 4:2:0 chroma format, but odd width or height."); } - return tensorflow::OkStatus(); + return absl::OkStatus(); } - mutex mu_; + tsl::mutex mu_; size_t file_index_ TF_GUARDED_BY(mu_) = 0; - std::unique_ptr file_ TF_GUARDED_BY(mu_); + std::unique_ptr file_ TF_GUARDED_BY(mu_); uint64_t file_pos_ TF_GUARDED_BY(mu_); - string buffer_ TF_GUARDED_BY(mu_); + std::string buffer_ TF_GUARDED_BY(mu_); int64_t width_ TF_GUARDED_BY(mu_); int64_t height_ TF_GUARDED_BY(mu_); ChromaFormat chroma_format_ TF_GUARDED_BY(mu_); }; - const vector filenames_; + const std::vector filenames_; }; }; diff --git a/tensorflow_compression/cc/lib/bit_coder.cc b/tensorflow_compression/cc/lib/bit_coder.cc index ae62d93..b4b1d7d 100644 --- a/tensorflow_compression/cc/lib/bit_coder.cc +++ b/tensorflow_compression/cc/lib/bit_coder.cc @@ -15,16 +15,37 @@ limitations under the License. #include "tensorflow_compression/cc/lib/bit_coder.h" #include + #include #include #include -#include +#include -#include "absl/base/internal/endian.h" +#include "absl/base/config.h" #include "absl/status/status.h" namespace tensorflow_compression { + +namespace little_endian { +namespace { +#ifndef ABSL_IS_LITTLE_ENDIAN +#error BitWriter assumes little endian +#endif + +uint64_t Load64(const void* p) { + uint64_t v; + std::memcpy(&v, p, sizeof(v)); + return v; +} + +void Store64(void* p, uint64_t v) { std::memcpy(p, &v, sizeof(v)); } + +uint64_t ToHost64(uint64_t v) { return v; } + +} // namespace +} // namespace little_endian + BitWriter::BitWriter() : next_index_(0), bits_in_buffer_(0), @@ -37,9 +58,9 @@ void BitWriter::WriteBits(uint32_t count, uint64_t bits) { bits &= (uint64_t{1} << count) - 1; buffer_ |= bits << bits_in_buffer_; bits_in_buffer_ += count; - // TODO(jonycgn): Investigate performance of buffer resizing. + // TODO(jonarchist): Investigate performance of buffer resizing. data_.resize(next_index_ + 8); - absl::little_endian::Store64(&data_[next_index_], buffer_); + little_endian::Store64(&data_[next_index_], buffer_); size_t bytes_in_buffer = bits_in_buffer_ / 8; bits_in_buffer_ -= bytes_in_buffer * 8; buffer_ >>= bytes_in_buffer * 8; @@ -100,14 +121,14 @@ void BitReader::Refill() { if (!bytes_to_copy) return; uint64_t x = 0; memcpy(&x, next_byte_, bytes_to_copy); - buffer_ |= absl::little_endian::ToHost(x) << bits_in_buffer_; + buffer_ |= little_endian::ToHost64(x) << bits_in_buffer_; next_byte_ += bytes_to_copy; bits_in_buffer_ += bytes_to_copy * 8; assert(bits_in_buffer_ < 64); } else { // It's safe to load 64 bits; insert valid (possibly nonzero) bits above // bits_in_buffer_. The shift requires bits_in_buffer_ < 64. - buffer_ |= absl::little_endian::Load64(next_byte_) << bits_in_buffer_; + buffer_ |= little_endian::Load64(next_byte_) << bits_in_buffer_; // Advance by bytes fully absorbed into the buffer. next_byte_ += (63 - bits_in_buffer_) / 8; // We absorbed a multiple of 8 bits, so the lower 3 bits of bits_in_buffer_ diff --git a/tensorflow_compression/cc/lib/range_coder.cc b/tensorflow_compression/cc/lib/range_coder.cc index 04cf08e..125b8ef 100644 --- a/tensorflow_compression/cc/lib/range_coder.cc +++ b/tensorflow_compression/cc/lib/range_coder.cc @@ -23,13 +23,14 @@ limitations under the License. #include "tensorflow_compression/cc/lib/range_coder.h" #include -#include #include +#include "absl/base/optimization.h" +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "tensorflow/core/platform/logging.h" namespace tensorflow_compression { diff --git a/tensorflow_compression/cc/ops/pmf_to_cdf_ops.cc b/tensorflow_compression/cc/ops/pmf_to_cdf_ops.cc index 78102cd..db082eb 100644 --- a/tensorflow_compression/cc/ops/pmf_to_cdf_ops.cc +++ b/tensorflow_compression/cc/ops/pmf_to_cdf_ops.cc @@ -37,7 +37,7 @@ REGISTER_OP("PmfToQuantizedCdf") ShapeHandle out; TF_RETURN_IF_ERROR(c->ReplaceDim(in, -1, last, &out)); c->set_output(0, out); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Converts a PMF into a quantized CDF for range coding. diff --git a/tensorflow_compression/cc/ops/range_coder_ops.cc b/tensorflow_compression/cc/ops/range_coder_ops.cc index fbe17ce..9bdb190 100644 --- a/tensorflow_compression/cc/ops/range_coder_ops.cc +++ b/tensorflow_compression/cc/ops/range_coder_ops.cc @@ -166,7 +166,7 @@ REGISTER_OP("EntropyDecodeChannel") TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &suffix_shape)); TF_RETURN_IF_ERROR(c->Concatenate(shape, suffix_shape, &shape)); c->set_output(1, shape); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Decodes the encoded stream inside `handle`. @@ -208,7 +208,7 @@ REGISTER_OP("EntropyDecodeIndex") TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &suffix_shape)); TF_RETURN_IF_ERROR(c->Concatenate(shape, suffix_shape, &shape)); c->set_output(1, shape); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Decodes the encoded stream inside `handle`. diff --git a/tensorflow_compression/cc/ops/range_coding_ops.cc b/tensorflow_compression/cc/ops/range_coding_ops.cc index 08373a1..30af91f 100644 --- a/tensorflow_compression/cc/ops/range_coding_ops.cc +++ b/tensorflow_compression/cc/ops/range_coding_ops.cc @@ -100,7 +100,7 @@ REGISTER_OP("RangeDecode") ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); c->set_output(0, out); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Range-decodes `code` into an int32 tensor of shape `shape`. @@ -215,7 +215,7 @@ REGISTER_OP("UnboundedIndexRangeDecode") .Attr("debug_level: int = 1") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->input(1)); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Range decodes `encoded` using an indexed probability table. diff --git a/tensorflow_compression/cc/ops/run_length_gamma_ops.cc b/tensorflow_compression/cc/ops/run_length_gamma_ops.cc index f40c163..f215cc9 100644 --- a/tensorflow_compression/cc/ops/run_length_gamma_ops.cc +++ b/tensorflow_compression/cc/ops/run_length_gamma_ops.cc @@ -42,7 +42,7 @@ REGISTER_OP("RunLengthGammaDecode") ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); c->set_output(0, out); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Decodes `data` using run-length and Elias gamma coding. diff --git a/tensorflow_compression/cc/ops/run_length_ops.cc b/tensorflow_compression/cc/ops/run_length_ops.cc index 3786408..337db8f 100644 --- a/tensorflow_compression/cc/ops/run_length_ops.cc +++ b/tensorflow_compression/cc/ops/run_length_ops.cc @@ -38,9 +38,9 @@ calling RunLengthEncode with run_length_code = -1, magnitude_code = -1, and use_run_length_for_non_zeros = false. run_length_code: If >= 0, use Rice code with this parameter to encode run - lengths, else use Golomb code. + lengths, else use Elias gamma code. magnitude_code: If >= 0, use Rice code with this parameter to encode magnitudes, - else use Golomb code. + else use Elias gamma code. use_run_length_for_non_zeros: If true, alternate between coding run lengths of zeros and non-zeros. If false, only encode run lengths of zeros, and encode non-zeros one by one. @@ -59,7 +59,7 @@ REGISTER_OP("RunLengthDecode") ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); c->set_output(0, out); - return tensorflow::OkStatus(); + return absl::OkStatus(); }) .Doc(R"doc( Decodes `data` using run-length coding. @@ -72,9 +72,9 @@ calling RunLengthDecode with run_length_code = -1, magnitude_code = -1, and use_run_length_for_non_zeros = false. run_length_code: If >= 0, use Rice code with this parameter to decode run - lengths, else use Golomb code. + lengths, else use Elias gamma code. magnitude_code: If >= 0, use Rice code with this parameter to decode magnitudes, - else use Golomb code. + else use Elias gamma code. use_run_length_for_non_zeros: If true, alternate between coding run lengths of zeros and non-zeros. If false, only decode run lengths of zeros, and decode non-zeros one by one. diff --git a/tensorflow_compression/python/__init__.py b/tensorflow_compression/python/__init__.py new file mode 100644 index 0000000..75a85b7 --- /dev/null +++ b/tensorflow_compression/python/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + diff --git a/tensorflow_compression/python/datasets/BUILD b/tensorflow_compression/python/datasets/BUILD index 1ccbdb4..9f714ee 100644 --- a/tensorflow_compression/python/datasets/BUILD +++ b/tensorflow_compression/python/datasets/BUILD @@ -23,8 +23,3 @@ py_test( srcs = ["y4m_dataset_test.py"], deps = [":y4m_dataset"], ) - -filegroup( - name = "py_src", - srcs = glob(["*.py"]), -) diff --git a/tensorflow_compression/python/distributions/BUILD b/tensorflow_compression/python/distributions/BUILD index 3a5d67c..48761d7 100644 --- a/tensorflow_compression/python/distributions/BUILD +++ b/tensorflow_compression/python/distributions/BUILD @@ -81,8 +81,3 @@ py_test( ":round_adapters", ], ) - -filegroup( - name = "py_src", - srcs = glob(["*.py"]), -) diff --git a/tensorflow_compression/python/distributions/helpers.py b/tensorflow_compression/python/distributions/helpers.py index 027e87b..b5f0d6c 100644 --- a/tensorflow_compression/python/distributions/helpers.py +++ b/tensorflow_compression/python/distributions/helpers.py @@ -25,7 +25,7 @@ ] -# TODO(jonycgn): Consider wrapping in tf.function. +# TODO(jonarchist): Consider wrapping in tf.function. def estimate_tails(func, target, shape, dtype): """Estimates approximate tail quantiles. diff --git a/tensorflow_compression/python/entropy_models/BUILD b/tensorflow_compression/python/entropy_models/BUILD index 166c2e3..47185e6 100644 --- a/tensorflow_compression/python/entropy_models/BUILD +++ b/tensorflow_compression/python/entropy_models/BUILD @@ -10,6 +10,7 @@ py_library( deps = [ ":continuous_batched", ":continuous_indexed", + ":laplace", ":power_law", ":universal", ], @@ -68,6 +69,21 @@ py_test( ], ) +py_library( + name = "laplace", + srcs = ["laplace.py"], + deps = [ + "//tensorflow_compression/python/ops:gen_ops", + "//tensorflow_compression/python/ops:round_ops", + ], +) + +py_test( + name = "laplace_test", + srcs = ["laplace_test.py"], + deps = [":laplace"], +) + py_library( name = "power_law", srcs = ["power_law.py"], @@ -104,8 +120,3 @@ py_test( "//tensorflow_compression/python/distributions:uniform_noise", ], ) - -filegroup( - name = "py_src", - srcs = glob(["*.py"]), -) diff --git a/tensorflow_compression/python/entropy_models/__init__.py b/tensorflow_compression/python/entropy_models/__init__.py index 21e0685..51619ad 100644 --- a/tensorflow_compression/python/entropy_models/__init__.py +++ b/tensorflow_compression/python/entropy_models/__init__.py @@ -16,5 +16,6 @@ from tensorflow_compression.python.entropy_models.continuous_batched import * from tensorflow_compression.python.entropy_models.continuous_indexed import * +from tensorflow_compression.python.entropy_models.laplace import * from tensorflow_compression.python.entropy_models.power_law import * from tensorflow_compression.python.entropy_models.universal import * diff --git a/tensorflow_compression/python/entropy_models/continuous_batched_test.py b/tensorflow_compression/python/entropy_models/continuous_batched_test.py index c659b2f..acd38b3 100644 --- a/tensorflow_compression/python/entropy_models/continuous_batched_test.py +++ b/tensorflow_compression/python/entropy_models/continuous_batched_test.py @@ -215,7 +215,7 @@ def test_dtypes_are_correct_with_mixed_precision(self): self.assertEqual(bits.shape, (2,)) self.assertAllGreaterEqual(bits, 0.) finally: - tf.keras.mixed_precision.set_global_policy(None) + tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx()) def test_small_cdfs_for_dirac_prior_without_quantization_offset(self): prior = uniform_noise.NoisyNormal(loc=100. * tf.range(16.), scale=1e-10) diff --git a/tensorflow_compression/python/entropy_models/continuous_indexed_test.py b/tensorflow_compression/python/entropy_models/continuous_indexed_test.py index 07843a5..130c29a 100644 --- a/tensorflow_compression/python/entropy_models/continuous_indexed_test.py +++ b/tensorflow_compression/python/entropy_models/continuous_indexed_test.py @@ -220,7 +220,7 @@ def test_dtypes_are_correct_with_mixed_precision(self): self.assertEqual(bits.shape, (2,)) self.assertAllGreaterEqual(bits, 0.) finally: - tf.keras.mixed_precision.set_global_policy(None) + tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx()) class LocationScaleIndexedEntropyModelTest(tf.test.TestCase): diff --git a/tensorflow_compression/python/entropy_models/laplace.py b/tensorflow_compression/python/entropy_models/laplace.py new file mode 100644 index 0000000..c375c07 --- /dev/null +++ b/tensorflow_compression/python/entropy_models/laplace.py @@ -0,0 +1,233 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An entropy model for the run-length rice code.""" + +import tensorflow as tf +from tensorflow_compression.python.ops import gen_ops +from tensorflow_compression.python.ops import round_ops + + +__all__ = [ + "LaplaceEntropyModel", +] + + +class LaplaceEntropyModel(tf.Module): + """Entropy model for Laplace distributed random variables. + + This entropy model handles quantization and compression of a bottleneck tensor + and implements a penalty that encourages compressibility under the Rice + code. + + Given a signed integer, `run_length_encode` encodes zeros using a + run-length code, the sign using a uniform bit, and applies the Rice code to + the magnitude. + + The penalty applied by this class is given by: + ``` + l1 * reduce_sum(abs(x)) + ``` + This encourages `x` to follow a symmetrized laplace distribution. + """ + + def __init__(self, + coding_rank, + l1=0.01, + run_length_code=-1, + magnitude_code=0, + use_run_length_for_non_zeros=False, + bottleneck_dtype=None): + """Initializes the instance. + + Args: + coding_rank: Integer. Number of innermost dimensions considered a coding + unit. Each coding unit is compressed to its own bit string, and the + estimated rate is summed over each coding unit in `bits()`. + l1: Float. L1 regularization factor. + run_length_code: Int. Rice code if >= 0 else Gamma code. + magnitude_code: Int. Rice code if >= 0 else Gamma code. + use_run_length_for_non_zeros: Bool. Whether to encode nonzero run lengths. + bottleneck_dtype: `tf.dtypes.DType`. Data type of bottleneck tensor. + Defaults to `tf.keras.mixed_precision.global_policy().compute_dtype`. + """ + self._coding_rank = int(coding_rank) + if self.coding_rank < 0: + raise ValueError("`coding_rank` must be at least 0.") + self._l1 = float(l1) + if self.l1 <= 0: + raise ValueError("`l1` must be greater than 0.") + self._run_length_code = run_length_code + self._magnitude_code = magnitude_code + self._use_run_length_for_non_zeros = use_run_length_for_non_zeros + if bottleneck_dtype is None: + bottleneck_dtype = tf.keras.mixed_precision.global_policy().compute_dtype + if bottleneck_dtype is None: + bottleneck_dtype = tf.keras.backend.floatx() + self._bottleneck_dtype = tf.as_dtype(bottleneck_dtype) + super().__init__() + + @property + def l1(self): + """L1 parameter.""" + return self._l1 + + @property + def run_length_code(self): + """run_length_code parameter.""" + return self._run_length_code + + @property + def magnitude_code(self): + """magnitude_code parameter.""" + return self._magnitude_code + + @property + def use_run_length_for_non_zeros(self): + """use_run_length_for_non_zeros parameter.""" + return self._use_run_length_for_non_zeros + + @property + def bottleneck_dtype(self): + """Data type of the bottleneck tensor.""" + return self._bottleneck_dtype + + @property + def coding_rank(self): + """Number of innermost dimensions considered a coding unit.""" + return self._coding_rank + + def encode_fn(self, x): + return gen_ops.run_length_encode( + data=x, + run_length_code=self.run_length_code, + magnitude_code=self.magnitude_code, + use_run_length_for_non_zeros=self.use_run_length_for_non_zeros) + + def decode_fn(self, x, shape): + return gen_ops.run_length_decode( + code=x, + shape=shape, + run_length_code=self.run_length_code, + magnitude_code=self.magnitude_code, + use_run_length_for_non_zeros=self.use_run_length_for_non_zeros) + + @tf.Module.with_name_scope + def __call__(self, bottleneck): + """Perturbs a tensor with (quantization) noise and computes penalty. + + Args: + bottleneck: `tf.Tensor` containing the data to be compressed. Must have at + least `self.coding_rank` dimensions. + + Returns: + A tuple `(self.quantize(bottleneck), self.penalty(bottleneck))`. + """ + bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype) + return self.quantize(bottleneck), self.penalty(bottleneck) + + @tf.Module.with_name_scope + def penalty(self, bottleneck): + """Computes penalty encouraging compressibility. + + Args: + bottleneck: `tf.Tensor` containing the data to be compressed. Must have at + least `self.coding_rank` dimensions. + + Returns: + Penalty value, which has the same shape as `bottleneck` without the + `self.coding_rank` innermost dimensions. + """ + bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype) + return self.l1 * tf.reduce_sum(abs(bottleneck), + axis=tuple(range(-self.coding_rank, 0))) + + @tf.Module.with_name_scope + def quantize(self, bottleneck): + """Quantizes a floating-point bottleneck tensor. + + The tensor is rounded to integer values. The gradient of this rounding + operation is overridden with the identity (straight-through gradient + estimator). + + Args: + bottleneck: `tf.Tensor` containing the data to be quantized. + + Returns: + A `tf.Tensor` containing the quantized values. + """ + bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype) + return round_ops.round_st(bottleneck) + + @tf.Module.with_name_scope + def compress(self, bottleneck): + """Compresses a floating-point tensor. + + Compresses the tensor to bit strings. `bottleneck` is first quantized + as in `quantize()`, and then compressed using the run-length rice code. The + quantized tensor can later be recovered by calling `decompress()`. + + The innermost `self.coding_rank` dimensions are treated as one coding unit, + i.e. are compressed into one string each. Any additional dimensions to the + left are treated as batch dimensions. + + Args: + bottleneck: `tf.Tensor` containing the data to be compressed. Must have at + least `self.coding_rank` dimensions. + + Returns: + A `tf.Tensor` having the same shape as `bottleneck` without the + `self.coding_rank` innermost dimensions, containing a string for each + coding unit. + """ + bottleneck = tf.convert_to_tensor(bottleneck, dtype=self.bottleneck_dtype) + + shape = tf.shape(bottleneck) + if self.coding_rank == 0: + flat_shape = [-1] + strings_shape = shape + else: + flat_shape = tf.concat([[-1], shape[-self.coding_rank:]], 0) + strings_shape = shape[:-self.coding_rank] + + symbols = tf.cast(tf.round(bottleneck), tf.int32) + symbols = tf.reshape(symbols, flat_shape) + + strings = tf.map_fn( + self.encode_fn, symbols, + fn_output_signature=tf.TensorSpec((), dtype=tf.string)) + return tf.reshape(strings, strings_shape) + + @tf.Module.with_name_scope + def decompress(self, strings, code_shape): + """Decompresses a tensor. + + Reconstructs the quantized tensor from bit strings produced by `compress()`. + + Args: + strings: `tf.Tensor` containing the compressed bit strings. + code_shape: Shape of innermost dimensions of the output `tf.Tensor`. + + Returns: + A `tf.Tensor` of shape `tf.shape(strings) + code_shape`. + """ + strings = tf.convert_to_tensor(strings, dtype=tf.string) + strings_shape = tf.shape(strings) + symbols = tf.map_fn( + lambda x: self.decode_fn(x, code_shape), + tf.reshape(strings, [-1]), + fn_output_signature=tf.TensorSpec( + [None] * self.coding_rank, dtype=tf.int32)) + symbols = tf.reshape(symbols, tf.concat([strings_shape, code_shape], 0)) + return tf.cast(symbols, self.bottleneck_dtype) diff --git a/tensorflow_compression/python/entropy_models/laplace_test.py b/tensorflow_compression/python/entropy_models/laplace_test.py new file mode 100644 index 0000000..18bbb7a --- /dev/null +++ b/tensorflow_compression/python/entropy_models/laplace_test.py @@ -0,0 +1,122 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests of laplace entropy model.""" + +import numpy as np +import tensorflow as tf +from tensorflow_compression.python.entropy_models.laplace import LaplaceEntropyModel + + +class LaplaceEntropyModelTest(tf.test.TestCase): + + def test_can_instantiate(self): + em = LaplaceEntropyModel(coding_rank=1) + self.assertEqual(em.coding_rank, 1) + self.assertEqual(em.bottleneck_dtype, tf.float32) + + def test_requires_coding_rank_greater_equal_zero(self): + with self.assertRaises(ValueError): + LaplaceEntropyModel(coding_rank=-1) + + def test_quantizes_to_integers(self): + em = LaplaceEntropyModel(coding_rank=1) + x = tf.range(-20., 20.) + x_perturbed = x + tf.random.uniform(x.shape, -.49, .49) + x_quantized = em.quantize(x_perturbed) + self.assertAllEqual(x, x_quantized) + + def test_gradients_are_straight_through(self): + em = LaplaceEntropyModel(coding_rank=1) + x = tf.range(-20., 20.) + x_perturbed = x + tf.random.uniform(x.shape, -.49, .49) + with tf.GradientTape() as tape: + tape.watch(x_perturbed) + x_quantized = em.quantize(x_perturbed) + gradients = tape.gradient(x_quantized, x_perturbed) + self.assertAllEqual(gradients, tf.ones_like(gradients)) + + def test_compression_consistent_with_quantization(self): + em = LaplaceEntropyModel(coding_rank=1) + x = tf.range(-20., 20.) + x += tf.random.uniform(x.shape, -.49, .49) + x_quantized = em.quantize(x) + x_decompressed = em.decompress(em.compress(x), x.shape) + self.assertAllEqual(x_decompressed, x_quantized) + + def test_penalty_is_proportional_to_code_length(self): + em = LaplaceEntropyModel(coding_rank=1) + x = tf.range(-20., 20.)[:, None] + x += tf.random.uniform(x.shape, -.49, .49) + strings = em.compress(tf.broadcast_to(x, (40, 100))) + code_lengths = tf.cast(tf.strings.length(strings, unit="BYTE"), tf.float32) + code_lengths *= 8 / 100 + penalties = em.penalty(x) + # There are some fluctuations due to `alpha`, `floor`, and rounding, but we + # expect a high degree of correlation between code lengths and penalty. + self.assertGreater(np.corrcoef(code_lengths, penalties)[0, 1], .96) + + def test_penalty_is_nonnegative_and_differentiable(self): + em = LaplaceEntropyModel(coding_rank=1) + x = tf.range(-20., 20.)[:, None] + x += tf.random.uniform(x.shape, -.49, .49) + with tf.GradientTape() as tape: + tape.watch(x) + penalties = em.penalty(x) + gradients = tape.gradient(penalties, x) + self.assertAllGreaterEqual(penalties, 0) + self.assertAllEqual(tf.sign(gradients), tf.sign(x)) + + def test_compression_works_in_tf_function(self): + samples = tf.random.stateless_normal([100], (34, 232)) + + # Since tf.function traces each function twice, and only allows variable + # creation in the first call, we need to have a stateful object in which we + # create the entropy model only the first time the function is called, and + # store it for the second time. + + class Compressor: + + def compress(self, values): + if not hasattr(self, "em"): + self.em = LaplaceEntropyModel(coding_rank=1) + compressed = self.em.compress(values) + return self.em.decompress(compressed, [100]) + + values_eager = Compressor().compress(samples) + values_function = tf.function(Compressor().compress)(samples) + self.assertAllClose(samples, values_eager, rtol=0., atol=.5) + self.assertAllEqual(values_eager, values_function) + + def test_dtypes_are_correct_with_mixed_precision(self): + tf.keras.mixed_precision.set_global_policy("mixed_float16") + try: + em = LaplaceEntropyModel(coding_rank=1) + self.assertEqual(em.bottleneck_dtype, tf.float16) + x = tf.random.stateless_normal((2, 5), seed=(0, 1), dtype=tf.float16) + x_tilde, penalty = em(x) + bitstring = em.compress(x) + x_hat = em.decompress(bitstring, (5,)) + self.assertEqual(x_hat.dtype, tf.float16) + self.assertAllClose(x, x_hat, rtol=0, atol=.5) + self.assertEqual(x_tilde.dtype, tf.float16) + self.assertAllClose(x, x_tilde, rtol=0, atol=.5) + self.assertEqual(penalty.dtype, tf.float16) + self.assertEqual(penalty.shape, (2,)) + finally: + tf.keras.mixed_precision.set_global_policy(None) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_compression/python/entropy_models/power_law_test.py b/tensorflow_compression/python/entropy_models/power_law_test.py index 654ca5b..bb65ae8 100644 --- a/tensorflow_compression/python/entropy_models/power_law_test.py +++ b/tensorflow_compression/python/entropy_models/power_law_test.py @@ -115,7 +115,7 @@ def test_dtypes_are_correct_with_mixed_precision(self): self.assertEqual(penalty.dtype, tf.float16) self.assertEqual(penalty.shape, (2,)) finally: - tf.keras.mixed_precision.set_global_policy(None) + tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx()) if __name__ == "__main__": diff --git a/tensorflow_compression/python/layers/BUILD b/tensorflow_compression/python/layers/BUILD index c1f4e29..b832e98 100644 --- a/tensorflow_compression/python/layers/BUILD +++ b/tensorflow_compression/python/layers/BUILD @@ -89,8 +89,3 @@ py_test( "//tensorflow_compression/python/ops:round_ops", ], ) - -filegroup( - name = "py_src", - srcs = glob(["*.py"]), -) diff --git a/tensorflow_compression/python/layers/gdn_test.py b/tensorflow_compression/python/layers/gdn_test.py index 7c5c7e4..1a63bc1 100644 --- a/tensorflow_compression/python/layers/gdn_test.py +++ b/tensorflow_compression/python/layers/gdn_test.py @@ -207,7 +207,7 @@ def test_dtypes_are_correct_with_mixed_precision(self): self.assertEqual(variable.dtype, tf.float32) self.assertEqual(y.dtype, tf.float16) finally: - tf.keras.mixed_precision.set_global_policy(None) + tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx()) if __name__ == "__main__": diff --git a/tensorflow_compression/python/layers/parameters_test.py b/tensorflow_compression/python/layers/parameters_test.py index c7bc76f..2e352ef 100644 --- a/tensorflow_compression/python/layers/parameters_test.py +++ b/tensorflow_compression/python/layers/parameters_test.py @@ -57,7 +57,7 @@ class RDFTParameterTest(ParameterTest, tf.test.TestCase, kwargs = dict(name="rdft_kernel") shape = (3, 3, 1, 2) - # TODO(jonycgn): Find out why 3D RFFT gradients are not implemented in TF. + # TODO(jonarchist): Find out why 3D RFFT gradients are not implemented in TF. @parameterized.parameters((7, 3, 2), (5, 3, 1, 2)) def test_gradients_propagate(self, *shape): initial_value = tf.random.uniform(shape, dtype=tf.float32) diff --git a/tensorflow_compression/python/layers/signal_conv_test.py b/tensorflow_compression/python/layers/signal_conv_test.py index af47a96..cdc1fe7 100644 --- a/tensorflow_compression/python/layers/signal_conv_test.py +++ b/tensorflow_compression/python/layers/signal_conv_test.py @@ -162,7 +162,7 @@ def test_dtypes_are_correct_with_mixed_precision(self): self.assertEqual(variable.dtype, tf.float32) self.assertEqual(y.dtype, tf.float16) finally: - tf.keras.mixed_precision.set_global_policy(None) + tf.keras.mixed_precision.set_global_policy(tf.keras.backend.floatx()) class ConvolutionsTest(tf.test.TestCase): @@ -376,7 +376,7 @@ def run_or_fail(self, method, raise else: try: - with self.assertRaisesRegexp(NotImplementedError, "SignalConv"): + with self.assertRaisesRegex(NotImplementedError, "SignalConv"): method(**args) except: msg = [] diff --git a/tensorflow_compression/python/ops/BUILD b/tensorflow_compression/python/ops/BUILD index 4a0fc39..7f056eb 100644 --- a/tensorflow_compression/python/ops/BUILD +++ b/tensorflow_compression/python/ops/BUILD @@ -68,8 +68,3 @@ py_test( srcs = ["quantization_ops_test.py"], deps = [":gen_ops"], ) - -filegroup( - name = "py_src", - srcs = glob(["*.py"]), -) diff --git a/tensorflow_compression/python/ops/gen_ops.py b/tensorflow_compression/python/ops/gen_ops.py index 0a97c43..8b99462 100644 --- a/tensorflow_compression/python/ops/gen_ops.py +++ b/tensorflow_compression/python/ops/gen_ops.py @@ -32,8 +32,10 @@ "entropy_encode_finalize", "entropy_encode_index", "pmf_to_quantized_cdf", - "run_length_gamma_decode", - "run_length_gamma_encode", + "run_length_decode", + "run_length_encode", + "run_length_gamma_decode", # Deprecated. + "run_length_gamma_encode", # Deprecated. "stochastic_round", ] # pylint:enable=undefined-all-variable diff --git a/tensorflow_compression/python/ops/quantization_ops_test.py b/tensorflow_compression/python/ops/quantization_ops_test.py index dbafa4d..2d6fc27 100644 --- a/tensorflow_compression/python/ops/quantization_ops_test.py +++ b/tensorflow_compression/python/ops/quantization_ops_test.py @@ -71,7 +71,7 @@ def test_rounding_is_unbiased(self): rounded = gen_ops.stochastic_round(replicated, 1., ()) self.assertEqual(rounded.dtype, tf.int32) averaged = tf.reduce_mean(tf.cast(rounded, tf.float32), axis=0) - self.assertAllClose(values, averaged, atol=5e-3, rtol=0) + self.assertAllClose(values, averaged, atol=1e-2, rtol=0) if __name__ == "__main__": diff --git a/tensorflow_compression/python/util/BUILD b/tensorflow_compression/python/util/BUILD index 005ee42..98bfb11 100644 --- a/tensorflow_compression/python/util/BUILD +++ b/tensorflow_compression/python/util/BUILD @@ -22,8 +22,3 @@ py_test( srcs = ["packed_tensors_test.py"], deps = [":packed_tensors"], ) - -filegroup( - name = "py_src", - srcs = glob(["*.py"]), -) diff --git a/tensorflow_compression_ops/MANIFEST.in b/tensorflow_compression_ops/MANIFEST.in new file mode 100644 index 0000000..c672c3a --- /dev/null +++ b/tensorflow_compression_ops/MANIFEST.in @@ -0,0 +1,3 @@ +global-include LICENSE +global-include *.md +recursive-include tensorflow_compression_ops/cc/ *.so diff --git a/tensorflow_compression_ops/README.md b/tensorflow_compression_ops/README.md new file mode 100644 index 0000000..41f63cd --- /dev/null +++ b/tensorflow_compression_ops/README.md @@ -0,0 +1,221 @@ +# TensorFlow Compression Ops + +TensorFlow Compression Ops (TFC-ops) contains data compression ops for +TensorFlow. + +This is a subset package of TensorFlow Compression (TFC) that contains +C++-implemented TensorFlow operations only. For the full TFC package, please +refer to the [TFC homepage](https://github.com/tensorflow/compression/). + + +## Documentation & getting help + +Refer to [the TFC API +documentation](https://www.tensorflow.org/api_docs/python/tfc) for a complete +description of the functions this package implements. + +This subset pockage implements the following functions in the API: + + * `create_range_encoder` + * `create_range_decoder` + * `entropy_decode_channel` + * `entropy_decode_finalize` + * `entropy_decode_index` + * `entropy_encode_channel` + * `entropy_encode_finalize` + * `entropy_encode_index` + * `pmf_to_quantized_cdf` + * `range_decode` (deprecated) + * `range_encode` (deprecated) + * `run_length_decode` + * `run_length_encode` + * `run_length_gamma_decode` (deprecated) + * `run_length_gamma_encode` (deprecated) + * `stochastic_round` + +Please post all questions or comments on +[Discussions](https://github.com/tensorflow/compression/discussions). Only file +[Issues](https://github.com/tensorflow/compression/issues) for actual bugs or +feature requests. On Discussions, you may get a faster answer, and you help +other people find the question or answer more easily later. + + +## Installation + +***Note: Precompiled packages are currently only provided for Linux and +Darwin/Mac OS.*** + +Set up an environment in which you can install precompiled binary Python +packages using the `pip` command. Refer to the +[TensorFlow installation instructions](https://www.tensorflow.org/install/pip) +for more information on how to set up such a Python environment. + +The current version of TensorFlow Compression requires TensorFlow 2. + +### pip + +To install TFC via `pip`, run the following command: + +```bash +python -m pip install tensorflow-compression-ops +``` + +To test that the installation works correctly, you can run the unit tests with: + +```bash +python -m tensorflow_compression_ops.tests.all +``` + +Once the command finishes, you should see a message ```OK (skipped=2)``` or +similar in the last line. + +### Colab + +You can try out TFC live in a [Colab](https://colab.research.google.com/). The +following command installs the latest version of TFC that is compatible with the +installed TensorFlow version. Run it in a cell before executing your Python +code: + +``` +%pip install tensorflow-compression-ops~=$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+\.\d+).*/\1.0/sg') +``` + +Note: The binary packages of TFC are tied to TF with the same minor version +(e.g., TFC 2.9.1 requires TF 2.9.x), and Colab sometimes lags behind a few days +in deploying the latest version of TensorFlow. As a result, using `%pip install +tensorflow-compression-ops` naively might attempt to upgrade TF, which may +create problems. + +### Docker + +To use a Docker container (e.g. on Windows), be sure to install Docker +(e.g., [Docker Desktop](https://www.docker.com/products/docker-desktop)), +use a [TensorFlow Docker image](https://www.tensorflow.org/install/docker), +and then run the `pip install` command inside the Docker container, not on the +host. For instance, you can use a command line like this: + +```bash +docker run tensorflow/tensorflow:latest bash -c \ + "python -m pip install tensorflow-compression-ops && + python -m tensorflow_compression_ops.tests.all" +``` + +This will fetch the TensorFlow Docker image if it's not already cached, install +the pip package and then run the unit tests to confirm that it works. + +### Anaconda + +It seems that [Anaconda](https://www.anaconda.com/distribution/) ships its own +binary version of TensorFlow which is incompatible with our pip package. To +solve this, always install TensorFlow via `pip` rather than `conda`. For +example, this creates an Anaconda environment with CUDA libraries, and then +installs TensorFlow and TensorFlow Compression Ops: + +```bash +conda create --name ENV_NAME python cudatoolkit cudnn +conda activate ENV_NAME +python -m pip install tensorflow-compression-ops +``` + +Depending on the requirements of the `tensorflow` pip package, you may need to +pin the CUDA libraries to specific versions. If you aren't using a GPU, CUDA is +of course not necessary. + + +## Usage + +We recommend importing the library from your Python code as follows: + +```python +import tensorflow as tf +import tensorflow_compression_ops as tfc +``` + + +## Building pip packages + +This section describes the necessary steps to build your own pip packages of +TensorFlow Compression Ops. This may be necessary to install it on platforms for +which we don't provide precompiled binaries (currently only Linux and Darwin). + +To be compatible with the official TensorFlow pip package, the TFC pip package +must be linked against a matching version of the C libraries. For this reason, +to build the official Linux pip packages, we use [these Docker +images](https://hub.docker.com/r/tensorflow/build) and use the same toolchain +that TensorFlow uses. + +Inside the Docker container, the following steps need to be taken: + +1. Clone the `tensorflow/compression` repo from GitHub. +2. Run `tensorflow_compression_ops/build_pip_pkg.sh` inside the cloned repo. + +For example: + +```bash +git clone https://github.com/tensorflow/compression.git /tensorflow_compression +docker run -i --rm \ + -v /tmp/tensorflow_compression_ops:/tmp/tensorflow_compression_ops \ + -v /tensorflow_compression:/tensorflow_compression \ + -w /tensorflow_compression \ + -e "BAZEL_OPT=--config=manylinux_2_17_x86_64" \ + tensorflow/build:latest-python3.10 \ + bash tensorflow_compression_ops/build_pip_pkg.sh /tmp/tensorflow_compression_ops +``` + +For Darwin, the Docker image and specifying the Bazel config is not necessary. +We just build the package like this (note that you may want to create a clean +Python virtual environment to do this): + +```bash +git clone https://github.com/tensorflow/compression.git /tensorflow_compression +cd /tensorflow_compression +BAZEL_OPT="--macos_minimum_os=10.14" bash \ + tensorflow_compression_ops/build_pip_pkg.sh \ + /tmp/tensorflow_compression_ops +``` + +In both cases, the wheel file is created inside `/tmp/tensorflow_compression_ops`. + +To test the created package, first install the resulting wheel file: + +```bash +python -m pip install /tmp/tensorflow_compression_ops/tensorflow_compression_ops-*.whl +``` + +Then run the unit tests (Do not run the tests in the workspace directory where +the `WORKSPACE` file lives. In that case, the Python interpreter would attempt +to import `tensorflow_compression_ops` packages from the source tree, rather +than from the installed package system directory): + +```bash +pushd /tmp +python -m tensorflow_compression_ops.tests.all +popd +``` + +When done, you can uninstall the pip package again: + +```bash +python -m pip uninstall tensorflow-compression-ops +``` + + +## Citation + +If you use this library for research purposes, please cite: + +``` +@software{tfc_github, + author = "Ballé, Jona and Hwang, Sung Jin and Agustsson, Eirikur", + title = "{T}ensor{F}low {C}ompression: Learned Data Compression", + url = "http://github.com/tensorflow/compression", + version = "2.14.1", + year = "2024", +} +``` + +In the above BibTeX entry, names are top contributors sorted by number of +commits. Please adjust version number and year according to the version that was +actually used. + +Note that this is not an officially supported Google product. diff --git a/tensorflow_compression_ops/__init__.py b/tensorflow_compression_ops/__init__.py new file mode 100644 index 0000000..3af1d5a --- /dev/null +++ b/tensorflow_compression_ops/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Data compression ops in TensorFlow.""" + +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader + +ops = load_library.load_op_library(resource_loader.get_path_to_datafile( + "cc/libtensorflow_compression.so")) + +globals().update({n: getattr(ops, n) for n in [ + "create_range_encoder", + "create_range_decoder", + "entropy_decode_channel", + "entropy_decode_finalize", + "entropy_decode_index", + "entropy_encode_channel", + "entropy_encode_finalize", + "entropy_encode_index", + "pmf_to_quantized_cdf", + "range_decode", # Deprecated. + "range_encode", # Deprecated. + "run_length_decode", + "run_length_encode", + "run_length_gamma_decode", # Deprecated. + "run_length_gamma_encode", # Deprecated. + "stochastic_round", +]}) + diff --git a/tensorflow_compression_ops/build_pip_pkg.py b/tensorflow_compression_ops/build_pip_pkg.py new file mode 100644 index 0000000..7cfde72 --- /dev/null +++ b/tensorflow_compression_ops/build_pip_pkg.py @@ -0,0 +1,104 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Setup for pip package.""" + +import atexit +import glob +import os +import shutil +import sys +import tempfile +import setuptools + +# Version string should follow PEP440 rules. +DEFAULT_VERSION = "0.dev0+build-from-source" + + +class BinaryDistribution(setuptools.Distribution): + """This class is needed in order to create OS specific wheels.""" + + def has_ext_modules(self): + return True + + +def main(srcdir: str, destdir: str, version: str = ""): + tempdir = tempfile.mkdtemp() + atexit.register(shutil.rmtree, tempdir) + + pkgdir = os.path.join(tempdir, "tensorflow_compression_ops") + shutil.copytree(os.path.join(srcdir, "tensorflow_compression_ops"), pkgdir) + shutil.copy2(os.path.join(srcdir, "MANIFEST.in"), tempdir) + shutil.copy2(os.path.join(srcdir, "LICENSE"), pkgdir) + shutil.copy2(os.path.join(srcdir, "README.md"), pkgdir) + + if not os.path.exists( + os.path.join(pkgdir, "cc/libtensorflow_compression.so")): + raise RuntimeError("libtensorflow_compression.so not found. " + "Did you 'bazel run?'") + + with open(os.path.join(srcdir, "requirements.txt"), "r") as f: + install_requires = f.readlines() + + print("=== Building wheel") + atexit.register(os.chdir, os.getcwd()) + os.chdir(tempdir) + setuptools.setup( + name="tensorflow_compression_ops", + version=version or DEFAULT_VERSION, + description="Data compression ops for TensorFlow", + url="https://tensorflow.github.io/compression/", + author="Google LLC", + # Contained modules and scripts. + packages=setuptools.find_packages(), + install_requires=install_requires, + script_args=["sdist", "bdist_wheel"], + # Add in any packaged data. + include_package_data=True, + zip_safe=False, + distclass=BinaryDistribution, + # PyPI package information. + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Software Development :: Libraries", + ], + project_urls={ + "Documentation": + "https://tensorflow.github.io/compression/docs/api_docs/python/tfc.html", + "Discussion": + "https://groups.google.com/forum/#!forum/tensorflow-compression", + "Source": "https://github.com/tensorflow/compression", + "Tracker": "https://github.com/tensorflow/compression/issues", + }, + license="Apache 2.0", + keywords=("compression data-compression tensorflow machine-learning " + "python deep-learning deep-neural-networks neural-network ml") + ) + + print("=== Copying wheel to:", destdir) + os.makedirs(destdir, exist_ok=True) + for path in glob.glob(os.path.join(tempdir, "dist", "*.whl")): + print("Copied into:", shutil.copy(path, destdir)) + + +if __name__ == "__main__": + main(*sys.argv[1:]) # pylint: disable=too-many-function-args + diff --git a/tensorflow_compression_ops/build_pip_pkg.sh b/tensorflow_compression_ops/build_pip_pkg.sh new file mode 100644 index 0000000..8845cfd --- /dev/null +++ b/tensorflow_compression_ops/build_pip_pkg.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script must run at the workspace root directory. + +set -ex # Fail if any command fails, echo commands. + +# Script configuration -------------------------------------------------------- +OUTPUT_DIR="${1-/tmp/tensorflow_compression_ops}" +WHEEL_VERSION=${2-0.dev0} + +# Optionally exported environment variables. +: ${BAZEL_OPT:=} +# ----------------------------------------------------------------------------- + +python -m pip install -U pip setuptools wheel +python -m pip install -r tensorflow_compression_ops/requirements.txt +bazel build ${BAZEL_OPT} -c opt --copt=-mavx tensorflow_compression/cc:libtensorflow_compression.so + +SRCDIR="$(mktemp -d)" +trap 'rm -r -- "${SRCDIR}"' EXIT + +cp LICENSE "${SRCDIR}" +cp tensorflow_compression_ops/README.md "${SRCDIR}" +cp tensorflow_compression_ops/MANIFEST.in "${SRCDIR}" +cp tensorflow_compression_ops/requirements.txt "${SRCDIR}" + +mkdir -p "${SRCDIR}/tensorflow_compression_ops" +cp tensorflow_compression_ops/__init__.py "${SRCDIR}/tensorflow_compression_ops/__init__.py" + +mkdir -p "${SRCDIR}/tensorflow_compression_ops/cc" +cp "$(bazel info -c opt bazel-genfiles)/tensorflow_compression/cc/libtensorflow_compression.so" \ + "${SRCDIR}/tensorflow_compression_ops/cc" + +mkdir -p "${SRCDIR}/tensorflow_compression_ops/tests" +touch "${SRCDIR}/tensorflow_compression_ops/tests/__init__.py" +cp tensorflow_compression_ops/tests_all.py "${SRCDIR}/tensorflow_compression_ops/tests/all.py" +sed -e "s/from tensorflow_compression.python.ops import gen_ops/import tensorflow_compression_ops as gen_ops/" \ + tensorflow_compression/python/ops/quantization_ops_test.py \ + > "${SRCDIR}/tensorflow_compression_ops/tests/quantization_ops_test.py" +sed -e "s/from tensorflow_compression.python.ops import gen_ops/import tensorflow_compression_ops as gen_ops/" \ + tensorflow_compression/python/ops/range_coding_ops_test.py \ + > "${SRCDIR}/tensorflow_compression_ops/tests/range_coding_ops_test.py" + +python tensorflow_compression_ops/build_pip_pkg.py "${SRCDIR}" "${OUTPUT_DIR}" "${WHEEL_VERSION}" diff --git a/tensorflow_compression_ops/requirements.txt b/tensorflow_compression_ops/requirements.txt new file mode 100644 index 0000000..743a2c5 --- /dev/null +++ b/tensorflow_compression_ops/requirements.txt @@ -0,0 +1 @@ +tensorflow ~= 2.18.0 diff --git a/tensorflow_compression_ops/test_pip_pkg.sh b/tensorflow_compression_ops/test_pip_pkg.sh new file mode 100644 index 0000000..ac49ae0 --- /dev/null +++ b/tensorflow_compression_ops/test_pip_pkg.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex # Fail if any command fails, echo commands. + +WHEEL="${1}" + +# `import tensorflow_compression` in the bazel root directory produces cryptic +# error messages, because Python ends up looking for .so files under the +# subdirectories in the src repo instead of Python module libraries. Changing +# the current directory helps avoid running tests inside the bazel root +# direcotory by accident. +pushd /tmp + +python -m pip install -U pip setuptools wheel +python -m pip install "${WHEEL}" +python -m pip list -v + +# Logs elements of tfc namespace and runs unit tests. +python -c "import tensorflow_compression_ops as tfc; print('\n'.join(sorted(dir(tfc))))" +python -m tensorflow_compression_ops.tests.all + +popd # /tmp diff --git a/tensorflow_compression_ops/tests_all.py b/tensorflow_compression_ops/tests_all.py new file mode 100644 index 0000000..5bc5e88 --- /dev/null +++ b/tensorflow_compression_ops/tests_all.py @@ -0,0 +1,29 @@ +# Copyright 2024 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All Python tests for tensorflow_compression_ops. + +This is a convenience file to be included in the pip package. +""" + +import tensorflow as tf + +# pylint: disable=wildcard-import +from tensorflow_compression_ops.tests.quantization_ops_test import * +from tensorflow_compression_ops.tests.range_coding_ops_test import * +# pylint: enable=wildcard-import + + +if __name__ == "__main__": + tf.test.main() diff --git a/build_pip_pkg.py b/tools/build_pip_pkg.py similarity index 100% rename from build_pip_pkg.py rename to tools/build_pip_pkg.py diff --git a/tools/build_pip_pkg.sh b/tools/build_pip_pkg.sh new file mode 100644 index 0000000..144e9ca --- /dev/null +++ b/tools/build_pip_pkg.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash +# Copyright 2023 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# This script must run at the workspace root directory. + +set -ex # Fail if any command fails, echo commands. + +# Script configuration -------------------------------------------------------- +OUTPUT_DIR="${1-/tmp/tensorflow_compression}" +WHEEL_VERSION=${2-0.dev0} + +# Optionally exported environment variables. +: ${BAZEL_OPT:=} +# ----------------------------------------------------------------------------- + +python -m pip install -U pip setuptools wheel +python -m pip install -r requirements.txt +bazel build ${BAZEL_OPT} -c opt --copt=-mavx tensorflow_compression/cc:libtensorflow_compression.so + +SRCDIR="$(mktemp -d)" +trap 'rm -r -- "${SRCDIR}"' EXIT + +cp LICENSE README.md MANIFEST.in requirements.txt "${SRCDIR}" + +mkdir -p "${SRCDIR}/tensorflow_compression/cc" +cp "$(bazel info -c opt bazel-genfiles)/tensorflow_compression/cc/libtensorflow_compression.so" \ + "${SRCDIR}/tensorflow_compression/cc" + + +copy_file() { + local FILENAME="${1#./}" + local DST="${SRCDIR%/}/$(dirname "${FILENAME}")" + mkdir -p "${DST}" + cp "${FILENAME}" "${DST}" +} + +copy_file "tensorflow_compression/__init__.py" +copy_file "tensorflow_compression/all_tests.py" + +# Assumes no irregular characters in the filenames. +find tensorflow_compression/python -name "*.py" \ + | while read filename; do copy_file "${filename}"; done + +python tools/build_pip_pkg.py "${SRCDIR}" "${OUTPUT_DIR}" "${WHEEL_VERSION}" diff --git a/tools/test_pip_pkg.sh b/tools/test_pip_pkg.sh new file mode 100644 index 0000000..8833a97 --- /dev/null +++ b/tools/test_pip_pkg.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Copyright 2023 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex # Fail if any command fails, echo commands. + +WHEEL="${1}" + +# `import tensorflow_compression` in the bazel root directory produces cryptic +# error messages, because Python ends up looking for .so files under the +# subdirectories in the src repo instead of Python module libraries. Changing +# the current directory helps avoid running tests inside the bazel root +# direcotory by accident. +pushd /tmp + +python -m pip install -U pip setuptools wheel +python -m pip install "${WHEEL}" +python -m pip list -v + +# Logs elements of tfc namespace and runs unit tests. +python -c "import tensorflow_compression as tfc; print('\n'.join(sorted(dir(tfc))))" +python -m tensorflow_compression.all_tests + +popd # /tmp