machine_learning/rbm.cpp
/*******************************************************
* Copyright (c) 2014, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* https://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <arrayfire.h>
#include <math.h>
#include <stdio.h>
#include <af/util.h>
#include <string>
#include <vector>
#include "mnist_common.h"
using namespace af;
using std::vector;
float accuracy(const array &predicted, const array &target) {
array val, plabels, tlabels;
max(val, tlabels, target, 1);
max(val, plabels, predicted, 1);
return 100 * count<float>(plabels == tlabels) / tlabels.elements();
}
// Derivative of the activation function
array deriv(const array &out) { return out * (1 - out); }
// Cost function
double error(const array &out, const array &pred) {
array dif = (out - pred);
return sqrt((double)(sum<float>(dif * dif)));
}
array binary(const array in) {
// Choosing "1" with probability sigmoid(in)
return (in > randu(in.dims())).as(f32);
}
class rbm {
private:
array weights;
array h_bias;
array v_bias;
// Add bias input to the output from previous layer
array vtoh(const array &v) { return binary(prop_up(v)); }
array htov(const array &h) { return binary(prop_down(h)); }
public:
rbm() {}
rbm(int v_size, int h_size)
: weights(randu(h_size, v_size) / 100 - 0.05)
, h_bias(constant(0, 1, h_size))
, v_bias(constant(0, 1, v_size)) {}
array prop_up(const array &v) {
array h_bias_tile = tile(h_bias, v.dims(0));
return sigmoid(h_bias_tile + matmulNT(v, weights));
}
array prop_down(const array &h) {
array v_bias_tile = tile(v_bias, h.dims(0));
return sigmoid(v_bias_tile + matmul(h, weights));
}
void gibbs_vhv(array &vt, array &ht, const array &v, int k = 1) {
vt = v;
for (int i = 0; i < k; i++) {
ht = vtoh(vt);
vt = htov(ht);
}
}
void gibbs_hvh(array &vt, array &ht, const array &h, int k = 1) {
ht = h;
for (int i = 0; i < k; i++) {
vt = htov(ht);
ht = vtoh(vt);
}
}
void train(const array &in, double lr = 0.1, int num_epochs = 15,
int batch_size = 100, int k = 1, bool verbose = false) {
const int num_samples = in.dims(0);
const int num_batches = num_samples / batch_size;
for (int i = 0; i < num_epochs; i++) {
double err = 0;
for (int j = 0; j < num_batches - 1; j++) {
int st = j * batch_size;
int en = std::min(num_samples - 1, st + batch_size - 1);
int num = en - st + 1;
array v_pos = in(seq(st, en), span);
array h_pos = vtoh(v_pos);
array v_neg, h_neg;
gibbs_hvh(v_neg, h_neg, h_pos, k);
// Update weights
array c_pos = matmulTN(h_pos, v_pos);
array c_neg = matmulTN(h_neg, v_neg);
array delta_w = lr * (c_pos - c_neg) / num;
array delta_vb = lr * sum(v_pos - v_neg) / num;
array delta_hb = lr * sum(h_pos - h_neg) / num;
weights += delta_w;
v_bias += delta_vb;
h_bias += delta_hb;
if (verbose) { err += error(v_pos, v_neg); }
}
if (verbose) {
printf("Epoch %d: Reconstruction error: %0.4f\n", i + 1,
err / num_batches);
}
}
if (verbose) printf("\n");
}
};
int rbm_demo(bool /*console*/, int perc) {
printf("** ArrayFire RBM Demo **\n\n");
array train_images, test_images;
array train_target, test_target;
int num_classes, num_train, num_test;
// Load mnist data
float frac = (float)(perc) / 100.0;
setup_mnist<true>(&num_classes, &num_train, &num_test, train_images,
test_images, train_target, test_target, frac);
dim4 dims = train_images.dims();
int feature_size = train_images.elements() / num_train;
// Reshape images into feature vectors
array train_feats = moddims(train_images, feature_size, num_train).T();
array test_feats = moddims(test_images, feature_size, num_test).T();
train_target = train_target.T();
test_target = test_target.T();
rbm network(train_feats.dims(1), 2000);
network.train(train_feats,
0.1, // learning rate
15, // num epochs
100, // batch size
1, // k
true);
// Test reconstructed images
for (int ii = 0; ii < 5; ii++) {
array in = test_feats(ii, span);
array res, tmp;
network.gibbs_vhv(res, tmp, in);
in = moddims(in, dims[0], dims[1]);
res = moddims(res, dims[0], dims[1]);
in = round(in);
res = round(res);
printf("Reconstructed Error for image %2d: %.4f\n", ii,
sum<float>(abs(in - res)) / feature_size);
}
return 0;
}
int main(int argc, char **argv) {
int device = argc > 1 ? atoi(argv[1]) : 0;
bool console = argc > 2 ? argv[2][0] == '-' : false;
int perc = argc > 3 ? atoi(argv[3]) : 60;
try {
af::setDevice(device);
return rbm_demo(console, perc);
} catch (af::exception &ae) { std::cerr << ae.what() << std::endl; }
return 0;
}
af::dim4
Generic object that represents size and shape.
Definition: dim4.hpp:33
af::round
AFAPI array round(const array &in)
C++ Interface for rounding an array of numbers.
af::matmul
AFAPI array matmul(const array &lhs, const array &rhs, const matProp optLhs=AF_MAT_NONE, const matProp optRhs=AF_MAT_NONE)
Matrix multiply of two arrays.
af::seq
seq is used to create sequences for indexing af::array
Definition: seq.h:46
util.h
af::info
AFAPI void info()
af::array::as
const array as(dtype type) const
Converts the array into another type.
af::constant
array constant(T val, const dim4 &dims, const dtype ty=(af_dtype) dtype_traits< T >::ctype)
af::moddims
AFAPI array moddims(const array &in, const unsigned ndims, const dim_t *const dims)
af::setDevice
AFAPI void setDevice(const int device)
Sets the current device.
af::abs
AFAPI array abs(const array &in)
C++ Interface for absolute value.
af::array
A multi dimensional data container.
Definition: array.h:35
af
Definition: algorithm.h:15
af::matmulTN
AFAPI array matmulTN(const array &lhs, const array &rhs)
Matrix multiply of two arrays.
af::max
AFAPI array max(const array &in, const int dim=-1)
C++ Interface for maximum values in an array.
af::array::elements
dim_t elements() const
Get the total number of elements across all dimensions of the array.
af::matmulNT
AFAPI array matmulNT(const array &lhs, const array &rhs)
Matrix multiply of two arrays.
af::sqrt
AFAPI array sqrt(const array &in)
C++ Interface for square root of input.
af::randu
AFAPI array randu(const dim4 &dims, const dtype ty, randomEngine &r)
af::exception
An ArrayFire exception class.
Definition: exception.h:29
af::tile
AFAPI array tile(const array &in, const unsigned x, const unsigned y=1, const unsigned z=1, const unsigned w=1)
af::span
AFAPI seq span
A special value representing the entire axis of an af::array.
af::array::dims
dim4 dims() const
Get dimensions of the array.
af::sigmoid
AFAPI array sigmoid(const array &in)
C++ Interface for calculating sigmoid function of an array.
af::sum
AFAPI array sum(const array &in, const int dim=-1)
C++ Interface for sum of elements in an array.
arrayfire.h
af::exception::what
virtual const char * what() const
Returns an error message for the exception in a string format.
Definition: exception.h:60
af::array::T
array T() const
Get the transposed the array.
af::min
AFAPI array min(const array &in, const int dim=-1)
C++ Interface for minimum values in an array.
f32
@ f32
32-bit floating point values
Definition: defines.h:211