version 0.4.1
autodifffe.hh
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2021-2024 The Ikarus Developers mueller@ibb.uni-stuttgart.de
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
9#pragma once
10
11#include <autodiff/forward/dual/dual.hpp>
12#include <autodiff/forward/dual/eigen.hpp>
13
16
17namespace Ikarus {
18
26template <typename FEImpl, bool forceAutoDiff = false>
27class AutoDiffFE : public FEImpl
28{
29public:
30 using RealFE = FEImpl;
31 using BasisHandler = typename RealFE::BasisHandler;
32 using Traits = typename RealFE::Traits;
33 using LocalView = typename Traits::LocalView;
34 using Element = typename Traits::Element;
35 using FERequirementType = typename Traits::FERequirementType;
36private:
37 using Mixin = FEImpl::Mixin;
38
39public:
47 friend void calculateMatrix(const AutoDiffFE& self, const FERequirementType& par,
48 typename Traits::template MatrixType<> h) {
49 self.calculateMatrix(par, h);
50 }
51
59 friend void calculateVector(const AutoDiffFE& self, const FERequirementType& par,
60 typename Traits::template VectorType<double> g) {
61 self.calculateVector(par, g);
62 }
63
72 friend void calculateLocalSystem(const AutoDiffFE& self, const FERequirementType& par,
73 typename Traits::template MatrixType<> h, typename Traits::template VectorType<> g) {
74 self.calculateLocalSystem(par, h, g);
75 }
76
84 friend auto calculateScalar(const AutoDiffFE& self, const FERequirementType& par) {
85 return self.calculateScalar(par);
86 }
87
93 const RealFE& realFE() const { return *this; }
94
102 template <typename... Args>
103 explicit AutoDiffFE(Args&&... args)
104 : RealFE{std::forward<Args>(args)...} {}
105
106private:
107 void calculateMatrix(const FERequirementType& req, typename Traits::template MatrixType<> h) const {
108 // real element implements calculateMatrix by itself, then we simply forward the call
109
110 if constexpr (requires(Eigen::VectorXd v) {
111 static_cast<const Mixin&>(std::declval<AutoDiffFE>())
112 .template calculateMatrixImpl<double>(req, h, v);
113 } and not forceAutoDiff) {
114 return Mixin::template calculateMatrixImpl<double>(req, h);
115 } else if constexpr (requires(Eigen::VectorXdual v) {
116 static_cast<const Mixin&>(std::declval<AutoDiffFE>())
117 .template calculateVectorImpl<autodiff::dual>(
118 req, std::declval<typename Traits::template VectorType<autodiff::dual>>(), v);
119 }) {
120 // real element implements calculateVector by itself, therefore we only need first order derivatives
121 Eigen::VectorXdual dx(this->localView().size());
122 Eigen::VectorXdual g(this->localView().size());
123 dx.setZero();
124 auto f = [this, &req, &g](auto& x) -> auto& {
125 // Since req is const as a function argument, we can not make this lambda capture by mutable reference
126 // But we have to do this since for efficiency reason we reuse the g vector
127 // therefore, the only remaining option is to cast the const away from g
128 Eigen::VectorXdual& gref = const_cast<Eigen::VectorXdual&>(g);
129 gref.setZero();
130 Mixin::template calculateVectorImpl<autodiff::dual>(req, gref, x);
131 return g;
132 };
133 jacobian(f, autodiff::wrt(dx), at(dx), g, h);
134 } else if constexpr (requires(typename Traits::template VectorType<autodiff::dual2nd> v) {
135 static_cast<const Mixin&>(std::declval<AutoDiffFE>())
136 .template calculateScalarImpl<autodiff::dual2nd>(req, v);
137 }) {
138 // real element implements calculateScalar by itself, therefore we need second order derivatives
139 Eigen::VectorXdual2nd dx(this->localView().size());
140 Eigen::VectorXd g;
141 autodiff::dual2nd e;
142 dx.setZero();
143 auto f = [this, &req](auto& x) { return Mixin::template calculateScalarImpl<autodiff::dual2nd>(req, x); };
144 hessian(f, autodiff::wrt(dx), at(dx), e, g, h);
145 } else
146 static_assert(Dune::AlwaysFalse<AutoDiffFE>::value,
147 "Appropriate calculateScalarImpl or calculateVectorImpl functions are not implemented for the "
148 "chosen element.");
149 }
150
151 void calculateVector(const FERequirementType& req, typename Traits::template VectorType<> g) const {
152 // real element implements calculateVector by itself, then we simply forward the call
153 if constexpr (requires {
154 static_cast<const Mixin&>(std::declval<AutoDiffFE>())
155 .template calculateVectorImpl<double>(
156 req, std::declval<typename Traits::template VectorType<double>>(),
157 std::declval<const Eigen::VectorXd&>());
158 } and not forceAutoDiff) {
159 return Mixin::template calculateVectorImpl<double>(req, g);
160 } else if constexpr (requires {
161 static_cast<const Mixin&>(std::declval<AutoDiffFE>())
162 .template calculateScalarImpl<autodiff::dual>(req,
163 std::declval<const Eigen::VectorXdual&>());
164 }) {
165 // real element implements calculateScalar by itself but no calculateVectorImpl, therefore we need first order
166 // derivatives
167 Eigen::VectorXdual dx(this->localView().size());
168 dx.setZero();
169 autodiff::dual e;
170 auto f = [this, &req](auto& x) { return Mixin::template calculateScalarImpl<autodiff::dual>(req, x); };
171 gradient(f, autodiff::wrt(dx), at(dx), e, g);
172 } else
173 static_assert(Dune::AlwaysFalse<AutoDiffFE>::value,
174 "Appropriate calculateScalarImpl function is not implemented for the "
175 "chosen element.");
176 }
177
178 [[nodiscard]] double calculateScalar(const FERequirementType& par) const {
179 // real element implements calculateScalar by itself, then we simply forward the call
180 if constexpr (requires {
181 static_cast<const Mixin&>(std::declval<AutoDiffFE>()).template calculateScalarImpl<double>(par);
182 }) {
183 Mixin::template calculateScalarImpl<double>(par);
184 // real element only implements the protected calculateScalarImpl by itself, thus we call that one.
185 return Mixin::template calculateScalarImpl<double>(par);
186 } else {
187 static_assert(Dune::AlwaysFalse<AutoDiffFE>::value,
188 "Appropriate calculateScalar and calculateScalarImpl functions are not implemented for the "
189 "chosen element.");
190 }
191 }
192
193 void calculateLocalSystem(const FERequirementType& req, typename Traits::template MatrixType<> h,
194 typename Traits::template VectorType<> g) const {
195 Eigen::VectorXdual2nd dx(this->localView().size());
196 dx.setZero();
197 auto f = [&](auto& x) { return Mixin::calculateScalarImpl(req, x); };
198 hessian(f, autodiff::wrt(dx), at(dx), g, h);
199 }
200};
201} // namespace Ikarus
Contains stl-like type traits.
Definition of the LinearElastic class for finite element mechanics computations.
Definition: simpleassemblers.hh:22
AutoDiffFE class, an automatic differentiation wrapper for finite elements.
Definition: autodifffe.hh:28
AutoDiffFE(Args &&... args)
Constructor for the AutoDiffFE class. Forward the construction to the underlying element.
Definition: autodifffe.hh:103
typename Traits::FERequirementType FERequirementType
Type of the Finite Element Requirements.
Definition: autodifffe.hh:35
typename RealFE::Traits Traits
Type traits for local view.
Definition: autodifffe.hh:32
friend void calculateMatrix(const AutoDiffFE &self, const FERequirementType &par, typename Traits::template MatrixType<> h)
Calculate the matrix associated with the finite element.
Definition: autodifffe.hh:47
friend auto calculateScalar(const AutoDiffFE &self, const FERequirementType &par)
Calculate the scalar value associated with the finite element.
Definition: autodifffe.hh:84
const RealFE & realFE() const
Get the reference to the base finite element.
Definition: autodifffe.hh:93
friend void calculateLocalSystem(const AutoDiffFE &self, const FERequirementType &par, typename Traits::template MatrixType<> h, typename Traits::template VectorType<> g)
Calculate the local system associated with the finite element.
Definition: autodifffe.hh:72
FEImpl RealFE
Type of the base finite element.
Definition: autodifffe.hh:30
typename RealFE::BasisHandler BasisHandler
Type of the basis handler.
Definition: autodifffe.hh:31
friend void calculateVector(const AutoDiffFE &self, const FERequirementType &par, typename Traits::template VectorType< double > g)
Calculate the vector associated with the finite element.
Definition: autodifffe.hh:59
typename Traits::LocalView LocalView
Type of the local view.
Definition: autodifffe.hh:33
typename Traits::Element Element
Type of the element.
Definition: autodifffe.hh:34