Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions spec/grad/gates_arithmetic_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,114 @@ describe Num::Grad do
end
{% end %}

it "backpropogates for addition with broadcast" do
ctx = Num::Grad::Context(Float32Tensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
])
b = ctx.variable([
1_f32, 1_f32, 1_f32, 1_f32,
])

result = a + b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [4_f32, 4_f32, 4_f32, 4_f32].to_tensor

Num::Testing.tensor_equal(a.grad, expected_a).should be_true
Num::Testing.tensor_equal(b.grad, expected_b).should be_true
end

{% if flag?(:opencl) %}
it "backpropogates for addition with broadcast opencl", tags: "opencl" do
ctx = Num::Grad::Context(Float32ClTensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
].to_tensor(OCL))
b = ctx.variable([
1_f32, 1_f32, 1_f32, 1_f32,
].to_tensor(OCL))

result = a + b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [4_f32, 4_f32, 4_f32, 4_f32].to_tensor

Num::Testing.tensor_equal(a.grad.cpu, expected_a).should be_true
Num::Testing.tensor_equal(b.grad.cpu, expected_b).should be_true
end
{% end %}

it "backpropogates for addition with scalar broadcast" do
ctx = Num::Grad::Context(Float32Tensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
])
b = ctx.variable([
1_f32,
])

result = a + b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [16_f32].to_tensor

Num::Testing.tensor_equal(a.grad, expected_a).should be_true
Num::Testing.tensor_equal(b.grad, expected_b).should be_true
end

{% if flag?(:opencl) %}
it "backpropogates for addition with scalar broadcast opencl", tags: "opencl" do
ctx = Num::Grad::Context(Float32ClTensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
].to_tensor(OCL))
b = ctx.variable([
1_f32,
].to_tensor(OCL))

result = a + b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [16_f32].to_tensor

Num::Testing.tensor_equal(a.grad.cpu, expected_a).should be_true
Num::Testing.tensor_equal(b.grad.cpu, expected_b).should be_true
end
{% end %}

it "backpropogates for subtraction" do
ctx = Num::Grad::Context(Float32Tensor).new

Expand Down Expand Up @@ -112,6 +220,114 @@ describe Num::Grad do
end
{% end %}

it "backpropogates for subtraction with broadcast" do
ctx = Num::Grad::Context(Float32Tensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
])
b = ctx.variable([
1_f32, 1_f32, 1_f32, 1_f32,
])

result = a - b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [-4_f32, -4_f32, -4_f32, -4_f32].to_tensor

Num::Testing.tensor_equal(a.grad, expected_a).should be_true
Num::Testing.tensor_equal(b.grad, expected_b).should be_true
end

{% if flag?(:opencl) %}
it "backpropogates for subtraction with broadcast opencl", tags: "opencl" do
ctx = Num::Grad::Context(Float32ClTensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
].to_tensor(OCL))
b = ctx.variable([
1_f32, 1_f32, 1_f32, 1_f32,
].to_tensor(OCL))

result = a - b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [-4_f32, -4_f32, -4_f32, -4_f32].to_tensor

Num::Testing.tensor_equal(a.grad.cpu, expected_a).should be_true
Num::Testing.tensor_equal(b.grad.cpu, expected_b).should be_true
end
{% end %}

it "backpropogates for subtraction with scalar broadcast" do
ctx = Num::Grad::Context(Float32Tensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
])
b = ctx.variable([
1_f32,
])

result = a - b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [-16_f32].to_tensor

Num::Testing.tensor_equal(a.grad, expected_a).should be_true
Num::Testing.tensor_equal(b.grad, expected_b).should be_true
end

{% if flag?(:opencl) %}
it "backpropogates for subtraction with scalar broadcast opencl", tags: "opencl" do
ctx = Num::Grad::Context(Float32ClTensor).new

a = ctx.variable([
[1_f32, 2_f32, 3_f32, 4_f32],
[5_f32, 6_f32, 7_f32, 8_f32],
[9_f32, 10_f32, 11_f32, 12_f32],
[13_f32, 14_f32, 15_f32, 16_f32],
].to_tensor(OCL))
b = ctx.variable([
1_f32,
].to_tensor(OCL))

result = a - b
result.backprop

expected_a = [[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32],
[1_f32, 1_f32, 1_f32, 1_f32]].to_tensor
expected_b = [-16_f32].to_tensor

Num::Testing.tensor_equal(a.grad.cpu, expected_a).should be_true
Num::Testing.tensor_equal(b.grad.cpu, expected_b).should be_true
end
{% end %}

