Skip to content

Commit 36bcd52

Browse files
authored
Merge pull request #1226 from LourensVeen/issue_1221_protoplanetary_disk_broken
Fix protoplanetary disk parameter rename
2 parents 230ca21 + 09d8e90 commit 36bcd52

3 files changed

Lines changed: 220 additions & 133 deletions

File tree

src/amuse/ext/protodisk.py

Lines changed: 125 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy
22
import warnings
3+
from amuse.support.helpers import rename_fn_par
34
from amuse.ext.evrard_test import body_centered_grid_unit_cube
45
from amuse.ext.evrard_test import regular_grid_unit_cube
56
from amuse.ext.evrard_test import uniform_random_unit_cube
@@ -8,157 +9,148 @@
89

910
from amuse.datamodel import Particles
1011
from amuse.datamodel import ParticlesWithUnitsConverted
12+
13+
1114
def approximate_inverse_error_function(x):
12-
a=8*(numpy.pi-3)/3*numpy.pi*(4-numpy.pi)
13-
return numpy.sign(x)*numpy.sqrt(
14-
numpy.sqrt((2/numpy.pi/a+numpy.log(1-x**2)/2)**2-numpy.log(1-x**2)/a)-(2/numpy.pi/a+numpy.log(1-x**2)/2)
15-
)
16-
17-
class uniform_unit_cylinder(object):
18-
def __init__(self,targetN, base_grid=None):
19-
cube_cylinder_ratio=numpy.pi*0.5**2
20-
self.targetN=targetN
21-
self.estimatedN=targetN/cube_cylinder_ratio
15+
a = 8 * (numpy.pi - 3) / 3 * numpy.pi * (4 - numpy.pi)
16+
17+
return numpy.sign(x) * numpy.sqrt(
18+
numpy.sqrt(
19+
(2 / numpy.pi / a + numpy.log(1 - x**2) / 2)**2 - numpy.log(1 - x**2) / a
20+
) - (2 / numpy.pi / a + numpy.log(1 - x**2) / 2)
21+
)
22+
23+
class uniform_unit_cylinder:
24+
def __init__(self, targetN, base_grid=None):
25+
cube_cylinder_ratio = numpy.pi*0.5**2
26+
self.targetN = targetN
27+
self.estimatedN = targetN / cube_cylinder_ratio
28+
2229
if base_grid is None:
23-
self.base_grid=uniform_random_unit_cube
30+
self.base_grid = uniform_random_unit_cube
2431
else:
25-
self.base_grid=base_grid
26-
27-
def cutout_cylinder(self,x,y,z):
28-
r=x**2+y**2
29-
selection=r < numpy.ones_like(r)
30-
x=x.compress(selection)
31-
y=y.compress(selection)
32-
z=z.compress(selection)
33-
return x,y,z
32+
self.base_grid = base_grid
33+
34+
def cutout_cylinder(self, x, y, z):
35+
r = x**2 + y**2
36+
selection = r < numpy.ones_like(r)
37+
x = x.compress(selection)
38+
y = y.compress(selection)
39+
z = z.compress(selection)
40+
return x, y, z
3441

3542
def make_xyz(self):
36-
if(self.base_grid==uniform_random_unit_cube):
37-
estimatedN=self.estimatedN
38-
x=[]
43+
if (self.base_grid == uniform_random_unit_cube):
44+
estimatedN = self.estimatedN
45+
x = []
3946
while len(x) < self.targetN:
40-
estimadedN=estimatedN*1.1+1
41-
x,y,z=self.cutout_cylinder(*(self.base_grid(estimatedN)).make_xyz())
42-
return x[0:self.targetN],y[0:self.targetN],z[0:self.targetN]
47+
estimadedN = estimatedN * 1.1 + 1
48+
x, y, z = self.cutout_cylinder(*(self.base_grid(estimatedN)).make_xyz())
49+
50+
return x[0:self.targetN], y[0:self.targetN], z[0:self.targetN]
4351
else:
4452
return self.cutout_cylinder(*(self.base_grid(self.estimatedN)).make_xyz())
4553

4654

