diff --git a/README.md b/README.md index b7df65e..a0aae0c 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,24 @@ # C++/CUDA Extensions in PyTorch -An example of writing a C++ extension for PyTorch. See +An example of writing a C++/CUDA extension for PyTorch. See [here](http://pytorch.org/tutorials/advanced/cpp_extension.html) for the accompanying tutorial. +This repo demonstrates how to write an example `extension_cpp.ops.lltm` +custom op that has both custom CPU and CUDA kernels. -There are a few "sights" you can metaphorically visit in this repository: +To build: +``` +pip install . +``` -- Inspect the C++ and CUDA extensions in the `cpp/` and `cuda/` folders, -- Build C++ and/or CUDA extensions by going into the `cpp/` or `cuda/` folder and executing `python setup.py install`, -- JIT-compile C++ and/or CUDA extensions by going into the `cpp/` or `cuda/` folder and calling `python jit.py`, which will JIT-compile the extension and load it, -- Benchmark Python vs. C++ vs. CUDA by running `python benchmark.py {py, cpp, cuda} [--cuda]`, -- Run gradient checks on the code by running `python grad_check.py {py, cpp, cuda} [--cuda]`. -- Run output checks on the code by running `python check.py {forward, backward} [--cuda]`. +To test: +``` +python test/test_extension.py +``` + +To benchmark Python vs. C++ vs. CUDA: +``` +python test/benchmark.py +``` ## Authors diff --git a/benchmark.py b/benchmark.py deleted file mode 100644 index 212da08..0000000 --- a/benchmark.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import division -from __future__ import print_function - -import argparse -import math -import time - -import torch - -TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} - -parser = argparse.ArgumentParser() -parser.add_argument('example', choices=['py', 'cpp', 'cuda']) -parser.add_argument('-b', '--batch-size', type=int, default=16) -parser.add_argument('-f', '--features', type=int, default=32) -parser.add_argument('-s', '--state-size', type=int, default=128) -parser.add_argument('-r', '--runs', type=int, default=100) -parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') -parser.add_argument('-c', '--cuda', action='store_true') -parser.add_argument('-d', '--double', action='store_true') -options = parser.parse_args() - -if options.example == 'py': - from python.lltm import LLTM -elif options.example == 'cpp': - from cpp.lltm import LLTM -else: - from cuda.lltm import LLTM - options.cuda = True - -device = torch.device("cuda") if options.cuda else torch.device("cpu") -dtype = torch.float64 if options.double else torch.float32 - -kwargs = {'dtype': dtype, - 'device': device, - 'requires_grad': True} -X = torch.randn(options.batch_size, options.features, **kwargs) -h = torch.randn(options.batch_size, options.state_size, **kwargs) -C = torch.randn(options.batch_size, options.state_size, **kwargs) -rnn = LLTM(options.features, options.state_size).to(device, dtype) - -# Force CUDA initialization -new_h, new_C = rnn(X, (h, C)) -(new_h.sum() + new_C.sum()).backward() - -forward_min = math.inf -forward_time = 0 -backward_min = math.inf -backward_time = 0 -for _ in range(options.runs): - rnn.zero_grad() - - start = time.time() - new_h, new_C = rnn(X, (h, C)) - elapsed = time.time() - start - forward_min = min(forward_min, elapsed) - forward_time += elapsed - - start = time.time() - (new_h.sum() + new_C.sum()).backward() - elapsed = time.time() - start - backward_min = min(backward_min, elapsed) - backward_time += elapsed - -scale = TIME_SCALES[options.scale] -forward_min *= scale -backward_min *= scale -forward_average = forward_time / options.runs * scale -backward_average = backward_time / options.runs * scale - -print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format( - forward_min, forward_average, backward_min, backward_average, - options.scale)) diff --git a/check.py b/check.py deleted file mode 100644 index 8fad6d1..0000000 --- a/check.py +++ /dev/null @@ -1,107 +0,0 @@ -from __future__ import division -from __future__ import print_function - -import argparse -import numpy as np -import torch - -import python.lltm_baseline -import cpp.lltm - - -def check_equal(first, second, verbose): - if verbose: - print() - for i, (x, y) in enumerate(zip(first, second)): - x = x.cpu().detach().numpy() - y = y.cpu().detach().numpy() - if verbose: - print("x = {}".format(x.flatten())) - print("y = {}".format(y.flatten())) - print('-' * 80) - np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i)) - - -def zero_grad(variables): - for variable in variables: - variable.grad.zero_() - - -def get_grads(variables): - return [var.grad.clone() for var in variables] - - -def check_forward(variables, with_cuda, verbose): - baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables) - cpp_values = cpp.lltm.LLTMFunction.apply(*variables) - - print('Forward: Baseline (Python) vs. C++ ... ', end='') - check_equal(baseline_values, cpp_values, verbose) - print('Ok') - - if with_cuda: - cuda_values = cuda.lltm.LLTMFunction.apply(*variables) - print('Forward: Baseline (Python) vs. CUDA ... ', end='') - check_equal(baseline_values, cuda_values, verbose) - print('Ok') - - -def check_backward(variables, with_cuda, verbose): - baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables) - (baseline_values[0] + baseline_values[1]).sum().backward() - grad_baseline = get_grads(variables) - - zero_grad(variables) - - cpp_values = cpp.lltm.LLTMFunction.apply(*variables) - (cpp_values[0] + cpp_values[1]).sum().backward() - grad_cpp = get_grads(variables) - - print('Backward: Baseline (Python) vs. C++ ... ', end='') - check_equal(grad_baseline, grad_cpp, verbose) - print('Ok') - - if with_cuda: - zero_grad(variables) - cuda_values = cuda.lltm.LLTMFunction.apply(*variables) - (cuda_values[0] + cuda_values[1]).sum().backward() - grad_cuda = get_grads(variables) - - print('Backward: Baseline (Python) vs. CUDA ... ', end='') - check_equal(grad_baseline, grad_cuda, verbose) - print('Ok') - - -parser = argparse.ArgumentParser() -parser.add_argument('direction', choices=['forward', 'backward'], nargs='+') -parser.add_argument('-b', '--batch-size', type=int, default=3) -parser.add_argument('-f', '--features', type=int, default=17) -parser.add_argument('-s', '--state-size', type=int, default=5) -parser.add_argument('-c', '--cuda', action='store_true') -parser.add_argument('-v', '--verbose', action='store_true') -options = parser.parse_args() - -if options.cuda: - import cuda.lltm - device = torch.device("cuda") -else: - device = torch.device("cpu") - -kwargs = {'dtype': torch.float64, - 'device': device, - 'requires_grad': True} -X = torch.randn(options.batch_size, - options.features, - **kwargs) -h = torch.randn(options.batch_size, options.state_size, **kwargs) -C = torch.randn(options.batch_size, options.state_size, **kwargs) -W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs) -b = torch.randn(1, 3 * options.state_size, **kwargs) - -variables = [X, W, b, h, C] - -if 'forward' in options.direction: - check_forward(variables, options.cuda, options.verbose) - -if 'backward' in options.direction: - check_backward(variables, options.cuda, options.verbose) diff --git a/cpp/__init__.py b/cpp/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/cpp/jit.py b/cpp/jit.py deleted file mode 100644 index d1c9c5d..0000000 --- a/cpp/jit.py +++ /dev/null @@ -1,3 +0,0 @@ -from torch.utils.cpp_extension import load -lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"], verbose=True) -help(lltm_cpp) diff --git a/cpp/lltm.py b/cpp/lltm.py deleted file mode 100644 index 24cf82d..0000000 --- a/cpp/lltm.py +++ /dev/null @@ -1,44 +0,0 @@ -import math -from torch import nn -from torch.autograd import Function -import torch - -import lltm_cpp - -torch.manual_seed(42) - - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell) - new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward( - grad_h, grad_cell, *ctx.saved_variables) - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - -class LLTM(nn.Module): - def __init__(self, input_features, state_size): - super(LLTM, self).__init__() - self.input_features = input_features - self.state_size = state_size - self.weights = nn.Parameter( - torch.Tensor(3 * state_size, input_features + state_size)) - self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.state_size) - for weight in self.parameters(): - weight.data.uniform_(-stdv, +stdv) - - def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/cpp/setup.py b/cpp/setup.py deleted file mode 100644 index 7a4c164..0000000 --- a/cpp/setup.py +++ /dev/null @@ -1,11 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CppExtension - -setup( - name='lltm_cpp', - ext_modules=[ - CppExtension('lltm_cpp', ['lltm.cpp']), - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/cuda/__init__.py b/cuda/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/cuda/jit.py b/cuda/jit.py deleted file mode 100644 index 6c52eff..0000000 --- a/cuda/jit.py +++ /dev/null @@ -1,4 +0,0 @@ -from torch.utils.cpp_extension import load -lltm_cuda = load( - 'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True) -help(lltm_cuda) diff --git a/cuda/lltm.py b/cuda/lltm.py deleted file mode 100644 index c740b88..0000000 --- a/cuda/lltm.py +++ /dev/null @@ -1,45 +0,0 @@ -import math -from torch import nn -from torch.autograd import Function -import torch - -import lltm_cuda - -torch.manual_seed(42) - - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell) - new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights] - ctx.save_for_backward(*variables) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - outputs = lltm_cuda.backward( - grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) - d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - -class LLTM(nn.Module): - def __init__(self, input_features, state_size): - super(LLTM, self).__init__() - self.input_features = input_features - self.state_size = state_size - self.weights = nn.Parameter( - torch.Tensor(3 * state_size, input_features + state_size)) - self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.state_size) - for weight in self.parameters(): - weight.data.uniform_(-stdv, +stdv) - - def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp deleted file mode 100644 index 2434776..0000000 --- a/cuda/lltm_cuda.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include - -#include - -// CUDA forward declarations - -std::vector lltm_cuda_forward( - torch::Tensor input, - torch::Tensor weights, - torch::Tensor bias, - torch::Tensor old_h, - torch::Tensor old_cell); - -std::vector lltm_cuda_backward( - torch::Tensor grad_h, - torch::Tensor grad_cell, - torch::Tensor new_cell, - torch::Tensor input_gate, - torch::Tensor output_gate, - torch::Tensor candidate_cell, - torch::Tensor X, - torch::Tensor gate_weights, - torch::Tensor weights); - -// C++ interface - -// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector lltm_forward( - torch::Tensor input, - torch::Tensor weights, - torch::Tensor bias, - torch::Tensor old_h, - torch::Tensor old_cell) { - CHECK_INPUT(input); - CHECK_INPUT(weights); - CHECK_INPUT(bias); - CHECK_INPUT(old_h); - CHECK_INPUT(old_cell); - - return lltm_cuda_forward(input, weights, bias, old_h, old_cell); -} - -std::vector lltm_backward( - torch::Tensor grad_h, - torch::Tensor grad_cell, - torch::Tensor new_cell, - torch::Tensor input_gate, - torch::Tensor output_gate, - torch::Tensor candidate_cell, - torch::Tensor X, - torch::Tensor gate_weights, - torch::Tensor weights) { - CHECK_INPUT(grad_h); - CHECK_INPUT(grad_cell); - CHECK_INPUT(input_gate); - CHECK_INPUT(output_gate); - CHECK_INPUT(candidate_cell); - CHECK_INPUT(X); - CHECK_INPUT(gate_weights); - CHECK_INPUT(weights); - - return lltm_cuda_backward( - grad_h, - grad_cell, - new_cell, - input_gate, - output_gate, - candidate_cell, - X, - gate_weights, - weights); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); - m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); -} diff --git a/cuda/setup.py b/cuda/setup.py deleted file mode 100644 index 670b3c8..0000000 --- a/cuda/setup.py +++ /dev/null @@ -1,14 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name='lltm_cuda', - ext_modules=[ - CUDAExtension('lltm_cuda', [ - 'lltm_cuda.cpp', - 'lltm_cuda_kernel.cu', - ]), - ], - cmdclass={ - 'build_ext': BuildExtension - }) diff --git a/extension_cpp/__init__.py b/extension_cpp/__init__.py new file mode 100644 index 0000000..769c697 --- /dev/null +++ b/extension_cpp/__init__.py @@ -0,0 +1,2 @@ +import torch +from . import _C, ops diff --git a/cuda/lltm_cuda_kernel.cu b/extension_cpp/csrc/cuda/lltm_cuda.cu similarity index 89% rename from cuda/lltm_cuda_kernel.cu rename to extension_cpp/csrc/cuda/lltm_cuda.cu index 02bb9ad..7612b84 100644 --- a/cuda/lltm_cuda_kernel.cu +++ b/extension_cpp/csrc/cuda/lltm_cuda.cu @@ -94,7 +94,7 @@ __global__ void lltm_cuda_backward_kernel( } } // namespace -std::vector lltm_cuda_forward( +std::tuple lltm_cuda_forward( torch::Tensor input, torch::Tensor weights, torch::Tensor bias, @@ -130,7 +130,7 @@ std::vector lltm_cuda_forward( return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; } -std::vector lltm_cuda_backward( +std::tuple lltm_cuda_backward( torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, @@ -143,6 +143,9 @@ std::vector lltm_cuda_backward( auto d_old_cell = torch::zeros_like(new_cell); auto d_gates = torch::zeros_like(gates); + auto grad_h_contig = grad_h.contiguous(); + auto grad_cell_contig = grad_cell.contiguous(); + const auto batch_size = new_cell.size(0); const auto state_size = new_cell.size(1); @@ -153,8 +156,8 @@ std::vector lltm_cuda_backward( lltm_cuda_backward_kernel<<>>( d_old_cell.packed_accessor(), d_gates.packed_accessor(), - grad_h.packed_accessor(), - grad_cell.packed_accessor(), + grad_h_contig.packed_accessor(), + grad_cell_contig.packed_accessor(), new_cell.packed_accessor(), input_gate.packed_accessor(), output_gate.packed_accessor(), @@ -170,5 +173,11 @@ std::vector lltm_cuda_backward( auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); auto d_input = d_X.slice(/*dim=*/1, state_size); - return {d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates}; + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; +} + +// Registers CUDA implementations for lltm_forward, lltm_backward +TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("lltm_forward", &lltm_cuda_forward); + m.impl("lltm_backward", &lltm_cuda_backward); } diff --git a/cpp/lltm.cpp b/extension_cpp/csrc/lltm.cpp similarity index 71% rename from cpp/lltm.cpp rename to extension_cpp/csrc/lltm.cpp index 9bdfe0c..c3bb5bc 100644 --- a/cpp/lltm.cpp +++ b/extension_cpp/csrc/lltm.cpp @@ -20,7 +20,7 @@ torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); } -std::vector lltm_forward( +std::tuple lltm_forward( torch::Tensor input, torch::Tensor weights, torch::Tensor bias, @@ -47,7 +47,7 @@ std::vector lltm_forward( gate_weights}; } -std::vector lltm_backward( +std::tuple lltm_backward( torch::Tensor grad_h, torch::Tensor grad_cell, torch::Tensor new_cell, @@ -84,7 +84,17 @@ std::vector lltm_backward( return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &lltm_forward, "LLTM forward"); - m.def("backward", &lltm_backward, "LLTM backward"); +// Registers _C as an extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +// Defines the operators +TORCH_LIBRARY(extension_cpp, m) { + m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); +} + +// Registers CPU implementations for lltm_forward, lltm_backward +TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("lltm_forward", &lltm_forward); + m.impl("lltm_backward", &lltm_backward); } diff --git a/extension_cpp/ops.py b/extension_cpp/ops.py new file mode 100644 index 0000000..d0a2a9f --- /dev/null +++ b/extension_cpp/ops.py @@ -0,0 +1,61 @@ +from typing import Tuple +import torch +from torch import Tensor + +__all__ = ["lltm", "reference_lltm"] + + +def lltm( + input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor +) -> Tuple[Tensor, Tensor]: + return LLTMFunction.apply(input, weights, bias, old_h, old_cell) + + +class LLTMFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weights, bias, old_h, old_cell): + outputs = torch.ops.extension_cpp.lltm_forward.default( + input, weights, bias, old_h, old_cell + ) + new_h, new_cell = outputs[:2] + variables = list(outputs[1:]) + [weights] + ctx.save_for_backward(*variables) + + return new_h, new_cell + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_h, grad_cell): + ( + d_old_h, + d_input, + d_weights, + d_bias, + d_old_cell, + ) = torch.ops.extension_cpp.lltm_backward.default( + grad_h, grad_cell, *ctx.saved_tensors + ) + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +def reference_lltm( + input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor +) -> Tuple[Tensor, Tensor]: + X = torch.cat([old_h, input], dim=1) + + # Compute the input, output and candidate cell gates with one MM. + gate_weights = torch.nn.functional.linear(X, weights, bias) + # Split the combined gate weight matrix into its components. + gates = gate_weights.chunk(3, dim=1) + + input_gate = torch.sigmoid(gates[0]) + output_gate = torch.sigmoid(gates[1]) + # Here we use an ELU instead of the usual tanh. + candidate_cell = torch.nn.functional.elu(gates[2]) + + # Compute the new cell state. + new_cell = old_cell + candidate_cell * input_gate + # Compute the new hidden state and output. + new_h = torch.tanh(new_cell) * output_gate + + return new_h, new_cell diff --git a/grad_check.py b/grad_check.py deleted file mode 100644 index caf3b36..0000000 --- a/grad_check.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import division -from __future__ import print_function - -import argparse -import torch -from torch.autograd import gradcheck - -parser = argparse.ArgumentParser() -parser.add_argument('example', choices=['py', 'cpp', 'cuda']) -parser.add_argument('-b', '--batch-size', type=int, default=3) -parser.add_argument('-f', '--features', type=int, default=17) -parser.add_argument('-s', '--state-size', type=int, default=5) -parser.add_argument('-c', '--cuda', action='store_true') -options = parser.parse_args() - -if options.example == 'py': - from python.lltm_baseline import LLTMFunction -elif options.example == 'cpp': - from cpp.lltm import LLTMFunction -else: - from cuda.lltm import LLTMFunction - options.cuda = True - -device = torch.device("cuda") if options.cuda else torch.device("cpu") - -kwargs = {'dtype': torch.float64, - 'device': device, - 'requires_grad': True} - -X = torch.randn(options.batch_size, options.features, **kwargs) -h = torch.randn(options.batch_size, options.state_size, **kwargs) -C = torch.randn(options.batch_size, options.state_size, **kwargs) -W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs) -b = torch.randn(1, 3 * options.state_size, **kwargs) - -variables = [X, W, b, h, C] - - -if gradcheck(LLTMFunction.apply, variables): - print('Ok') diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..918072e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools", + "torch", +] +build-backend = "setuptools.build_meta" diff --git a/python/__init__.py b/python/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/python/lltm.py b/python/lltm.py deleted file mode 100644 index 6766d08..0000000 --- a/python/lltm.py +++ /dev/null @@ -1,44 +0,0 @@ -import math -import torch -import torch.nn.functional as F - -torch.manual_seed(42) - - -class LLTM(torch.nn.Module): - def __init__(self, input_features, state_size): - super(LLTM, self).__init__() - self.input_features = input_features - self.state_size = state_size - # 3 * state_size for input gate, output gate and candidate cell gate. - # input_features + state_size because we will multiply with [input, h]. - self.weights = torch.nn.Parameter( - torch.Tensor(3 * state_size, input_features + state_size)) - self.bias = torch.nn.Parameter(torch.Tensor(1, 3 * state_size)) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.state_size) - for weight in self.parameters(): - weight.data.uniform_(-stdv, +stdv) - - def forward(self, input, state): - old_h, old_cell = state - X = torch.cat([old_h, input], dim=1) - - # Compute the input, output and candidate cell gates with one MM. - gate_weights = F.linear(X, self.weights, self.bias) - # Split the combined gate weight matrix into its components. - gates = gate_weights.chunk(3, dim=1) - - input_gate = torch.sigmoid(gates[0]) - output_gate = torch.sigmoid(gates[1]) - # Here we use an ELU instead of the usual tanh. - candidate_cell = F.elu(gates[2]) - - # Compute the new cell state. - new_cell = old_cell + candidate_cell * input_gate - # Compute the new hidden state and output. - new_h = torch.tanh(new_cell) * output_gate - - return new_h, new_cell diff --git a/python/lltm_baseline.py b/python/lltm_baseline.py deleted file mode 100644 index 61bf328..0000000 --- a/python/lltm_baseline.py +++ /dev/null @@ -1,98 +0,0 @@ -import math - -from torch import nn -from torch.autograd import Function -import torch -import torch.nn.functional as F - -torch.manual_seed(42) - - -def d_sigmoid(z): - s = torch.sigmoid(z) - return (1 - s) * s - - -def d_tanh(z): - t = torch.tanh(z) - return 1 - (t * t) - - -def d_elu(z, alpha=1.0): - e = z.exp() - mask = (alpha * (e - 1)) < 0 - return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e) - - -class LLTMFunction(Function): - @staticmethod - def forward(ctx, input, weights, bias, old_h, old_cell): - X = torch.cat([old_h, input], dim=1) - - gate_weights = F.linear(X, weights, bias) - gates = gate_weights.chunk(3, dim=1) - - input_gate = torch.sigmoid(gates[0]) - output_gate = torch.sigmoid(gates[1]) - candidate_cell = F.elu(gates[2]) - - new_cell = old_cell + candidate_cell * input_gate - new_h = torch.tanh(new_cell) * output_gate - - ctx.save_for_backward(X, weights, input_gate, output_gate, old_cell, - new_cell, candidate_cell, gate_weights) - - return new_h, new_cell - - @staticmethod - def backward(ctx, grad_h, grad_cell): - X, weights, input_gate, output_gate, old_cell = ctx.saved_variables[:5] - new_cell, candidate_cell, gate_weights = ctx.saved_variables[5:] - - d_input = d_weights = d_bias = d_old_h = d_old_cell = None - - d_output_gate = torch.tanh(new_cell) * grad_h - d_tanh_new_cell = output_gate * grad_h - d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell - - d_old_cell = d_new_cell - d_candidate_cell = input_gate * d_new_cell - d_input_gate = candidate_cell * d_new_cell - - gates = gate_weights.chunk(3, dim=1) - d_input_gate *= d_sigmoid(gates[0]) - d_output_gate *= d_sigmoid(gates[1]) - d_candidate_cell *= d_elu(gates[2]) - - d_gates = torch.cat( - [d_input_gate, d_output_gate, d_candidate_cell], dim=1) - - if ctx.needs_input_grad[1]: - d_weights = d_gates.t().mm(X) - if ctx.needs_input_grad[2]: - d_bias = d_gates.sum(dim=0, keepdim=True) - if ctx.needs_input_grad[3] or ctx.needs_input_grad[4]: - d_X = d_gates.mm(weights) - state_size = grad_h.shape[1] - d_old_h, d_input = d_X[:, :state_size], d_X[:, state_size:] - - return d_input, d_weights, d_bias, d_old_h, d_old_cell - - -class LLTM(nn.Module): - def __init__(self, input_features, state_size): - super(LLTM, self).__init__() - self.input_features = input_features - self.state_size = state_size - self.weights = nn.Parameter( - torch.Tensor(3 * state_size, input_features + state_size)) - self.bias = nn.Parameter(torch.Tensor(1, 3 * state_size)) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.state_size) - for weight in self.parameters(): - weight.data.uniform_(-stdv, +stdv) - - def forward(self, input, state): - return LLTMFunction.apply(input, self.weights, self.bias, *state) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1408cb3 --- /dev/null +++ b/setup.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import glob + +from setuptools import find_packages, setup + +from torch.utils.cpp_extension import ( + CppExtension, + CUDAExtension, + BuildExtension, + CUDA_HOME, +) + +library_name = "extension_cpp" + + +def get_extensions(): + debug_mode = os.getenv("DEBUG", "0") == "1" + use_cuda = os.getenv("USE_CUDA", "1") == "1" + if debug_mode: + print("Compiling in debug mode") + + use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None + extension = CUDAExtension if use_cuda else CppExtension + + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + ], + } + if debug_mode: + extra_compile_args["cxx"].append("-g") + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + this_dir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(this_dir, library_name, "csrc") + sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) + + extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) + + if use_cuda: + sources += cuda_sources + + ext_modules = [ + extension( + f"{library_name}._C", + sources, + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args, + ) + ] + + return ext_modules + + +setup( + name=library_name, + version="0.0.1", + packages=find_packages(), + ext_modules=get_extensions(), + install_requires=["torch"], + description="Example of PyTorch cpp and CUDA extensions", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/pytorch/extension-cpp", + cmdclass={"build_ext": BuildExtension}, +) diff --git a/test/benchmark.py b/test/benchmark.py new file mode 100644 index 0000000..e9f4799 --- /dev/null +++ b/test/benchmark.py @@ -0,0 +1,83 @@ +from __future__ import division +from __future__ import print_function + +import argparse +import math +import time + +import torch + +TIME_SCALES = {"s": 1, "ms": 1000, "us": 1000000} + +parser = argparse.ArgumentParser() +parser.add_argument("example", choices=["py", "cpp", "cuda"]) +parser.add_argument("-b", "--batch-size", type=int, default=16) +parser.add_argument("-f", "--features", type=int, default=32) +parser.add_argument("-s", "--state-size", type=int, default=128) +parser.add_argument("-r", "--runs", type=int, default=100) +parser.add_argument("--scale", choices=["s", "ms", "us"], default="us") +parser.add_argument("-c", "--cuda", action="store_true") +parser.add_argument("-d", "--double", action="store_true") +options = parser.parse_args() + +if options.example == "py": + from extension_cpp.ops import reference_lltm as LLTM +else: + from extension_cpp.ops import lltm as LLTM +if options.example == "cuda": + options.cuda = True + +device = torch.device("cuda") if options.cuda else torch.device("cpu") +dtype = torch.float64 if options.double else torch.float32 + +kwargs = {"dtype": dtype, "device": device, "requires_grad": True} +batch_size = options.batch_size +features = options.features +state_size = options.state_size +X = torch.randn( + batch_size, # E: No overload variant of "randn" matches argument + features, + **kwargs +) +h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia +C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia +W = torch.randn(3 * state_size, features + state_size, **kwargs) +b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn" + +# Force CUDA initialization +new_h, new_C = LLTM(X, W, b, h, C) +(new_h.sum() + new_C.sum()).backward() + +forward_min = math.inf +forward_time = 0 +backward_min = math.inf +backward_time = 0 +for _ in range(options.runs): + X.grad = None + h.grad = None + C.grad = None + W.grad = None + b.grad = None + start = time.time() + new_h, new_C = LLTM(X, W, b, h, C) + elapsed = time.time() - start + forward_min = min(forward_min, elapsed) + forward_time += elapsed + + start = time.time() + (new_h.sum() + new_C.sum()).backward() + elapsed = time.time() - start + backward_min = min(backward_min, elapsed) + backward_time += elapsed + +scale = TIME_SCALES[options.scale] +forward_min *= scale +backward_min *= scale +forward_average = forward_time / options.runs * scale +backward_average = backward_time / options.runs * scale + +print( + "Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}".format( + forward_min, forward_average, backward_min, backward_average, options.scale + ) +) diff --git a/test/test_extension.py b/test/test_extension.py new file mode 100644 index 0000000..3d4c81f --- /dev/null +++ b/test/test_extension.py @@ -0,0 +1,58 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.optests import opcheck +import unittest +import extension_cpp +from torch import Tensor +from typing import Tuple +import torch.nn.functional as F + + +def sample_inputs(device): + batch_size = 3 + features = 17 + state_size = 5 + kwargs = {"dtype": torch.float64, "device": device, "requires_grad": True} + X = torch.randn( + batch_size, # E: No overload variant of "randn" matches argument + features, + **kwargs + ) + h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia + C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia + W = torch.randn(3 * state_size, features + state_size, **kwargs) + b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn" + return X, W, b, h, C + + +class TestLLTM(TestCase): + def _test_correctness(self, device): + args = sample_inputs(device) + result = extension_cpp.ops.lltm(*args) + expected = extension_cpp.ops.reference_lltm(*args) + self.assertEqual(len(result), len(expected)) + torch.testing.assert_close(result, expected) + + def test_correctness_cpu(self): + self._test_correctness("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_correctness_cuda(self): + self._test_correctness("cuda") + + def _test_gradients(self, device): + args = sample_inputs(device) + torch.autograd.gradcheck(extension_cpp.ops.lltm, args) + + def test_gradients_cpu(self): + self._test_gradients("cpu") + + # This is supposed to succeed, there's probably a bug in the CUDA kernel. + @unittest.expectedFailure + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_gradients_cuda(self): + self._test_gradients("cuda") + + +if __name__ == "__main__": + unittest.main()