it "backpropogates for multiplication" do
ctx = Num::Grad::Context(Float32Tensor).new

Expand Down
32 changes: 28 additions & 4 deletions src/grad/backends/agnostic.cr
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,38 @@
module Num::Grad
extend self

#
# This returns the appropriate backward gradient processing for
# addition and subtraction based on the
# size and rank of the two variables
#
private def sum_grad_backward(gradient : U, a : U, b : U) : Array(U) forall U
if a.rank != b.rank
# broadcast along an axis, so sum dwn by axis
swap = a.rank > b.rank
gless = gradient
(b.rank - a.rank).abs.times do
gless = gless.sum(0)
end
if a.size == 1 || b.size == 1
gless = gless.sum(0)
end
swap ? [gradient, gless] : [gless, gradient]
else
[gradient, gradient]
end
end

# :nodoc:
def add_backward(gradient : U) : Array(U) forall U
[gradient, gradient]
def add_backward(gradient : U, a : Variable(U), b : Variable(U)) : Array(U) forall U
sum_grad_backward(gradient, a.value, b.value)
end

# :nodoc:
def subtract_backward(gradient : U) : Array(U) forall U
[gradient, gradient * -1]
def subtract_backward(gradient : U, a : Variable(U), b : Variable(U)) : Array(U) forall U
r = sum_grad_backward(gradient, a.value, b.value)
r[1] = -r[1]
r
end

# :nodoc:
Expand Down
49 changes: 16 additions & 33 deletions src/grad/gates/arithmetic.cr
Original file line number Diff line number Diff line change
Expand Up @@ -22,61 +22,44 @@
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

# :nodoc:
class Num::Grad::AddGate(T) < Num::Grad::Gate(T)
abstract class Num::Grad::TwoOpGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)
getter b : Num::Grad::Variable(T)
@@name = "TwoOp"

# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
Num::Grad.add_backward(payload.variable.grad)
def initialize(@a, @b)
end

abstract def backward(payload : Num::Grad::Payload(T)) : Array(T)

# :nodoc:
def cache(result : Num::Grad::Variable(T), *args)
a, b = args

result.grad = T.zeros_like(result.value)
result.requires_grad = true

Num::Grad.register("Add", self, result, a, b)
Num::Grad.register(@@name, self, result, a, b)
end
end

# :nodoc:
class Num::Grad::SubtractGate(T) < Num::Grad::Gate(T)
# :nodoc:
def backward(payload : Num::Grad::Payload(T)) : Array(T)
Num::Grad.subtract_backward(payload.variable.grad)
end
class Num::Grad::AddGate(T) < Num::Grad::TwoOpGate(T)
@@name = "Add"

# :nodoc:
def cache(result : Num::Grad::Variable(T), *args)
a, b = args
result.grad = T.zeros_like(result.value)
result.requires_grad = true

Num::Grad.register("Sub", self, result, a, b)
def backward(payload : Num::Grad::Payload(T)) : Array(T)
Num::Grad.add_backward(payload.variable.grad, a, b)
end
end

# :nodoc:
class Num::Grad::TwoOpGate(T) < Num::Grad::Gate(T)
getter a : Num::Grad::Variable(T)
getter b : Num::Grad::Variable(T)
@@name = "TwoOp"
class Num::Grad::SubtractGate(T) < Num::Grad::TwoOpGate(T)
@@name = "Sub"

# :nodoc:
def initialize(@a : Num::Grad::Variable(T), @b : Num::Grad::Variable(T))
end

def backward(payload : Num::Grad::Payload(T)) : Array(T)
[] of T
end

# :nodoc:
def cache(result : Num::Grad::Variable(T), *args)
a, b = args
result.grad = T.zeros_like(result.value)
result.requires_grad = true

Num::Grad.register(@@name, self, result, a, b)
Num::Grad.subtract_backward(payload.variable.grad, a, b)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/grad/variable.cr
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Num::Grad::Variable(T)
# f = a + b # => [5.0]
# f.backprop
# ```
operator_op :+, Num::Grad::AddGate(T)
operator_op :+, Num::Grad::AddGate(T), self, other

# Subtracts a variable from another variable and stores
# the derivative of the operation in the computational
Expand All @@ -83,7 +83,7 @@ class Num::Grad::Variable(T)
# f = a - b # => [-1.0]
# f.backprop
# ```
operator_op :-, Num::Grad::SubtractGate(T)
operator_op :-, Num::Grad::SubtractGate(T), self, other

# Multiples a variable to another variable and stores
# the derivative of the operation in the computational
Expand Down