Skip to content
This repository was archived by the owner on Jul 24, 2022. It is now read-only.

Commit 5be8789

Browse files
committed
change sigmoid and tanh to torch from nn.functional
1 parent ae4a541 commit 5be8789

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

python/lltm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ def forward(self, input, state):
3131
# Split the combined gate weight matrix into its components.
3232
gates = gate_weights.chunk(3, dim=1)
3333

34-
input_gate = F.sigmoid(gates[0])
35-
output_gate = F.sigmoid(gates[1])
34+
input_gate = torch.sigmoid(gates[0])
35+
output_gate = torch.sigmoid(gates[1])
3636
# Here we use an ELU instead of the usual tanh.
3737
candidate_cell = F.elu(gates[2])
3838

3939
# Compute the new cell state.
4040
new_cell = old_cell + candidate_cell * input_gate
4141
# Compute the new hidden state and output.
42-
new_h = F.tanh(new_cell) * output_gate
42+
new_h = torch.tanh(new_cell) * output_gate
4343

4444
return new_h, new_cell

python/lltm_baseline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010

1111
def d_sigmoid(z):
12-
s = F.sigmoid(z)
12+
s = torch.sigmoid(z)
1313
return (1 - s) * s
1414

1515

1616
def d_tanh(z):
17-
t = F.tanh(z)
17+
t = torch.tanh(z)
1818
return 1 - (t * t)
1919

2020

@@ -32,12 +32,12 @@ def forward(ctx, input, weights, bias, old_h, old_cell):
3232
gate_weights = F.linear(X, weights, bias)
3333
gates = gate_weights.chunk(3, dim=1)
3434

35-
input_gate = F.sigmoid(gates[0])
36-
output_gate = F.sigmoid(gates[1])
35+
input_gate = torch.sigmoid(gates[0])
36+
output_gate = torch.sigmoid(gates[1])
3737
candidate_cell = F.elu(gates[2])
3838

3939
new_cell = old_cell + candidate_cell * input_gate
40-
new_h = F.tanh(new_cell) * output_gate
40+
new_h = torch.tanh(new_cell) * output_gate
4141

4242
ctx.save_for_backward(X, weights, input_gate, output_gate, old_cell,
4343
new_cell, candidate_cell, gate_weights)
@@ -51,7 +51,7 @@ def backward(ctx, grad_h, grad_cell):
5151

5252
d_input = d_weights = d_bias = d_old_h = d_old_cell = None
5353

54-
d_output_gate = F.tanh(new_cell) * grad_h
54+
d_output_gate = torch.tanh(new_cell) * grad_h
5555
d_tanh_new_cell = output_gate * grad_h
5656
d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell
5757

0 commit comments

Comments
 (0)