Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8edb37b

Browse files
add autodiff
1 parent 6c9f942 commit 8edb37b

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import math
2+
3+
class Var:
4+
def __init__(self, val, deriv=1.0):
5+
self.val = val
6+
self.deriv = deriv
7+
8+
def __add__(self, other):
9+
if isinstance(other, Var):
10+
val = self.val + other.val
11+
deriv = self.deriv + other.deriv
12+
else:
13+
val = self.val + other
14+
deriv = self.deriv
15+
return Var(val, deriv)
16+
17+
def __radd__(self, other):
18+
return self + other
19+
20+
def __sub__(self, other):
21+
if isinstance(other, Var):
22+
val = self.val - other.val
23+
deriv = self.deriv - other.deriv
24+
else:
25+
val = self.val - other
26+
deriv = self.deriv
27+
return Var(val, deriv)
28+
29+
def __rsub__(self, other):
30+
val = other - self.val
31+
deriv = - self.deriv
32+
return Var(val, deriv)
33+
34+
def __mul__(self, other):
35+
if isinstance(other, Var):
36+
val = self.val * other.val
37+
deriv = self.val * other.deriv + self.deriv * other.val
38+
else:
39+
val = self.val * other
40+
deriv = self.deriv * other
41+
return Var(val, deriv)
42+
43+
def __rmul__(self, other):
44+
return self * other
45+
46+
def __truediv__(self, other):
47+
if isinstance(other, Var):
48+
val = self.val / other.val
49+
deriv = (self.deriv * other.val - self.val * other.deriv)/other.val**2
50+
else:
51+
val = self.val / other
52+
deriv = self.deriv / other
53+
return Var(val, deriv)
54+
55+
def __rtruediv__(self, other):
56+
val = other / self.val
57+
deriv = other * 1/self.val**2
58+
return Var(val, deriv)
59+
60+
def __repr__(self):
61+
return "value: {}\t deriv: {}".format(self.val, self.deriv)
62+
63+
64+
def exp(f: Var):
65+
return Var(math.exp(f.val), math.exp(f.val) * f.deriv)
66+
67+
68+
fx = lambda x: exp(x*x - x)/x
69+
70+
df = fx(Var(2.0))
71+
print(df)
72+
73+
# value: 3.694528049465325 deriv: 9.236320123663312

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /