template class xf::data_analytics::regression::internal::ridgeRegressionSGDTrainer¶
#include "linearRegression.hpp"
Overview¶
ridge regression training using SGD framework
Parameters:
MType | datatype of regression, support double and float |
WAxi | AXI interface width to load training data. |
WData | Data width of feature data type. |
BurstLen | Length of burst read. |
D | Number of features that processed each cycle |
DDepth | DDepth * D is max feature numbers supported. |
RAMWeight | Use which kind of RAM to store weight, could be LUTRAM, BRAM or URAM. |
RAMIntercept | Use which kind of RAM to store intercept, could be LUTRAM, BRAM or URAM. |
RAMAvgWeight | Use which kind of RAM to store Avg of Weigth, could be LUTRAM, BRAM or URAM. |
RAMAvgIntercept | Use which kind of RAM to store Avg of intercept, could be LUTRAM, BRAM or URAM. |
template < typename MType, int WAxi, int WData, int BurstLen, int D, int DDepth, RAMType RAMWeight, RAMType RAMIntercept, RAMType RAMAvgWeight, RAMType RAMAvgIntercept > class ridgeRegressionSGDTrainer: public xf::data_analytics::common::SGDFramework // fields MType regVal
Inherited Members¶
// typedefs typedef Gradient::DataType MType // fields static const int WAxi static const int D static const int Depth ap_uint <32> offset ap_uint <32> rows ap_uint <32> cols ap_uint <32> bucketSize float fraction bool ifJump MType stepSize MType tolerance bool withIntercept ap_uint <32> maxIter Gradient gradProcessor
Methods¶
setTrainingConfigs¶
void setTrainingConfigs ( MType inputStepSize, MType inputTolerance, MType inputRegVal, bool inputWithIntercept, ap_uint <32> inputMaxIter )
Set up configs related to SGD iteration.
Parameters:
inputStepSize | step size of SGD iteratin. |
inputTolerance | Convergence tolerance. |
inputRegVal | regularazation value for LASSO. |
intputWithIntercept | If training includes intercept or not. |
inputMaxIter | Max iteration number of SGD. |
updateParams¶
bool updateParams (ap_uint <32> iterationIndex)
update weight and intercept based on gradient
Parameters:
iterationIndex | iteraton index. |