4755
class ProtoPlanetaryDisk:
48-
4956
def __init__(
5057
self, targetN, convert_nbody=None, discfraction=0.1,
51-
densitypower=1., thermalpower=0.5, radius_min=1, radius_max=100,
52-
gamma=1.,q_out=2.,base_grid=None, Rmin=None, Rmax=None,
58+
densitypower=1., thermalpower=0.5, radius_min=None, radius_max=None,
59+
gamma=1., q_out=2., base_grid=None, Rmin=None, Rmax=None,
5360
):
54-
if Rmin is not None:
55-
warnings.warn(
56-
"Rmin is deprecated, use radius_min instead",
57-
category=FutureWarning,
58-
)
59-
if radius_min is not None and radius_min != Rmin:
60-
raise ValueError(
61-
"Rmin and radius_min have different values, "
62-
"this is only allowed if one of them is None"
63-
)
64-
radius_min = Rmin
65-
if radius_min is None:
66-
raise ValueError("radius_min must be set")
67-
if Rmax is not None:
68-
warnings.warn(
69-
"Rmax is deprecated, use radius_max instead",
70-
category=FutureWarning,
71-
)
72-
if radius_max is not None and radius_max != Rmax:
73-
raise ValueError(
74-
"Rmax and radius_max have different values, "
75-
"this is only allowed if one of them is None"
76-
)
77-
radius_max = Rmax
78-
if radius_max is None:
79-
raise ValueError("radius_max must be set")
80-
81-
self.targetN=targetN
82-
self.convert_nbody=convert_nbody
83-
self.densitypower=densitypower
84-
self.thermalpower=thermalpower
85-
self.Rmin=radius_min
86-
self.Rmax=radius_max
87-
self.gamma=gamma
88-
self.q_out=q_out
89-
self.discfraction=discfraction
90-
91-
self.a=self.thermalpower
92-
self.a2=self.thermalpower/2
93-
self.g=densitypower
94-
self.g2=2-densitypower
95-
self.k_out=((1+discfraction)/Rmax**3)**0.5
96-
self.sigma_out=self.g2*discfraction/(2*numpy.pi*Rmax**self.g*(Rmax**self.g2-Rmin**self.g2))
97-
self.cs_out=self.q_out*numpy.pi*self.sigma_out/self.k_out
98-
99-
self.base_cylinder=uniform_unit_cylinder(targetN,base_grid)
100-
101-
102-
def sigma(self,r):
103-
return self.sigma_out*(self.Rmax/r)**self.g
104-
105-
def csound(self,r):
106-
return self.cs_out*(self.Rmax/r)**self.a2
107-
108-
def cmass(self,r):
109-
return self.discfraction*(r**self.g2-self.Rmin**self.g2)/(self.Rmax**self.g2-self.Rmin**self.g2)
110-
111-
def mass_encl(self,r):
112-
return 1+self.cmass(r)
113-
114-
def kappa(self,r):
115-
return (self.mass_encl(r)/r**3)**0.5
116-
117-
def toomreQ(self,r):
118-
return self.csound(r)*self.kappa(r)/numpy.pi/self.sigma(r)
119-
120-
def getradius(self,f):
121-
return ((self.Rmax**self.g2-self.Rmin**self.g2)*f+self.Rmin**self.g2)**(1./self.g2)
122-
123-
def zscale(self,r):
124-
return self.csound(r)/self.kappa(r)
125-
126-
def u(self,r):
127-
if self.gamma ==1.:
61+
self.targetN = targetN
62+
self.convert_nbody = convert_nbody
63+
self.densitypower = densitypower
64+
self.thermalpower = thermalpower
65+
self.Rmin = rename_fn_par("radius_min", radius_min, "Rmin", Rmin, 1)
66+
self.Rmax = rename_fn_par("radius_max", radius_max, "Rmax", Rmax, 100)
67+
self.gamma = gamma
68+
self.q_out = q_out
69+
self.discfraction = discfraction
70+
71+
self.a = self.thermalpower
72+
self.a2 = self.thermalpower/2
73+
self.g = densitypower
74+
self.g2 = 2 - densitypower
75+
self.k_out = ((1 + discfraction) / self.Rmax**3)**0.5
76+
self.sigma_out = self.g2 * discfraction / (
77+
2 * numpy.pi * self.Rmax**self.g *
78+
(self.Rmax**self.g2 - self.Rmin**self.g2))
79+
self.cs_out = self.q_out * numpy.pi * self.sigma_out / self.k_out
80+
81+
self.base_cylinder = uniform_unit_cylinder(targetN, base_grid)
82+
83+
@property
84+
def radius_min(self):
85+
return self.Rmin
86+
87+
@property
88+
def radius_max(self):
89+
return self.Rmax
90+
91+
def sigma(self, r):
92+
return self.sigma_out * (self.Rmax / r)**self.g
93+
94+
def csound(self, r):
95+
return self.cs_out * (self.Rmax / r)**self.a2
96+
97+
def cmass(self, r):
98+
return self.discfraction * (r**self.g2 - self.Rmin**self.g2) / (
99+
self.Rmax**self.g2 - self.Rmin**self.g2)
100+
101+
def mass_encl(self, r):
102+
return 1 + self.cmass(r)
103+
104+
def kappa(self, r):
105+
return (self.mass_encl(r) / r**3)**0.5
106+
107+
def toomreQ(self, r):
108+
return self.csound(r) * self.kappa(r) / numpy.pi / self.sigma(r)
109+
110+
def getradius(self, f):
111+
return (
112+
(self.Rmax**self.g2 - self.Rmin**self.g2) * f + self.Rmin**self.g2)**(
113+
1./self.g2)
114+
115+
def zscale(self, r):
116+
return self.csound(r) / self.kappa(r)
117+
118+
def u(self, r):
119+
if self.gamma == 1.:
128120
return self.csound(r)**2
129121
else:
130-
return self.csound(r)**2/(self.gamma-1)
122+
return self.csound(r)**2 / (self.gamma - 1)
123+
124+
def vcirc(self, r):
125+
return (self.mass_encl(r) / r)**0.5
131126

