Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9a2f7bb

Browse files
Remove chainrule and test for SqMahalanobis (#539)
1 parent ec19a94 commit 9a2f7bb

File tree

3 files changed

+1
-32
lines changed

3 files changed

+1
-32
lines changed

‎Project.toml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.10.59"
3+
version = "0.10.60"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

‎src/chainrules.jl‎

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -121,28 +121,6 @@ function ChainRulesCore.rrule(s::Sinus, x::AbstractVector, y::AbstractVector)
121121
return val, evaluate_pullback
122122
end
123123

124-
## Reverse Rules SqMahalanobis
125-
126-
function ChainRulesCore.rrule(
127-
dist::Distances.SqMahalanobis, a::AbstractVector, b::AbstractVector
128-
)
129-
d = dist(a, b)
130-
function SqMahalanobis_pullback::Real)
131-
a_b = a - b
132-
∂qmat = InplaceableThunk(
133-
-> mul!(X̄, a_b, a_b', true, Δ), @thunk((a_b * a_b') * Δ)
134-
)
135-
∂a = InplaceableThunk(
136-
-> mul!(X̄, dist.qmat, a_b, true, 2 * Δ), @thunk((2 * Δ) * dist.qmat * a_b)
137-
)
138-
∂b = InplaceableThunk(
139-
-> mul!(X̄, dist.qmat, a_b, true, -2 * Δ), @thunk((-2 * Δ) * dist.qmat * a_b)
140-
)
141-
return Tangent{typeof(dist)}(; qmat=∂qmat), ∂a, ∂b
142-
end
143-
return d, SqMahalanobis_pullback
144-
end
145-
146124
## Reverse Rules for matrix wrappers
147125

148126
function ChainRulesCore.rrule(::Type{<:ColVecs}, X::AbstractMatrix)

‎test/chainrules.jl‎

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
x = rand(rng, 5)
44
y = rand(rng, 5)
55
r = rand(rng, 5)
6-
Q = Matrix(Cholesky(rand(rng, 5, 5), 'U', 0))
7-
@assert isposdef(Q)
86

97
compare_gradient(:Zygote, [x, y]) do xy
108
Euclidean()(xy[1], xy[2])
@@ -21,11 +19,4 @@
2119
compare_gradient(:Zygote, [x, y]) do xy
2220
KernelFunctions.Sinus(r)(xy[1], xy[2])
2321
end
24-
if VERSION < v"1.6"
25-
@test_broken "Chain rule of SqMahalanobis is broken in Julia pre-1.6"
26-
else
27-
compare_gradient(:Zygote, [Q, x, y]) do Qxy
28-
SqMahalanobis(Qxy[1])(Qxy[2], Qxy[3])
29-
end
30-
end
3122
end

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /