version 0.4.1
truncatedconjugategradient.hh
Go to the documentation of this file.
1// Original File: https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h
2// SPDX-FileCopyrightText: 2011-2014 Copyright (C) Gael Guennebaud <gael.guennebaud@inria.fr>
3// SPDX-License-Identifier: MPL-2.0
4// Modifications:
5// SPDX-FileCopyrightText: 2021-2024 The Ikarus Developers mueller@ibb.uni-stuttgart.de
6// SPDX-License-Identifier: LGPL-3.0-or-later
7
19#pragma once
20#include <Eigen/Core>
21#include <Eigen/Dense>
22#include <Eigen/Sparse>
23
24namespace Eigen {
25enum class TCGStopReason : int
26{
33};
34template <typename Scalar>
35struct TCGInfo
36{
38 Scalar Delta = 100000;
39 Scalar kappa = 0.1;
40 Scalar theta = 1.0;
41 Eigen::Index mininner = 1;
42 Eigen::Index maxinner = 1000;
43 Eigen::Index numInnerIter = 0;
44
45 void initRuntimeOptions(int _num_dof_solve) {
46 maxinner = _num_dof_solve * 2;
47 Delta = 100000; // typical distance of manifold
48 // Delta0 = Delta_bar/8.0;
49 }
50};
51namespace internal {
52
68 template <typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
69 void truncated_conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, const Preconditioner& precond,
70 Eigen::Index& iters, typename Dest::RealScalar& tol_error,
72 using std::abs;
73 using std::sqrt;
74 typedef typename Dest::RealScalar RealScalar;
75 typedef typename Dest::Scalar Scalar;
76 typedef Matrix<Scalar, Dynamic, 1> VectorType;
77
78 RealScalar tol = tol_error;
79 Index maxIters = iters;
80
81 Index n = mat.cols();
82
83 VectorType residual = rhs - mat * x; // initial residual
84
85 RealScalar rhsNorm2 = rhs.norm();
86 const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
87
88 if (rhsNorm2 <= considerAsZero) {
89 x.setZero();
90 iters = 0;
91 tol_error = 0;
92 return;
93 }
94 RealScalar threshold = numext::maxi(tol * tol * rhsNorm2 * rhsNorm2, considerAsZero);
95 RealScalar residualNorm2 = residual.norm();
96 if (residualNorm2 * residualNorm2 < threshold) {
97 iters = 0;
98 tol_error = (residualNorm2 / rhsNorm2);
99 return;
100 }
101
102 double e_Pd = 0.0;
103 double e_Pe_new = 0.0;
104 double e_Pe = x.squaredNorm();
105 double d_Pd;
106 double d_Hd;
107 VectorType p(n);
108 p = precond.solve(residual); // initial search direction
109 // bool coutflag=true;
111 VectorType z(n), tmp(n);
112 RealScalar absNew = numext::real(residual.dot(p)); // the square of the absolute value of r scaled by invM
113 d_Pd = absNew;
114 Index i = 1;
115 while (i < maxIters) {
116 tmp.noalias() = mat * p; // the bottleneck of the algorithm
117 d_Hd = p.dot(tmp);
118 Scalar alpha = absNew / d_Hd; // the amount we travel on dir
119
120 e_Pe_new = e_Pe + 2.0 * alpha * e_Pd + alpha * alpha * d_Pd;
121
122 if (d_Hd <= 0 || e_Pe_new >= _info.Delta * _info.Delta) // negative curvature or execdet trustregion
123 {
124 double tau = (-e_Pd + sqrt(e_Pd * e_Pd + d_Pd * (_info.Delta * _info.Delta - e_Pe))) / d_Pd;
125
126 x += tau * p;
127 if (d_Hd <= 0)
129 else
131
132 break;
133 }
134 e_Pe = e_Pe_new;
135 x += alpha * p; // update solution
136 residual -= alpha * tmp; // update residual
137
138 residualNorm2 = residual.norm();
139
140 if (i >= _info.mininner &&
141 residualNorm2 <= rhsNorm2 * std::min(rhsNorm2, _info.kappa)) // missing pow(rhsNorm2,_info.theta
142 {
143 // Residual is small enough to quit
144 if (_info.kappa < rhsNorm2)
146 else
148 break;
149 }
150 if (residualNorm2 < threshold)
151 break;
152
153 z = precond.solve(residual); // approximately solve for "A z = residual"
154
155 RealScalar absOld = absNew;
156 absNew = numext::real(residual.dot(z)); // update the absolute value of r
157 RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction
158
159 e_Pd = beta * (e_Pd + alpha * d_Pd);
160 d_Pd = absNew + beta * beta * d_Pd;
161
162 p = z + beta * p; // update search direction
163 i++;
164 }
165 tol_error = (residualNorm2 / rhsNorm2);
166 iters = i;
167 _info.numInnerIter = i;
168 }
169
170} // namespace internal
171
172template <typename MatrixType, int UpLo = Lower,
173 typename Preconditioner = DiagonalPreconditioner<typename MatrixType::Scalar> >
174class TruncatedConjugateGradient;
175
176namespace internal {
177
178 template <typename MatrixType_, int UpLo, typename Preconditioner_>
179 struct traits<TruncatedConjugateGradient<MatrixType_, UpLo, Preconditioner_> >
180 {
181 typedef MatrixType_ MatrixType;
182 typedef Preconditioner_ Preconditioner;
183 };
184
185} // namespace internal
186
194template <typename M, int upLo, typename PC>
195class TruncatedConjugateGradient : public IterativeSolverBase<TruncatedConjugateGradient<M, upLo, PC> >
196{
197public:
198 typedef IterativeSolverBase<TruncatedConjugateGradient> Base;
200 : Base(std::move(other)),
201 algInfo_{other.algInfo_} {}
202
203private:
204 using Base::m_error;
205 using Base::m_info;
206 using Base::m_isInitialized;
207 using Base::m_iterations;
208 using Base::matrix;
209 mutable TCGInfo<typename M::RealScalar> algInfo_;
210
211public:
212 using MatrixType = M;
213 using Scalar = typename MatrixType::Scalar;
214 using RealScalar = typename MatrixType::RealScalar;
215 using Preconditioner = PC;
216
217 enum
218 {
219 UpLo = upLo
220 };
221
222public:
228
233 void setInfo(TCGInfo<typename MatrixType::RealScalar> alginfo) { this->algInfo_ = alginfo; }
236 : Base() {}
237
248 template <typename MatrixDerived>
250 : Base(A.derived()) {}
251
253
255 template <typename Rhs, typename Dest>
256 void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const {
257 typedef typename Base::MatrixWrapper MatrixWrapper;
258 typedef typename Base::ActualMatrixType ActualMatrixType;
259 enum
260 {
261 TransposeInput = (!MatrixWrapper::MatrixFree) && (UpLo == (Lower | Upper)) && (!MatrixType::IsRowMajor) &&
262 (!NumTraits<Scalar>::IsComplex)
263 };
264 typedef std::conditional_t<TransposeInput, Transpose<const ActualMatrixType>, const ActualMatrixType&>
265 RowMajorWrapper;
266 EIGEN_STATIC_ASSERT(internal::check_implication(MatrixWrapper::MatrixFree, UpLo == (Lower | Upper)),
267 MATRIX_FREE_CONJUGATE_GRADIENT_IS_COMPATIBLE_WITH_UPPER_UNION_LOWER_MODE_ONLY);
268 typedef std::conditional_t<UpLo == (Lower | Upper), RowMajorWrapper,
269 typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type>
270 SelfAdjointWrapper;
271
272 m_iterations = Base::maxIterations();
273 m_error = Base::m_tolerance;
274
275 RowMajorWrapper row_mat(matrix());
276 internal::truncated_conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations,
277 m_error, algInfo_);
278 m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
279 }
280};
281
282} // end namespace Eigen
Definition: truncatedconjugategradient.hh:24
TCGStopReason
Definition: truncatedconjugategradient.hh:26
void truncated_conjugate_gradient(const MatrixType &mat, const Rhs &rhs, Dest &x, const Preconditioner &precond, Eigen::Index &iters, typename Dest::RealScalar &tol_error, TCGInfo< typename Dest::RealScalar > &_info)
Definition: truncatedconjugategradient.hh:69
Definition: truncatedconjugategradient.hh:36
Scalar kappa
Definition: truncatedconjugategradient.hh:39
Eigen::Index mininner
Definition: truncatedconjugategradient.hh:41
Eigen::Index maxinner
Definition: truncatedconjugategradient.hh:42
Scalar Delta
Definition: truncatedconjugategradient.hh:38
void initRuntimeOptions(int _num_dof_solve)
Definition: truncatedconjugategradient.hh:45
TCGStopReason stop_tCG
Definition: truncatedconjugategradient.hh:37
Scalar theta
Definition: truncatedconjugategradient.hh:40
Eigen::Index numInnerIter
Definition: truncatedconjugategradient.hh:43
Iterative solver for solving linear systems using the truncated conjugate gradient method.
Definition: truncatedconjugategradient.hh:196
IterativeSolverBase< TruncatedConjugateGradient > Base
Definition: truncatedconjugategradient.hh:198
@ UpLo
Definition: truncatedconjugategradient.hh:219
TruncatedConjugateGradient(TruncatedConjugateGradient &&other) noexcept
Definition: truncatedconjugategradient.hh:199
TruncatedConjugateGradient()
Definition: truncatedconjugategradient.hh:235
typename MatrixType::RealScalar RealScalar
Definition: truncatedconjugategradient.hh:214
PC Preconditioner
Definition: truncatedconjugategradient.hh:215
void _solve_vector_with_guess_impl(const Rhs &b, Dest &x) const
Definition: truncatedconjugategradient.hh:256
TCGInfo< typename MatrixType::RealScalar > getInfo()
Get information about the truncated conjugate gradient algorithm.
Definition: truncatedconjugategradient.hh:227
typename MatrixType::Scalar Scalar
Definition: truncatedconjugategradient.hh:213
~TruncatedConjugateGradient()
Definition: truncatedconjugategradient.hh:252
void setInfo(TCGInfo< typename MatrixType::RealScalar > alginfo)
Set information about the truncated conjugate gradient algorithm.
Definition: truncatedconjugategradient.hh:233
M MatrixType
Definition: truncatedconjugategradient.hh:212
TruncatedConjugateGradient(const EigenBase< MatrixDerived > &A)
Definition: truncatedconjugategradient.hh:249
Preconditioner_ Preconditioner
Definition: truncatedconjugategradient.hh:182
Definition: concepts.hh:30