Skip to content

Commit 6f8e941

Browse files
committed
add example of defining new autograd function
1 parent d28debf commit 6f8e941

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

new_function.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
from torch.autograd import Variable
3+
4+
5+
class MulTwo(torch.autograd.Function):
6+
def forward(self, input):
7+
return 2.0 * input
8+
9+
def backward(self, grad_output):
10+
return 2.0 * grad_output
11+
12+
13+
x = Variable(torch.randn(3, 4), requires_grad=True)
14+
y = MulTwo()(x)
15+
s = y.sum()
16+
17+
s.backward()
18+
print(x.grad.data)
19+

0 commit comments

Comments
 (0)