version 0.4.1
flatassemblermanipulator.hh
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2021-2024 The Ikarus Developers mueller@ibb.uni-stuttgart.de
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
9#pragma once
10
11#include <dune/python/common/typeregistry.hh>
12#include <dune/python/pybind11/eigen.h>
13#include <dune/python/pybind11/pybind11.h>
14#include <dune/python/pybind11/stl.h>
15#include <dune/python/pybind11/stl_bind.h>
16
21#include <ikarus/utils/basis.hh>
23
24namespace Ikarus::Python {
25
26namespace Impl {
27
28 template <typename NewType, typename F, typename... Args, size_t... Indices>
29 decltype(auto) forward_last_as(F&& f, std::index_sequence<Indices...>, Args&&... args) {
30 auto tup = std::forward_as_tuple(args...);
31 return f(std::get<Indices>(tup)..., NewType(std::get<sizeof...(Args) - 1>(tup)));
32 }
33
34 template <typename NewType, typename F>
35 auto wrapFunctionAndReplaceLastType(F&& f) {
36 return [f](auto&&... args) -> decltype(auto) {
37 return forward_last_as<NewType>(f, std::make_index_sequence<sizeof...(args) - 1>{},
38 std::forward<decltype(args)>(args)...);
39 };
40 }
41} // namespace Impl
42
43// Since Pybind11 create a new scipy.sparse.csr_matrix from an Eigen::SparseMatrix, we have to create our own wrapper,
44// which allows the modification of scalar entries of the sparse matrix in Python
45template <typename T>
47{
48 SparseMatrixWrapper(Eigen::SparseMatrix<T>& matrix)
49 : matrixRef(matrix) {}
50 std::reference_wrapper<Eigen::SparseMatrix<T>> matrixRef;
51};
52
53template <typename T>
54void registerSparseMatrixWrapper(pybind11::handle scope) {
55 auto includes = Dune::Python::IncludeFiles{"ikarus/finiteelements/flatassemblermanipulator.hh"};
56 auto [lv, isNotRegistered] = Dune::Python::insertClass<SparseMatrixWrapper<T>>(
57 scope, "SparseMatrixWrapper", Dune::Python::GenerateTypeName(Dune::className<SparseMatrixWrapper<T>>()),
58 includes);
59 if (isNotRegistered) {
60 lv.def(pybind11::init<Eigen::SparseMatrix<T>&>())
61 .def("__setitem__", [](SparseMatrixWrapper<T>& self, std::array<int, 2> index,
62 double val) { self.matrixRef.get().coeffRef(index[0], index[1]) = val; })
63 .def("__getitem__", [](SparseMatrixWrapper<T>& self, std::array<int, 2> index) {
64 return self.matrixRef.get().coeffRef(index[0], index[1]);
65 });
66 }
67}
68
69template <class AssemblerManipulator, class... options>
70void registerAssemblerManipulator(pybind11::handle scope, pybind11::class_<AssemblerManipulator, options...> cls) {
71 using pybind11::operator""_a;
72
73 registerFlatAssembler(scope, cls);
74 registerSparseMatrixWrapper<double>(scope);
75
76 using UnderlyingAssembler = typename AssemblerManipulator::WrappedAssembler;
77
78 cls.def(pybind11::init([](const UnderlyingAssembler& as) { return new AssemblerManipulator(as); }));
79
80 using NewArgs = std::tuple<
82 Eigen::Ref<typename AssemblerManipulator::VectorType>,
83 std::conditional_t<std::is_same_v<typename AssemblerManipulator::MatrixType, Eigen::SparseMatrix<double>>,
84 SparseMatrixWrapper<double>, Eigen::Ref<Eigen::MatrixXd>>>;
85 Dune::Hybrid::forEach(Dune::Hybrid::integralRange(
86 Dune::index_constant<std::tuple_size_v<typename AssemblerManipulator::CallBackTypes>>()),
87 [&](auto i) {
88 using F = std::tuple_element_t<i, typename AssemblerManipulator::CallBackTypes>;
89 using NewArg = std::tuple_element_t<i, NewArgs>;
90 std::string name = std::string("add") +
91 (i == 0 ? "Scalar"
92 : (i == 1) ? "Vector"
93 : "Matrix") +
94 std::string("CallBack");
96 constexpr int lastIndex = Traits::numberOfArguments - 1;
97 // From Python we need a callback that accepts the wrapped types since otherwise Python
98 // creates copies and no modification is possible Therefore, from Python we get a callback in
99 // the style of Fmod= std::function<void(...,Wrapped<Type>)> but in the assembler we store can
100 // only store F=std::function<void(...,Type&)> wrapFunctionAndReplaceLastType takes care of
101 // this and wraps the "Fmod" call inside a "F" function
102 using FMod =
104
105 cls.def(name.c_str(), [&](AssemblerManipulator& self, FMod f) {
106 F fOrig = Impl::wrapFunctionAndReplaceLastType<NewArg>(std::forward<FMod>(f));
107
108 self.bind(std::move(fOrig));
109 });
110 });
111}
112
113} // namespace Ikarus::Python
Contains stl-like type traits.
Python bindings for assemblers.
Provides a wrapper for scalar types to support passing by reference in Python bindings.
Definition of the LinearElastic class for finite element mechanics computations.
void registerFlatAssembler(pybind11::handle scope, pybind11::class_< Assembler, options... > cls)
Register Python bindings for a assembler class. .
Definition: flatassembler.hh:41
void init(int argc, char **argv, bool enableFileLogger=true)
Initializes the Ikarus framework.
Definition: init.hh:82
Definition: flatassembler.hh:21
void registerSparseMatrixWrapper(pybind11::handle scope)
Definition: flatassemblermanipulator.hh:54
void registerAssemblerManipulator(pybind11::handle scope, pybind11::class_< AssemblerManipulator, options... > cls)
Definition: flatassemblermanipulator.hh:70
The AssemblerManipulator defines a decorator for the assemblers that helps to manipulate the assemble...
Definition: assemblermanipulatorfuser.hh:36
Definition: flatassemblermanipulator.hh:47
std::reference_wrapper< Eigen::SparseMatrix< T > > matrixRef
Definition: flatassemblermanipulator.hh:50
SparseMatrixWrapper(Eigen::SparseMatrix< T > &matrix)
Definition: flatassemblermanipulator.hh:48
A wrapper class for scalar types to facilitate reference passing in Python bindings.
Definition: scalarwrapper.hh:27
Type trait for extracting information about functions.
Definition: traits.hh:331
Main function to wrap the type at position pos in a std::function.
Definition: traits.hh:453
Wrapper around Dune-functions global basis.