Function amsgrad
Creates a delegate that can be used to perform a step using the AMSGrad update rule.
{null} amsgrad
();
This function relies on automatic differentiation, so the objective (which must have a volume of 1) must be differentiable w.r.t. all elements of wrt. The returned delegate performs minimisation.
Parameters
Name | Description |
---|---|
outputs | An array of outputs. The first element of this array is the objective function to be minimised. |
wrt | An array of Operations that we want the derivative of objective with respect to. |
projs | Projection functions that can be applied when updating the values of elements in wrt . |
alpha | The step size. |
beta1 | Fading factor for the first moment of the gradient. |
beta2 | Fading factor for the second moment of the gradient. |
eps | To prevent division by zero. |
Returns
A delegate that is used to actually perform the update steps. The optimised values are stored in the
value
properties of the elements of wrt
. The delegate returns the values computed for each element of the
outputs
array. This can be useful for keeping track of several different performance metrics in a
prequential manner.
Example
import std .random : uniform;
//Generate some points
auto xdata = new float[100];
auto ydata = new float[100];
foreach(i; 0 .. 100)
{
xdata[i] = uniform(-10.0f, 10.0f);
ydata[i] = 3.0f * xdata[i] + 2.0f;
}
//Create the model
auto x = float32([]);
auto m = float32([]);
auto c = float32([]);
auto yhat = m * x + c;
auto y = float32([]);
//Create an AMSGrad updater
auto updater = amsgrad([(yhat - y) * (yhat - y)], [m, c], null, float32([], [0.1f]));
//Iterate for a while
float loss;
for(size_t i = 0; i < 300; i++)
{
size_t j = i % 100;
loss = updater([
x: Buffer(xdata[j .. j + 1]),
y: Buffer(ydata[j .. j + 1])
])[0] .as!float[0];
}
//Print the loss after 200 iterations. Let the user decide whether it's good enough to be considered a pass.
import std .stdio : writeln;
writeln(
"AMSGrad loss: ", loss, " ",
"m=", m .value .as!float[0], ", ",
"c=", c .value .as!float[0], " ",
"(expected m=3, c=2)");