Skip to content

Commit fc9eacc

Browse files
authored
Merge pull request #305 from yfnaji/ridge-reg
Ridge Regression
2 parents 7e5eb5c + 538abce commit fc9eacc

File tree

2 files changed

+260
-0
lines changed

2 files changed

+260
-0
lines changed

crates/RustQuant_ml/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ pub use linear_regression::*;
3434
pub mod logistic_regression;
3535
pub use logistic_regression::*;
3636

37+
/// Ridge regression.
38+
pub mod ridge_regression;
39+
pub use ridge_regression::*;
40+
3741
/// lasso regression.
3842
pub mod lasso;
3943
pub use lasso::*;
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2+
// RustQuant: A Rust library for quantitative finance tools.
3+
// Copyright (C) 2023 https://github.com/avhz
4+
// Dual licensed under Apache 2.0 and MIT.
5+
// See:
6+
// - LICENSE-APACHE.md
7+
// - LICENSE-MIT.md
8+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9+
10+
//! Module for ridge regression algorithms.
11+
12+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13+
// IMPORTS
14+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15+
16+
use nalgebra::{DMatrix, DVector};
17+
18+
use RustQuant_error::RustQuantError;
19+
20+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21+
// STRUCTS, ENUMS, AND TRAITS
22+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23+
24+
/// Struct to hold the input data for a ridge regression.
25+
#[allow(clippy::module_name_repetitions)]
26+
#[derive(Clone, Debug)]
27+
pub struct RidgeRegressionInput<T> {
28+
/// The features matrix.
29+
pub x: DMatrix<T>,
30+
/// The output data vector, also known as the response vector.
31+
pub y: DVector<T>,
32+
/// The regularization parameter.
33+
pub lambda: T,
34+
/// Include the intercept.
35+
pub fit_intercept: bool,
36+
/// The maximum number of iterations for training.
37+
pub max_iter: usize,
38+
/// The tolerance for the convergence.
39+
pub tolerance: T,
40+
}
41+
42+
/// Struct to hold the output data for a ridge regression.
43+
#[allow(clippy::module_name_repetitions)]
44+
#[derive(Clone, Debug)]
45+
pub struct RidgeRegressionOutput<T> {
46+
/// The intercept of the ridge regression,
47+
pub intercept: T,
48+
/// The coefficients of the ridge regression,
49+
pub coefficients: DVector<T>,
50+
}
51+
52+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
53+
// IMPLEMENTATIONS
54+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55+
56+
impl RidgeRegressionInput<f64> {
57+
/// Create a new `RidgeRegressionInput` struct.
58+
#[must_use]
59+
pub fn new(
60+
x: DMatrix<f64>,
61+
y: DVector<f64>,
62+
lambda: f64,
63+
fit_intercept: bool,
64+
max_iter: usize,
65+
tolerance: f64,
66+
) -> Self {
67+
Self { x, y, lambda, fit_intercept, max_iter, tolerance }
68+
}
69+
70+
/// Fits a ridge regression to the input data.
71+
/// Returns the intercept and coefficients.
72+
/// The intercept is the first value of the coefficients.
73+
pub fn fit(&self) -> Result<RidgeRegressionOutput<f64>, RustQuantError> {
74+
75+
let features_matrix = if self.fit_intercept {
76+
self.x.clone().insert_column(0, 1.)
77+
} else {
78+
self.x.clone()
79+
};
80+
81+
let n_col: usize = features_matrix.ncols();
82+
let features_matrix_transpose = features_matrix.transpose();
83+
let mut regularisation_matrix = DMatrix::<f64>::identity(n_col, n_col);
84+
85+
if self.fit_intercept { regularisation_matrix[(0,0)] = 0.0; }
86+
87+
let ridge_matrix = (&features_matrix_transpose * features_matrix) + self.lambda * regularisation_matrix;
88+
89+
let ridge_matrix_inv = ridge_matrix
90+
.try_inverse()
91+
.ok_or(RustQuantError::MatrixInversionFailed)?;
92+
93+
let mut coefficients = ridge_matrix_inv * &features_matrix_transpose * &self.y;
94+
let intercept: f64 = if self.fit_intercept {
95+
coefficients[0]
96+
} else {
97+
coefficients = coefficients.insert_row(0, 0.0);
98+
0.0
99+
};
100+
101+
Ok(RidgeRegressionOutput {
102+
intercept,
103+
coefficients,
104+
})
105+
}
106+
}
107+
108+
impl RidgeRegressionOutput<f64> {
109+
/// Predicts the output for the given input data.
110+
pub fn predict(&self, input: DMatrix<f64>) -> Result<DVector<f64>, RustQuantError> {
111+
let intercept = DVector::from_element(input.nrows(), self.intercept);
112+
let coefficients = self.coefficients.clone().remove_row(0);
113+
let predictions = input * coefficients + intercept;
114+
Ok(predictions)
115+
}
116+
}
117+
118+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119+
// UNIT TESTS
120+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
121+
122+
#[cfg(test)]
123+
mod tests_ridge_regression {
124+
use super::*;
125+
use RustQuant_utils::assert_approx_equal;
126+
127+
struct DataForTests {
128+
training_set: DMatrix<f64>,
129+
testing_set: DMatrix<f64>,
130+
response: DVector<f64>,
131+
}
132+
133+
fn setup_test() -> DataForTests {
134+
DataForTests {
135+
training_set: DMatrix::from_row_slice(
136+
4,
137+
3,
138+
&[
139+
-0.083_784_355, -0.633_485_70, -0.399_266_60,
140+
-0.982_943_745, 1.090_797_46, -0.468_123_05,
141+
-1.875_067_321, -0.913_727_27, 0.326_962_08,
142+
-0.186_144_661, 1.001_639_71, -0.412_746_90],
143+
),
144+
145+
testing_set: DMatrix::from_row_slice(
146+
4,
147+
3,
148+
&[
149+
0.562_036_47, 0.595_846_45, -0.411_653_01,
150+
0.663_358_26, 0.452_091_83, -0.294_327_15,
151+
-0.602_897_28, 0.896_743_96, 1.218_573_96,
152+
0.698_377_69, 0.572_216_51, 0.244_111_43],
153+
),
154+
155+
response: DVector::from_row_slice(
156+
&[
157+
-0.445_151_96,
158+
-1.847_803_64,
159+
-0.628_825_31,
160+
-0.861_080_69
161+
]
162+
),
163+
}
164+
}
165+
166+
#[test]
167+
fn test_ridge_regression_without_intercept() -> Result<(), RustQuantError> {
168+
169+
let data: DataForTests = setup_test();
170+
171+
let input: RidgeRegressionInput<f64> = RidgeRegressionInput {
172+
x: data.training_set,
173+
y: data.response,
174+
lambda: 1.0,
175+
fit_intercept: false,
176+
max_iter: 1000,
177+
tolerance: 1e-4,
178+
};
179+
180+
let output = input.fit()?;
181+
182+
for (i, coefficient) in output.coefficients.iter().enumerate() {
183+
assert_approx_equal!(
184+
coefficient,
185+
&[
186+
0.0,
187+
0.620_453_495_948_496_1,
188+
-0.420_204_780_485_896_43,
189+
0.490_065_457_911_238_96
190+
][i],
191+
f64::EPSILON
192+
);
193+
}
194+
195+
let predictions = output.predict(data.testing_set)?;
196+
for (i, pred) in predictions.iter().enumerate() {
197+
assert_approx_equal!(
198+
pred,
199+
&[
200+
-0.103_396_954_909_688_48,
201+
0.077_372_233_758_234_32,
202+
-0.153_704_818_231_581,
203+
0.312_493_346_002_296_7
204+
][i],
205+
f64::EPSILON
206+
);
207+
}
208+
Ok(())
209+
}
210+
211+
#[test]
212+
fn test_ridge_regression_with_intercept() -> Result<(), RustQuantError> {
213+
214+
let data: DataForTests = setup_test();
215+
216+
let input: RidgeRegressionInput<f64> = RidgeRegressionInput {
217+
x: data.training_set,
218+
y: data.response,
219+
lambda: 1.0,
220+
fit_intercept: true,
221+
max_iter: 1000,
222+
tolerance: 1e-4,
223+
};
224+
225+
let output = input.fit()?;
226+
227+
for (i, coefficient) in output.coefficients.iter().enumerate() {
228+
assert_approx_equal!(
229+
coefficient,
230+
&[
231+
-0.701_404_539_262_792_8,
232+
0.215_855_099_335_031_66,
233+
-0.371_997_155_606_467_07,
234+
0.104_115_015_026_450_71,
235+
][i],
236+
f64::EPSILON
237+
);
238+
}
239+
240+
let predictions = output.predict(data.testing_set)?;
241+
242+
for (i, pred) in predictions.iter().enumerate() {
243+
assert_approx_equal!(
244+
pred,
245+
&[
246+
-0.844_598_545_101_076_9,
247+
-0.757_036_026_633_643_9,
248+
-1.038_257_347_797_051_1,
249+
-0.738_103_402_522_953_9,
250+
][i],
251+
f64::EPSILON
252+
);
253+
}
254+
Ok(())
255+
}
256+
}

0 commit comments

Comments
 (0)