132-
def vcirc(self,r):
133-
return (self.mass_encl(r)/r)**0.5
134-
135127
def new_model(self):
136-
x,y,z=self.base_cylinder.make_xyz()
137-
self.actualN=len(x)
138-
f=x**2+y**2
139-
r=f**0.5
140-
rtarget=self.getradius(f)
141-
142-
mass=self.discfraction*numpy.ones_like(x)/self.actualN
143-
internal_energy=self.u(rtarget)
144-
zscale=self.zscale(rtarget)
145-
r=r.clip(1.e-8,2.)
146-
x=x/r
147-
y=y/r
148-
149-
vx=-y*self.vcirc(rtarget)
150-
vy=x*self.vcirc(rtarget)
151-
vz=numpy.zeros_like(x)
152-
153-
x=rtarget*x
154-
y=rtarget*y
155-
z=approximate_inverse_error_function(z)*zscale*2.**0.5
128+
x, y, z = self.base_cylinder.make_xyz()
129+
self.actualN = len(x)
130+
f = x**2 + y**2
131+
r = f**0.5
132+
rtarget = self.getradius(f)
133+
134+
mass = self.discfraction * numpy.ones_like(x) / self.actualN
135+
internal_energy = self.u(rtarget)
136+
zscale = self.zscale(rtarget)
137+
r = r.clip(1.e-8, 2.)
138+
x= x / r
139+
y= y / r
140+
141+
vx = -y * self.vcirc(rtarget)
142+
vy = x * self.vcirc(rtarget)
143+
vz = numpy.zeros_like(x)
144+
145+
x = rtarget * x
146+
y = rtarget * y
147+
z = approximate_inverse_error_function(z) * zscale * 2.**0.5
156148

