Skip to content

Commit e4e6713

Browse files
committed
Add check.py to check forward and backward outputs
1 parent 91617a4 commit e4e6713

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

check.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
4+
import argparse
5+
import numpy as np
6+
import torch
7+
8+
from torch.autograd import Variable
9+
10+
import python.lltm_baseline
11+
import cpp.lltm
12+
13+
14+
def check_equal(first, second, verbose):
15+
if verbose:
16+
print()
17+
for i, (x, y) in enumerate(zip(first, second)):
18+
x = x.cpu().detach().numpy()
19+
y = y.cpu().detach().numpy()
20+
if verbose:
21+
print("x = {}".format(x.flatten()))
22+
print("y = {}".format(y.flatten()))
23+
print('-' * 80)
24+
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i))
25+
26+
27+
def check_forward(variables, with_cuda, verbose):
28+
baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables)
29+
cpp_values = cpp.lltm.LLTMFunction.apply(*variables)
30+
31+
print('Forward: Baseline (Python) vs. C++ ... ', end='')
32+
check_equal(baseline_values, cpp_values, verbose)
33+
print('Ok')
34+
35+
if with_cuda:
36+
cuda_values = cuda.lltm.LLTMFunction.apply(*variables)
37+
print('Forward: Baseline (Python) vs. CUDA ... ', end='')
38+
check_equal(baseline_values, cuda_values, verbose)
39+
print('Ok')
40+
41+
42+
def check_backward(variables, with_cuda, verbose):
43+
baseline_values = python.lltm_baseline.LLTMFunction.apply(*variables)
44+
(baseline_values[0] + baseline_values[1]).sum().backward()
45+
grad_baseline = [var.grad for var in variables]
46+
47+
cpp_values = cpp.lltm.LLTMFunction.apply(*variables)
48+
(cpp_values[0] + cpp_values[1]).sum().backward()
49+
grad_cpp = [var.grad for var in variables]
50+
51+
print('Backward: Baseline (Python) vs. C++ ... ', end='')
52+
check_equal(grad_baseline, grad_cpp, verbose)
53+
print('Ok')
54+
55+
if with_cuda:
56+
cuda_values = cuda.lltm.LLTMFunction.apply(*variables)
57+
(cuda_values[0] + cuda_values[1]).sum().backward()
58+
grad_cuda = [var.grad for var in variables]
59+
60+
print('Backward: Baseline (Python) vs. CUDA ... ', end='')
61+
check_equal(grad_baseline, grad_cuda, verbose)
62+
print('Ok')
63+
64+
65+
parser = argparse.ArgumentParser()
66+
parser.add_argument('direction', choices=['forward', 'backward'], nargs='+')
67+
parser.add_argument('-b', '--batch-size', type=int, default=3)
68+
parser.add_argument('-f', '--features', type=int, default=17)
69+
parser.add_argument('-s', '--state-size', type=int, default=5)
70+
parser.add_argument('-c', '--cuda', action='store_true')
71+
parser.add_argument('-v', '--verbose', action='store_true')
72+
options = parser.parse_args()
73+
74+
if options.cuda:
75+
import cuda.lltm
76+
options.cuda = True
77+
78+
X = torch.randn(options.batch_size, options.features)
79+
h = torch.randn(options.batch_size, options.state_size)
80+
C = torch.randn(options.batch_size, options.state_size)
81+
W = torch.randn(3 * options.state_size, options.features + options.state_size)
82+
b = torch.randn(1, 3 * options.state_size)
83+
84+
variables = [X, W, b, h, C]
85+
86+
for i, var in enumerate(variables):
87+
if options.cuda:
88+
var = var.cuda()
89+
variables[i] = Variable(var.double(), requires_grad=True)
90+
91+
if 'forward' in options.direction:
92+
check_forward(variables, options.cuda, options.verbose)
93+
94+
if 'backward' in options.direction:
95+
check_backward(variables, options.cuda, options.verbose)

grad_check.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,5 @@
3535
var = var.cuda()
3636
variables[i] = Variable(var.double(), requires_grad=True)
3737

38-
print(LLTMFunction.apply(*variables))
39-
4038
if gradcheck(LLTMFunction.apply, variables):
4139
print('Ok')

0 commit comments

Comments
 (0)