6 #ifndef TENSORTWOD_HPP_INCLUDE
7 #define TENSORTWOD_HPP_INCLUDE
28 TensorTwoD(
size_t rows,
size_t cols, std::shared_ptr<TensorMath> math);
41 TensorTwoD(
const std::vector<std::vector<double> > &input);
49 TensorTwoD(
const std::vector<std::vector<double> > &input,
50 std::shared_ptr<TensorMath> math);
57 TensorTwoD(
const std::vector<TensorOneD> &input);
65 TensorTwoD(
const std::vector<TensorOneD> &input, std::shared_ptr<TensorMath> math);
86 void set_dim(
size_t rows,
size_t cols);
121 const std::vector<TensorOneD> &
get_data()
const;
124 void set_data(
const std::vector<TensorOneD> &);
128 std::vector<TensorOneD> m_mat;
129 std::shared_ptr<TensorMath> m_math;
Class to store and perform operations on 1D Tensors, aka vectors, suitable for use in feed-forward ne...
Definition: TensorOneD.hpp:21
Class to manage data and operations related to 2D Tensors required for neural net inference.
Definition: TensorTwoD.hpp:22
size_t get_rows() const
Get number of rows in the 2D tensor.
Definition: TensorTwoD.cpp:80
TensorTwoD()
Definition: TensorTwoD.cpp:16
void set_data(const std::vector< TensorOneD > &)
Set the contents as a vector of tensors.
Definition: TensorTwoD.cpp:150
void set_dim(size_t rows, size_t cols)
Set dimensions of 2D tensor.
Definition: TensorTwoD.cpp:93
bool operator==(const TensorTwoD &other) const
Overload == operator to do comparison of the underlying.
Definition: TensorTwoD.cpp:140
TensorOneD & operator[](size_t idx)
Reference indexing of 1D Tensor at idx of the 2D Tensor.
Definition: TensorTwoD.cpp:112
virtual ~TensorTwoD()=default
TensorTwoD & operator=(const TensorTwoD &other)
Oerload = operator with an in-place deep copy.
Definition: TensorTwoD.cpp:122
TensorOneD operator*(const TensorOneD &) const
Multiply a 2D tensor by a 1D tensor.
Definition: TensorTwoD.cpp:107
const std::vector< TensorOneD > & get_data() const
Return the tensor as a vector of tensors.
Definition: TensorTwoD.cpp:145
size_t get_cols() const
get number of columns in 2D tensor
Definition: TensorTwoD.cpp:85
Definition: Accumulator.cpp:12