157149
return (mass, x, y, z, vx, vy, vz, internal_energy)
158-
150+
159151
@property
160152
def result(self):
161-
masses, x,y,z, vx,vy,vz, internal_energies = self.new_model()
153+
masses, x, y, z, vx, vy, vz, internal_energies = self.new_model()
162154
result = Particles(self.actualN)
163155
result.mass = nbody_system.mass.new_quantity(masses)
164156
result.x = nbody_system.length.new_quantity(x)
@@ -168,10 +160,10 @@ def result(self):
168160
result.vy = nbody_system.speed.new_quantity(vy)
169161
result.vz = nbody_system.speed.new_quantity(vz)
170162
result.u = nbody_system.specific_energy.new_quantity(internal_energies)
171-
163+
172164
if not self.convert_nbody is None:
173-
result = ParticlesWithUnitsConverted(result, self.convert_nbody.as_converter_from_si_to_generic())
165+
result = ParticlesWithUnitsConverted(
166+
result, self.convert_nbody.as_converter_from_si_to_generic())
174167
result = result.copy()
175-
176-
return result
177168

169+
return result

src/amuse/support/helpers.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import warnings
2+
3+
4+
def rename_fn_par(new_name, new_value, old_name, old_value, default_value):
5+
"""Get value of a renamed function parameter.
6+
7+
If you have a function f with parameter ``x`` and default value 1, like this:
8+
9+
.. code-block::
10+
11+
def f(x=1):
12+
...
13+
14+
and you want to rename ``x`` to ``y`` while staying backwards compatible, then you
15+
can rename ``x`` to ``y``, add ``x`` back in at the end so that old keyword
16+
arguments still work, give both a default value of ``None``, and use this function
17+
like this:
18+
19+
.. code-block::
20+
21+
def f(y=None, x=None):
22+
value = rename_fn_par("y", y, "x", x, 1)
23+
24+
Any callers using ``x`` explicitly will receive a warning to change their code to
25+
use ``y`` in the future. If both ``x`` and ``y`` are set and the values are
26+
different, then an exception is raised.
27+
28+
Args:
29+
new_name (str): The new name of the variable
30+
new_value: The value passed using the new name
31+
old_name (str): The old name of the variable
32+
old_value: The value passed using the old name
33+
default_value: The default value if neither are set
34+
35+
Returns:
36+
Either new_value or old_value if only one is set or they're set to the same
37+
value, or default_value if neither is set.
38+
39+
Raises:
40+
ValueError: If both new_value and old_value are set, and to different values.
41+
"""
42+
if new_value is not None:
43+
if old_value is not None and old_value != new_value:
44+
raise ValueError(
45+
f"{old_name} and {new_name} have different values,"
46+
" which is not allowed because they represent the same thing.")
47+
return new_value
48+
49+
if old_value is not None:
50+
warnings.warn(
51+
f"{old_name} is deprecated, please use {new_name} instead",
52+
category=FutureWarning)
53+
return old_value
54+
55+
return default_value
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from amuse.support.helpers import rename_fn_par
2+
3+
import pytest
4+
5+
6+
@pytest.mark.filterwarnings("ignore: old is deprecated.*")
7+
def test_rename_fn_par():
8+
assert rename_fn_par("new", None, "old", None, 1) == 1
9+
assert rename_fn_par("new", None, "old", 1, 2) == 1
10+
assert rename_fn_par("new", 1, "old", None, 2) == 1
11+
assert rename_fn_par("new", 1, "old", 1, 2) == 1
12+
13+
with pytest.raises(ValueError):
14+
rename_fn_par("new", 1, "old", 2, 3)
15+
16+
17+
class _RenamedMethod:
18+
def __init__(self, new = None, a = 1, b = "test", old = None):
19+
self.y = rename_fn_par("new", new, "old", old, 42)
20+
21+
22+
@pytest.mark.filterwarnings("ignore: old is deprecated.*")
23+
def test_rename_fn_par_usage():
24+
rm = _RenamedMethod()
25+
assert rm.y == 42
26+
27+
rm = _RenamedMethod(1)
28+
assert rm.y == 1
29+
30+
rm = _RenamedMethod(new=1)
31+
assert rm.y == 1
32+
33+
rm = _RenamedMethod(old=1)
34+
assert rm.y == 1
35+
36+
rm = _RenamedMethod(1, 1, "test", 1)
37+
assert rm.y == 1
38+
39+
with pytest.raises(ValueError):
40+
_RenamedMethod(1, 2, "test", 3)

0 commit comments

Comments
 (0)