version 0.4.2
lambertw.hh
Go to the documentation of this file.
1// SPDX-FileCopyrightText: 2021-2025 The Ikarus Developers ikarus@ibb.uni-stuttgart.de
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
9#pragma once
10#include <cmath>
11#include <limits>
12
13#include <dune/common/exceptions.hh>
14#include <dune/common/float_cmp.hh>
15
17
18namespace Ikarus::util {
19
20namespace Impl {
21 template <typename T>
22 T log1p(T x) {
23 return log(T(1.0) + x);
24 }
25} // namespace Impl
26
41template <typename ST = double>
42ST lambertW0(ST z, int maxIterations = 20, ST eps = std::numeric_limits<ST>::epsilon()) {
43 if constexpr (not Concepts::AutodiffScalar<ST>) {
44 if (std::isnan(z))
45 return std::numeric_limits<ST>::quiet_NaN();
46 if (std::isinf(z))
47 return z;
48 }
49
50 const ST branchPoint = -1.0 / std::exp(1.0);
51
52 // If z equals -1/e then W(z) = -1.
53 if (Dune::FloatCmp::eq(z, branchPoint))
54 return -1.0;
55
56 // For branch 0 the domain is z >= -1/e.
57 if (z < branchPoint)
58 DUNE_THROW(Dune::InvalidStateException, "lambertW0: z must be >= -1/e for branch 0");
59
60 // Choose an initial guess x0. See https://en.wikipedia.org/wiki/Lambert_W_function.
61 ST x0;
62 if (Dune::FloatCmp::gt(z, ST(1.0))) {
63 ST lx = log(z);
64 ST llx = log(lx);
65 x0 = lx - llx - 0.5 * Impl::log1p<ST>(-llx / lx);
66 } else {
67 x0 = 0.567 * z;
68 }
69
70 // Begin Halley's iterative method. See https://en.wikipedia.org/wiki/Halley%27s_method.
71 ST x = x0;
72 ST lastDiff = 0.0;
73
74 for (int iter = 0; iter < maxIterations; ++iter) {
75 const ST ex = exp(x);
76 const ST f = x * ex - z;
77 const ST fPrime = ex * (x + 1.0);
78 const ST fPrimePrime = ex * (x + 2.0);
79 const ST denom = fPrime - ((1.0 / 2.0) * (fPrimePrime / fPrime) * f);
80
81 const ST newX = x - f / denom;
82 const ST diff = abs(newX - x);
83
84 // Check for convergence:
85 if (Dune::FloatCmp::le<ST>(diff, 3 * eps * abs(x)) || Dune::FloatCmp::eq<ST>(diff, lastDiff))
86 return newX;
87
88 lastDiff = diff;
89 x = newX;
90 }
91
92 DUNE_THROW(Dune::MathError, "lambertW0: failed to converge within the maximum number of iterations");
93}
94} // namespace Ikarus::util
Definition: lambertw.hh:18
ST lambertW0(ST z, int maxIterations=20, ST eps=std::numeric_limits< ST >::epsilon())
Implementation of the principal branch of the Lambert-W function (branch 0 in the domain ),...
Definition: lambertw.hh:42
Concept to check if the underlying scalar type is a dual type.
Definition: utils/concepts.hh:625
Several concepts.