Line data Source code
1 : /* Copyright (C) 2021 The PARI group.
2 :
3 : This file is part of the PARI/GP package.
4 :
5 : PARI/GP is free software; you can redistribute it and/or modify it under the
6 : terms of the GNU General Public License as published by the Free Software
7 : Foundation; either version 2 of the License, or (at your option) any later
8 : version. It is distributed in the hope that it will be useful, but WITHOUT
9 : ANY WARRANTY WHATSOEVER.
10 :
11 : Check the License for details. You should have received a copy of it, along
12 : with the package; see the file 'COPYING'. If not, write to the Free Software
13 : Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */
14 :
15 : #include "pari.h"
16 : #include "paripriv.h"
17 :
18 : /***********************************************************************/
19 : /** LAMBERT's W_K FUNCTIONS **/
20 : /***********************************************************************/
21 : /* roughly follows Veberic, https://arxiv.org/abs/1003.1628 */
22 :
23 : static double
24 21 : dblL1L2(double L1)
25 : {
26 21 : double L2 = log(-L1), LI = 1 / L1, N2, N3, N4, N5;
27 21 : N2 = (L2-2.)/2.; N3 = (6.+L2*(-9.+2.*L2))/6.;
28 21 : N4 = (-12.+L2*(36.+L2*(-22.+3*L2)))/12.;
29 21 : N5 = (60.+L2*(-300.+L2*(350.+L2*(-125.+12*L2))))/60.;
30 21 : return L1-L2+L2*LI*(1+LI*(N2+LI*(N3+LI*(N4+LI*N5))));
31 : }
32 :
33 : /* rough approximation to W0(a > -1/e), < 1% relative error */
34 : double
35 474466 : dbllambertW0(double a)
36 : {
37 474466 : if (a < -0.2583)
38 : {
39 0 : const double c2 = -1./3, c3 = 11./72, c4 = -43./540, c5 = 769./17280;
40 0 : double p = sqrt(2 * (M_E * a + 1));
41 0 : if (a < -0.3243) return -1+p*(1+p*(c2+p*c3));
42 0 : return -1+p*(1+p*(c2+p*(c3+p*(c4+p*c5))));
43 : }
44 : else
45 : {
46 474466 : double Wd = log(1.+a);
47 474466 : Wd *= (1.-log(Wd/a))/(1.+Wd);
48 474466 : if (a < 0.6482 && a > -0.1838) return Wd;
49 419401 : return Wd*(1.-log(Wd/a))/(1.+Wd);
50 : }
51 : }
52 : /* uniform approximation to W0, at least 15 bits. */
53 : static double
54 532 : dbllambertW0init(double a)
55 : {
56 532 : if (a < -0.323581)
57 : {
58 35 : const double c2 = 1./3., c3 = 11./72., c4 = 43./540., c5 = 769./17280.;
59 35 : const double c6 = 221./8505., c7 = 680863./43545600.;
60 35 : const double c8 = 1963./204120., c9 = 226287557./37623398400.;
61 35 : double p = M_E * a + 1;
62 35 : if (p <= 0) return -1;
63 23 : p = -sqrt(2 * p);
64 23 : return -(1.+p*(1.+p*(c2+p*(c3+p*(c4+p*(c5+p*(c6+p*(c7+p*(c8+p*c9)))))))));
65 : }
66 497 : if (a < 0.145469)
67 : {
68 112 : const double a1 = 5.931375, a2 = 11.392205, a3 = 7.338883, a4 = 0.653449;
69 112 : const double b1 = 6.931373, b2 = 16.823494, b3 = 16.430723, b4 = 5.115235;
70 112 : double n = 1.+a*(a1+a*(a2+a*(a3+a*a4)));
71 112 : double d = 1.+a*(b1+a*(b2+a*(b3+a*b4)));
72 112 : return a * n / d;
73 : }
74 385 : if (a < 8.706658)
75 : {
76 378 : const double a1 = 2.445053, a2 = 1.343664, a3 = 0.148440, a4 = 0.000804;
77 378 : const double b1 = 3.444708, b2 = 3.292489, b3 = 0.916460, b4 = 0.053068;
78 378 : double n = 1.+a*(a1+a*(a2+a*(a3+a*a4)));
79 378 : double d = 1.+a*(b1+a*(b2+a*(b3+a*b4)));
80 378 : return a * n / d;
81 : }
82 : else
83 : {
84 7 : double w = log(1.+a);
85 7 : w *= (1.-log(w/a)) / (1.+w);
86 7 : return w * (1.-log(w/a)) / (1.+w);
87 : }
88 : }
89 :
90 : /* rough approximation to W_{-1}(0 > a > -1/e), < 1% relative error */
91 : double
92 66962 : dbllambertW_1(double a)
93 : {
94 66962 : if (a < -0.2464)
95 : {
96 280 : const double c2 = -1./3, c3 = 11./72, c4 = -43./540, c5 = 769./17280;
97 280 : double p = -sqrt(2 * (M_E * a + 1));
98 280 : if (a < -0.3243) return -1+p*(1+p*(c2+p*c3));
99 175 : return -1+p*(1+p*(c2+p*(c3+p*(c4+p*c5))));
100 : }
101 : else
102 : {
103 : double Wd;
104 66682 : a = -a; Wd = -log(a);
105 66682 : Wd *= (1.-log(Wd/a))/(1.-Wd);
106 66682 : if (a < 0.0056) return -Wd;
107 658 : return -Wd*(1.-log(Wd/a))/(1.-Wd);
108 : }
109 : }
110 : /* uniform approximation to W_{-1}, at least 15 bits. */
111 : static double
112 112 : dbllambertW_1init(double a)
113 : {
114 112 : if (a < -0.302985)
115 : {
116 21 : const double c2 = 1./3., c3 = 11./72., c4 = 43./540., c5 = 769./17280.;
117 21 : const double c6 = 221./8505., c7 = 680863./43545600.;
118 21 : const double c8 = 1963./204120., c9 = 226287557./37623398400.;
119 21 : double p = M_E * a + 1;
120 21 : if (p <= 0) return -1;
121 21 : p = sqrt(2 * p);
122 21 : return -(1.+p*(1.+p*(c2+p*(c3+p*(c4+p*(c5+p*(c6+p*(c7+p*(c8+p*c9)))))))));
123 : }
124 91 : if (a <= -0.051012)
125 : {
126 77 : const double a0 = -7.814176, a1 = 253.888101, a2 = 657.949317;
127 77 : const double b1 = -60.439587, b2 = 99.985670, b3 = 682.607399;
128 77 : const double b4 = 962.178439, b5 = 1477.934128;
129 77 : double n = a0+a*(a1+a*a2);
130 77 : double d = 1+a*(b1+a*(b2+a*(b3+a*(b4+a*b5))));
131 77 : return n / d;
132 : }
133 14 : return dblL1L2(log(-a));
134 : }
135 :
136 : /* uniform approximation to more than 46 bits, 50 bits away from -1/e;
137 : * branch = -1 or 0 */
138 : static double
139 651 : dbllambertWfritsch(GEN ga, int branch)
140 : {
141 : double a, z, w1, q, w;
142 651 : if (expo(ga) >= 0x3fe)
143 : { /* branch = 0 */
144 7 : double w = dbllog2(ga) * M_LN2; /* ~ log(1+a) ~ log a */
145 7 : return w * (1.+w-log(w)) / (1.+w);
146 : }
147 644 : a = rtodbl(ga);
148 644 : w = branch? dbllambertW_1init(a): dbllambertW0init(a);
149 644 : if (w == -1.|| w == 0.) return w;
150 632 : z = log(a / w) - w; w1 = 1. + w;
151 632 : q = 2. * w1 * (w1 + (2./3.) * z);
152 632 : return w * (1 + (z / w1) * (q - z) / (q - 2 * z));
153 : }
154 :
155 : static double
156 7 : dbllambertWhalleyspec(double loga)
157 : {
158 7 : double w = dblL1L2(loga);
159 : for(;;)
160 0 : {
161 7 : double n = w + log(-w) - loga, d = 1 - w, r = n / (d + n / d);
162 7 : w *= 1 - r; if (r < 2.e-15) return w;
163 : }
164 : }
165 : /* k = 0 or -1. */
166 : static GEN
167 658 : lambertW(GEN z, long k, long prec)
168 : {
169 658 : pari_sp av = avma;
170 658 : long bit = prec2nbits(prec), L = -(bit / 3 + 10), ct = 0, p, pb;
171 : double wd;
172 : GEN w, vp;
173 :
174 658 : if (gequal0(z) && !k) return real_0(prec);
175 658 : z = gtofp(z, prec);
176 658 : if (k == -1)
177 : {
178 119 : long e = expo(z);
179 119 : if (signe(z) >= 0) pari_err_DOMAIN("lambertw", "z", ">", gen_0, z);
180 119 : wd = e < -512? dbllambertWhalleyspec(dbllog2(z) * M_LN2)
181 119 : : dbllambertWfritsch(z, -1);
182 : }
183 : else
184 539 : wd = dbllambertWfritsch(z, 0);
185 658 : if (fabs(wd + 1) < 1e-5)
186 : {
187 14 : long prec2 = prec + EXTRAPREC64;
188 14 : GEN Z = rtor(z, prec2);
189 14 : GEN t = addrs(mulrr(Z, gexp(gen_1, prec2)), 1);
190 14 : if (signe(t) <= 0) { set_avma(av); return real_m1(prec); }
191 0 : if (realprec(t) < prec)
192 : {
193 0 : prec2 += prec - realprec(t);
194 0 : Z = rtor(z, prec2);
195 0 : t = addrs(mulrr(Z, gexp(gen_1, prec2)), 1);
196 : }
197 0 : t = sqrtr(shiftr(t, 1));
198 0 : w = gprec_w(k == -1? subsr(-1, t) : subrs(t, 1), prec);
199 0 : p = prec - 2; vp = NULL;
200 : }
201 : else
202 : { /* away from -1/e: can reduce accuracy and self-correct */
203 644 : w = wd == 0.? z: dbltor(wd);
204 644 : vp = cgetg(30, t_VECSMALL); ct = 0; pb = bit;
205 1425 : while (pb > BITS_IN_LONG * 3/4)
206 781 : { vp[++ct] = (pb + BITS_IN_LONG-1) >> TWOPOTBITS_IN_LONG; pb = (pb + 2) / 3; }
207 644 : p = vp[ct]; w = gprec_w(w, p + 2);
208 : }
209 644 : if ((k == -1 && (bit < 192 || bit > 640)) || (k == 0 && bit > 1024))
210 : {
211 : for(;;)
212 13 : {
213 : GEN t, ew, n, d;
214 104 : ew = mplog(divrr(w, z)); n = addrr(w, ew); d = addrs(w, 1);
215 104 : t = divrr(n, shiftr(d, 1));
216 104 : w = mulrr(w, subsr(1, divrr(n, addrr(d, t))));
217 104 : if (p >= prec-2 && expo(n) - expo(d) - expo(w) <= L) break;
218 13 : if (vp) { if (--ct) p = vp[ct]; w = gprec_w(w, ct? p + 2: prec); }
219 : }
220 : }
221 : else
222 : {
223 : for(;;)
224 185 : {
225 : GEN t, ew, wew, n, d;
226 738 : ew = mpexp(w); wew = mulrr(w, ew); n = subrr(wew, z); d = addrr(ew, wew);
227 738 : t = divrr(mulrr(addrs(w, 2), n), shiftr(addrs(w, 1), 1));
228 738 : w = subrr(w, divrr(n, subrr(d, t)));
229 738 : if (p >= prec-2 && expo(n) - expo(d) - expo(w) <= L) break;
230 185 : if (vp) { if (--ct) p = vp[ct]; w = gprec_w(w, ct? p + 2: prec); }
231 : }
232 : }
233 644 : return gerepileupto(av, w);
234 : }
235 :
236 : /*********************************************************************/
237 : /* Complex branches */
238 : /*********************************************************************/
239 :
240 : /* x *= (1 - (x + log(x) - L) / (x + 1)); L = log(z) + 2IPi * k */
241 : static GEN
242 632752 : lamaux(GEN x, GEN L, long *pe, long prec)
243 : {
244 632752 : GEN n = gsub(gadd(x, glog(x, prec)), L);
245 632752 : if (pe) *pe = maxss(4, -gexpo(n));
246 632752 : if (gequal0(imag_i(n))) n = real_i(n);
247 632752 : return gmul(x, gsubsg(1, gdiv(n, gaddsg(1, x))));
248 : }
249 :
250 : /* Complex branches, experimental */
251 : static GEN
252 78428 : lambertWC(GEN z, long branch, long prec)
253 : {
254 78428 : pari_sp av = avma;
255 : GEN w, pii2k, zl, lzl, L, Lz;
256 78428 : long bit0, si, j, fl = 0, lim = 6, lp = DEFAULTPREC, bit = prec2nbits(prec);
257 :
258 78428 : si = gsigne(imag_i(z)); if (!si) z = real_i(z);
259 78428 : pii2k = gmulsg(branch, PiI2(lp));
260 78428 : zl = gtofp(z, lp); lzl = glog(zl, lp);
261 : /* From here */
262 78428 : if (branch == 0 || branch * si < 0
263 26579 : || (si == 0 && gsigne(z) < 0 && branch == -1))
264 : {
265 51989 : GEN lnzl1 = gaddsg(1, glog(gneg(zl), lp));
266 51989 : if (si == 0) si = gsigne(lnzl1);
267 51989 : if ((branch == 0 || branch * si < 0) && gexpo(lnzl1) < -1)
268 : { /* close to -1/e */
269 2408 : w = gaddsg(1, gmul(z, gexp(gen_1, prec)));
270 2408 : w = gprec_wtrunc(w, lp);
271 2408 : w = gsqrt(gmul2n(w, 1), lp);
272 2408 : w = branch * si < 0? gsubsg(-1, w): gaddsg(-1, w);
273 2408 : lim = 10; fl = 1;
274 : }
275 51989 : if (branch == 0 && !fl && gexpo(lzl) < 0) { w = zl; fl = 1; }
276 : }
277 78428 : if (!fl)
278 : {
279 68152 : if (branch)
280 : {
281 51212 : GEN lr = glog(pii2k, lp);
282 51212 : w = gadd(gsub(gadd(pii2k, lzl), lr), gdiv(gsub(lr, lzl), pii2k));
283 : }
284 : else
285 : {
286 16940 : GEN p = gaddsg(1, gmul(z, gexp(gen_1, lp)));
287 16940 : w = gexpo(p) > 0? lzl: gaddgs(gsqrt(p, lp), -1);
288 : }
289 : }
290 : /* to here: heuristic */
291 78428 : L = gadd(lzl, pii2k);
292 480200 : for (j = 1; j < lim; j++) w = lamaux(w, L, NULL, lp);
293 78428 : Lz = NULL;
294 78428 : if (branch == 0 || branch == -1)
295 : {
296 52199 : Lz = glog(z, prec);
297 52199 : if (branch == -1)
298 : {
299 25970 : long flag = 1;
300 25970 : if (!si && signe(z) <= 0 && signe(addrs(Lz, 1))) flag = 0;
301 25970 : if (flag) Lz = gsub(Lz, PiI2(prec));
302 : }
303 : }
304 78428 : w = lamaux(w, L, &bit0, lp);
305 230980 : while (bit0 < bit || (Lz && gexpo(gsub(gadd(w, glog(w, prec)), Lz)) > 16-bit))
306 : {
307 152552 : long p = nbits2prec(bit0 <<= 1);
308 152552 : L = gadd(gmulsg(branch, PiI2(p)), glog(gprec_w(z, p), p));
309 152552 : w = lamaux(gprec_w(w, p), L, NULL, p);
310 : }
311 78428 : return gerepilecopy(av, gprec_w(w, nbits2prec(bit)));
312 : }
313 :
314 : /* exp(t (1 + O(t^n))), n >= 0 */
315 : static GEN
316 154 : serexp0(long v, long n)
317 : {
318 154 : GEN y = cgetg(n+3, t_SER), t;
319 : long i;
320 154 : y[1] = evalsigne(1) | evalvarn(v) | evalvalser(0);
321 154 : gel(y,2) = gen_1; if (!n) return y;
322 147 : gel(y,3) = gen_1; if (n == 1) return y;
323 1295 : for (i=2, t = gen_2; i < n; i++, t = muliu(t,i)) gel(y,i+2) = mkfrac(gen_1,t);
324 119 : gel(y,i+2) = mkfrac(gen_1,t); return y;
325 : }
326 :
327 : /* series expansion of W at -1/e */
328 : static GEN
329 7 : Wbra(long N)
330 : {
331 7 : GEN v = cgetg(N + 2, t_VEC);
332 : long n;
333 7 : gel(v, 1) = gen_m1;
334 7 : gel(v, 2) = gen_1;
335 56 : for (n = 2; n <= N; n++)
336 : {
337 49 : GEN t = gel(v,n), a = gen_0;
338 49 : long k, K = (n - 1) >> 1;
339 133 : for (k = 1; k <= K; k++) t = gadd(t, gmul2n(gel(v,n-2*k), -k));
340 196 : for (k = 2; k < n; k++) a = gadd(a, gmul(gel(v,k+1), gel(v,n+2-k)));
341 49 : gel(v,n+1) = gsub(gdivgs(t, -n-1), gmul2n(a, -1));
342 : }
343 7 : return RgV_to_RgX(v, 0);
344 : }
345 :
346 : static GEN
347 154 : reverse(GEN y)
348 : {
349 154 : GEN z = ser2rfrac_i(y);
350 154 : long l = lg(z);
351 154 : return RgX_to_ser(RgXn_reverse(z, l-2), l-1);
352 : }
353 : static GEN
354 182 : serlambertW(GEN y, long branch, long prec)
355 : {
356 : long n, vy, val, v;
357 182 : GEN t = NULL;
358 :
359 182 : if (!signe(y)) return gcopy(y);
360 182 : v = valser(y);
361 182 : if (v < 0) pari_err_DOMAIN("lambertw","valuation", "<", gen_0, y);
362 175 : if (v > 0 && branch)
363 0 : pari_err_DOMAIN("lambertw [k != 0]", "x", "~", gen_0, y);
364 175 : vy = varn(y); n = lg(y)-3;
365 483 : for (val = 1; val < n; val++)
366 434 : if (!gequal0(polcoef_i(y, val, vy))) break;
367 175 : if (v)
368 : {
369 70 : t = serexp0(vy, n / val);
370 70 : setvalser(t, 1); t = reverse(t); /* rev(x exp(x)) */
371 : }
372 : else
373 : {
374 105 : GEN y0 = gel(y,2), x = glambertW(y0, branch, prec);
375 105 : if (val > n) return scalarser(x, vy, n+1);
376 98 : y = serchop0(y);
377 98 : if (gequalm1(x))
378 : { /* y0 ~ -1/e, branch = 0 or -1 */
379 14 : GEN p = gmul(shiftr(gexp(gen_1,prec), 1), y);
380 14 : if (odd(val)) pari_err(e_MISC, "odd valuation at branch point");
381 7 : p = gsqrt(p, prec); if (odd(branch)) p = gneg(p);
382 7 : n -= val >> 1;
383 7 : t = RgXn_eval(Wbra(n), ser2rfrac_i(p), n);
384 7 : return gtoser(t, varn(t), lg(p));
385 : }
386 84 : t = serexp0(vy, n / val);
387 : /* (x + t) exp(x + t) = (y0 + t y0/x) * exp(t) */
388 84 : t = gmul(deg1pol_shallow(gdiv(y0,x), y0, vy), t);
389 84 : t = gadd(x, reverse(serchop0(t)));
390 : }
391 154 : return normalizeser(gsubst(t, vy, y));
392 : }
393 :
394 : static GEN
395 14 : lambertp(GEN x)
396 : {
397 14 : pari_sp av = avma;
398 : long k;
399 : GEN y;
400 :
401 14 : if (gequal0(x)) return gcopy(x);
402 14 : if (!valp(x)) { x = leafcopy(x); setvalp(x, 1); }
403 14 : k = Qp_exp_prec(x);
404 14 : if (k < 0) return NULL;
405 14 : y = gpowgs(cvstop2(k, x), k - 1);
406 266 : for (k--; k; k--)
407 252 : y = gsub(gpowgs(cvstop2(k, x), k - 1), gdivgu(gmul(x, y), k + 1));
408 14 : return gerepileupto(av, gmul(x, y));
409 : }
410 :
411 : /* y a t_REAL */
412 : static int
413 2219 : useC(GEN y, long k)
414 : {
415 2219 : if (signe(y) > 0 || (k && k != -1)) return k ? 1: 0;
416 980 : return gsigne(addsr(1, logr_abs(y))) > 0;
417 : }
418 : static GEN
419 80808 : glambertW_i(void *E, GEN y, long prec)
420 : {
421 : pari_sp av;
422 80808 : long k = (long)E, p;
423 : GEN z;
424 80808 : if (gequal0(y))
425 : {
426 21 : if (k) pari_err_DOMAIN("glambertW","argument","",gen_0,y);
427 14 : return gcopy(y);
428 : }
429 80787 : switch(typ(y))
430 : {
431 2219 : case t_REAL:
432 2219 : p = minss(prec, realprec(y));
433 2219 : return useC(y, k)? lambertWC(y, k, p): lambertW(y, k, p);
434 14 : case t_PADIC: z = lambertp(y);
435 14 : if (!z) pari_err_DOMAIN("glambertW(t_PADIC)","argument","",gen_0,y);
436 14 : return z;
437 76867 : case t_COMPLEX:
438 76867 : p = precision(y);
439 76867 : return lambertWC(y, k, p? p: prec);
440 1687 : default:
441 1687 : av = avma; if (!(z = toser_i(y))) break;
442 182 : return gerepileupto(av, serlambertW(z, k, prec));
443 : }
444 1505 : return trans_evalgen("lambert", E, glambertW_i, y, prec);
445 : }
446 :
447 : GEN
448 79303 : glambertW(GEN y, long k, long prec)
449 : {
450 79303 : return glambertW_i((void*)k, y, prec);
451 : }
452 : GEN
453 0 : mplambertW(GEN y, long prec) { return lambertW(y, 0, prec); }
454 :
455 : /*********************************************************************/
456 : /* Application */
457 : /*********************************************************************/
458 : /* Solve x - a * log(x) = b with a > 0 and b >= a * (1 - log(a)). */
459 : GEN
460 0 : mplambertx_logx(GEN a, GEN b, long bit)
461 : {
462 0 : pari_sp av = avma;
463 0 : GEN e = gexp(gneg(gdiv(b, a)), nbits2prec(bit));
464 0 : return gerepileupto(av, gmul(gneg(a), lambertW(gneg(gdiv(e, a)), -1, bit)));
465 : }
466 : /* Special case a = 1, b = log(y): solve e^x / x = y with y >= exp(1). */
467 : GEN
468 0 : mplambertX(GEN y, long bit)
469 : {
470 0 : pari_sp av = avma;
471 0 : return gerepileupto(av, gneg(lambertW(gneg(ginv(y)), -1, bit)));
472 : }
473 :
474 : /* Solve x * log(x) - a * x = b; if b < 0, assume a >= 1 + log |b|. */
475 : GEN
476 0 : mplambertxlogx_x(GEN a, GEN b, long bit)
477 : {
478 0 : pari_sp av = avma;
479 0 : long s = gsigne(b);
480 : GEN e;
481 0 : if (!s) return gen_0;
482 0 : e = gexp(gneg(a), nbits2prec(bit));
483 0 : return gerepileupto(av, gdiv(b, lambertW(gmul(b, e), s > 0? 0: -1, bit)));
484 : }
|