diff --git a/atintegrators/BeamLoadingCavityPass.c b/atintegrators/BeamLoadingCavityPass.c index 0cee79530..3f435647a 100644 --- a/atintegrators/BeamLoadingCavityPass.c +++ b/atintegrators/BeamLoadingCavityPass.c @@ -223,7 +223,14 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, Energy=atGetOptionalDouble(ElemData,"Energy",Param->energy); check_error(); z_cuts=atGetOptionalDoubleArray(ElemData,"ZCuts"); check_error(); feedback_angle_offset=atGetOptionalDouble(ElemData,"feedback_angle_offset", 0.0); check_error(); - + + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + int dimsth[] = {Param->nbunch*nslice*nturns, 4}; atCheckArrayDims(ElemData,"_turnhistory", 2, dimsth); check_error(); int dimsvb[] = {Param->nbunch, 2}; diff --git a/atintegrators/BndMPoleSymplectic4E2RadPass.c b/atintegrators/BndMPoleSymplectic4E2RadPass.c index 2c5ae5ec4..d446a4330 100644 --- a/atintegrators/BndMPoleSymplectic4E2RadPass.c +++ b/atintegrators/BndMPoleSymplectic4E2RadPass.c @@ -232,6 +232,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, RApertures=atGetOptionalDoubleArray(ElemData,"RApertures"); check_error(); KickAngle=atGetOptionalDoubleArray(ElemData,"KickAngle"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/BndMPoleSymplectic4QuantPass.c b/atintegrators/BndMPoleSymplectic4QuantPass.c index b4bb1304f..d94586046 100644 --- a/atintegrators/BndMPoleSymplectic4QuantPass.c +++ b/atintegrators/BndMPoleSymplectic4QuantPass.c @@ -215,6 +215,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, RApertures=atGetOptionalDoubleArray(ElemData,"RApertures"); check_error(); KickAngle=atGetOptionalDoubleArray(ElemData,"KickAngle"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/BndMPoleSymplectic4RadPass.c b/atintegrators/BndMPoleSymplectic4RadPass.c index e4c35e230..7d6042cbb 100644 --- a/atintegrators/BndMPoleSymplectic4RadPass.c +++ b/atintegrators/BndMPoleSymplectic4RadPass.c @@ -168,6 +168,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, RApertures=atGetOptionalDoubleArray(ElemData,"RApertures"); check_error(); KickAngle=atGetOptionalDoubleArray(ElemData,"KickAngle"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/CrabCavityPass.c b/atintegrators/CrabCavityPass.c index f61ee1554..d03686752 100644 --- a/atintegrators/CrabCavityPass.c +++ b/atintegrators/CrabCavityPass.c @@ -101,6 +101,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, SigPhi = atGetOptionalDouble(ElemData,"SigPhi",0.0); check_error(); SigVV = atGetOptionalDouble(ElemData,"SigVV",0.0); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->Vx=Voltages[0]; diff --git a/atintegrators/ExactMultipoleRadPass.c b/atintegrators/ExactMultipoleRadPass.c index 1140ca0df..3beab66f4 100644 --- a/atintegrators/ExactMultipoleRadPass.c +++ b/atintegrators/ExactMultipoleRadPass.c @@ -138,6 +138,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData, struct elem *Elem, atError("NumIntSteps must be positive"); check_error(); } + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem *)atMalloc(sizeof(struct elem)); Elem->Length = Length; Elem->PolynomA = PolynomA; diff --git a/atintegrators/ExactRectBendRadPass.c b/atintegrators/ExactRectBendRadPass.c index 86d080724..723e10047 100644 --- a/atintegrators/ExactRectBendRadPass.c +++ b/atintegrators/ExactRectBendRadPass.c @@ -173,6 +173,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, atError("NumIntSteps == 0 not allowed with radiation"); check_error(); } + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/ExactRectangularBendRadPass.c b/atintegrators/ExactRectangularBendRadPass.c index f1a957c28..dff706fb9 100644 --- a/atintegrators/ExactRectangularBendRadPass.c +++ b/atintegrators/ExactRectangularBendRadPass.c @@ -175,6 +175,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, atError("NumIntSteps must be positive"); check_error(); } + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/ExactSectorBendRadPass.c b/atintegrators/ExactSectorBendRadPass.c index 9c72907bc..e194b9438 100644 --- a/atintegrators/ExactSectorBendRadPass.c +++ b/atintegrators/ExactSectorBendRadPass.c @@ -160,6 +160,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, atError("NumIntSteps == 0 not allowed with radiation"); check_error(); } + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/GWigSymplecticPass.c b/atintegrators/GWigSymplecticPass.c index 58bc8e7f8..dcc40b619 100644 --- a/atintegrators/GWigSymplecticPass.c +++ b/atintegrators/GWigSymplecticPass.c @@ -144,6 +144,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, T1 = atGetOptionalDoubleArray(ElemData, "T1"); check_error(); T2 = atGetOptionalDoubleArray(ElemData, "T2"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Energy=Energy; Elem->Length=Ltot; diff --git a/atintegrators/GWigSymplecticRadPass.c b/atintegrators/GWigSymplecticRadPass.c index 3b508915b..577389640 100644 --- a/atintegrators/GWigSymplecticRadPass.c +++ b/atintegrators/GWigSymplecticRadPass.c @@ -270,6 +270,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, T1 = atGetOptionalDoubleArray(ElemData, "T1"); check_error(); T2 = atGetOptionalDoubleArray(ElemData, "T2"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Energy=Energy; Elem->Length=Ltot; diff --git a/atintegrators/RFCavityPass.c b/atintegrators/RFCavityPass.c index 60bf4c72a..f12e9fe7e 100755 --- a/atintegrators/RFCavityPass.c +++ b/atintegrators/RFCavityPass.c @@ -47,6 +47,14 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, Energy=atGetOptionalDouble(ElemData,"Energy",energy); check_error(); TimeLag=atGetOptionalDouble(ElemData,"TimeLag",0); check_error(); PhaseLag=atGetOptionalDouble(ElemData,"PhaseLag",0); check_error(); + + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->Voltage=Voltage; @@ -56,7 +64,7 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, Elem->TimeLag=TimeLag; Elem->PhaseLag=PhaseLag; } - if (energy == 0.0) energy = Elem->Energy; + energy = atEnergy(energy, Elem->Energy); RFCavityPass(r_in, Elem->Length, Elem->Voltage/energy, Elem->Frequency, Elem->HarmNumber, Elem->TimeLag, Elem->PhaseLag, nturn, T0, num_particles); diff --git a/atintegrators/StrMPoleSymplectic4QuantPass.c b/atintegrators/StrMPoleSymplectic4QuantPass.c index e641f151f..00b1eab6c 100644 --- a/atintegrators/StrMPoleSymplectic4QuantPass.c +++ b/atintegrators/StrMPoleSymplectic4QuantPass.c @@ -184,6 +184,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, RApertures=atGetOptionalDoubleArray(ElemData,"RApertures"); check_error(); KickAngle=atGetOptionalDoubleArray(ElemData,"KickAngle"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atintegrators/StrMPoleSymplectic4RadPass.c b/atintegrators/StrMPoleSymplectic4RadPass.c index b2daf7a78..c292829d4 100644 --- a/atintegrators/StrMPoleSymplectic4RadPass.c +++ b/atintegrators/StrMPoleSymplectic4RadPass.c @@ -144,6 +144,13 @@ ExportMode struct elem *trackFunction(const atElem *ElemData,struct elem *Elem, RApertures=atGetOptionalDoubleArray(ElemData,"RApertures"); check_error(); KickAngle=atGetOptionalDoubleArray(ElemData,"KickAngle"); check_error(); + /* Check energy */ + Energy = atEnergy(Param->energy, Energy); + if (Energy == 0) { + atError("Energy needs to be defined. Check lattice parameters or pass method options.\n"); + check_error(); + } + Elem = (struct elem*)atMalloc(sizeof(struct elem)); Elem->Length=Length; Elem->PolynomA=PolynomA; diff --git a/atmat/atphysics/Radiation/atdiffmat.m b/atmat/atphysics/Radiation/atdiffmat.m index 1124c542b..8546cb39d 100644 --- a/atmat/atphysics/Radiation/atdiffmat.m +++ b/atmat/atphysics/Radiation/atdiffmat.m @@ -56,7 +56,7 @@ end % Calculate 6-by-6 linear transfer matrix in each element % near the equilibrium orbit - m=findelemm66(elem,passm,orbit); + m=findelemm66(elem,passm,orbit,'Energy',energy); % Cumulative diffusion matrix of the entire ring BCUM = m*BCUM*m' + bdiff; btx=BCUM; diff --git a/atmat/attrack/elempass.m b/atmat/attrack/elempass.m index 8d567a71c..8f13b0bff 100644 --- a/atmat/attrack/elempass.m +++ b/atmat/attrack/elempass.m @@ -26,6 +26,28 @@ props.Particle=particle; end +method_req_energy = { 'BeamLoadingCavityPass' , ... + 'BndMPoleSymplectic4E2RadPass', ... + 'BndMPoleSymplectic4QuantPass', ... + 'BndMPoleSymplectic4RadPass', ... + 'CrabCavityPass', ... + 'ExactMultipoleRadPass', ... + 'ExactRectangularBendRadPass', ... + 'ExactRectBendRadPass', ... + 'ExactSectorBendRadPass', ... + 'GWigSymplecticPass', ... + 'GWigSymplecticRadPass', ... + 'RFCavityPass', ... + 'StrMPoleSymplectic4QuantPass', ... + 'StrMPoleSymplectic4RadPass' ... + }; + +if any(strcmp(method_req_energy, methodname)) + if props.Energy <= 0 + error("Energy parameter must be defined."); + end +end + rout = feval(methodname,elem,rin,props); end \ No newline at end of file diff --git a/pyat/test/test_basic_elements.py b/pyat/test/test_basic_elements.py index 5cab9dbbb..30fe460b3 100644 --- a/pyat/test/test_basic_elements.py +++ b/pyat/test/test_basic_elements.py @@ -1,43 +1,51 @@ -import pytest -import numpy import warnings -from at import element_track, lattice_track -from at import lattice_pass, internal_lpass -from at import element_pass, internal_epass + +import numpy as np +import pytest +from at import ( + element_pass, + element_track, + elements, + internal_epass, + internal_lpass, + lattice, + lattice_pass, + lattice_track, +) from at.lattice.elements.conversions import _array, _array66 -from at import elements, lattice -from numpy.testing import assert_equal def test_data_checks(): - val = numpy.zeros([6, 6]) + val = np.zeros([6, 6]) assert _array(val).shape == (36,) assert _array66(val).shape == (6, 6) def test_element_string_ordering(): - d = elements.Drift('D0', 1, attr=numpy.array(0)) - assert d.__str__() == ("Drift:\n FamName: D0\n Length: 1.0\n" - " PassMethod: DriftPass\n attr: 0") + d = elements.Drift("D0", 1, attr=np.array(0)) + assert d.__str__() == ( + "Drift:\n FamName: D0\n Length: 1.0\n" + " PassMethod: DriftPass\n attr: 0" + ) assert d.__repr__() == "Drift('D0', 1.0, attr=array(0))" def test_element_creation_raises_exception(): with pytest.raises(ValueError): - elements.Element('family_name', R1='not_an_array') + elements.Element("family_name", R1="not_an_array") def test_base_element_methods(): - e = elements.Element('family_name') + e = elements.Element("family_name") assert e.divide([0.2, 0.5, 0.3]) == [e] assert id(e.copy()) != id(e) def test_argument_checks(): - q = elements.Quadrupole('quad', 1.0, 0.5) + q = elements.Quadrupole("quad", 1.0, 0.5) # Test type with pytest.raises(ValueError): - q.Length = 'a' + q.Length = "a" # Test shape with pytest.raises(ValueError): q.T1 = [0.0, 0.0] @@ -51,94 +59,93 @@ def test_argument_checks(): def test_dipole(): - d = elements.Dipole('dipole', 1.0, 0.01) + d = elements.Dipole("dipole", 1.0, 0.01) assert d.MaxOrder == 0 assert len(d.PolynomA) == 2 assert d.K == 0.0 - d = elements.Dipole('dipole', 1.0, 0.01, -0.5) + d = elements.Dipole("dipole", 1.0, 0.01, -0.5) assert d.MaxOrder == 1 assert len(d.PolynomA) == 2 assert d.K == -0.5 - d = elements.Dipole('dipole', 1.0, 0.01, PolynomB=[0.0, 0.1, 0.0]) + d = elements.Dipole("dipole", 1.0, 0.01, PolynomB=[0.0, 0.1, 0.0]) assert d.MaxOrder == 1 assert len(d.PolynomA) == 3 assert d.K == 0.1 - d = elements.Dipole('dipole', 1.0, 0.01, PolynomB=[0.0, 0.0, 0.005]) + d = elements.Dipole("dipole", 1.0, 0.01, PolynomB=[0.0, 0.0, 0.005]) assert d.MaxOrder == 2 assert len(d.PolynomA) == 3 assert d.K == 0.0 - d = elements.Dipole('dipole', 1.0, 0.01, PolynomB=[0.0, 0.0, 0.005], - MaxOrder=0) + d = elements.Dipole("dipole", 1.0, 0.01, PolynomB=[0.0, 0.0, 0.005], MaxOrder=0) assert d.MaxOrder == 0 assert len(d.PolynomA) == 3 assert d.K == 0.0 def test_quadrupole(): - q = elements.Quadrupole('quadrupole', 1.0) + q = elements.Quadrupole("quadrupole", 1.0) assert q.MaxOrder == 1 assert len(q.PolynomA) == 2 assert q.K == 0.0 - q = elements.Quadrupole('quadrupole', 1.0, -0.5) + q = elements.Quadrupole("quadrupole", 1.0, -0.5) assert q.MaxOrder == 1 assert len(q.PolynomA) == 2 assert q.K == -0.5 - q = elements.Quadrupole('quadrupole', 1.0, PolynomB=[0.0, 0.0, 0.005]) + q = elements.Quadrupole("quadrupole", 1.0, PolynomB=[0.0, 0.0, 0.005]) assert q.MaxOrder == 2 assert len(q.PolynomA) == 3 assert q.K == 0.0 - q = elements.Quadrupole('quadrupole', 1.0, PolynomB=[0.0, 0.5, 0.005], - MaxOrder=1) + q = elements.Quadrupole("quadrupole", 1.0, PolynomB=[0.0, 0.5, 0.005], MaxOrder=1) assert q.MaxOrder == 1 assert len(q.PolynomA) == 3 assert q.K == 0.5 def test_sextupole(): - s = elements.Sextupole('sextupole', 1.0) + s = elements.Sextupole("sextupole", 1.0) assert s.MaxOrder == 2 assert len(s.PolynomA) == 3 assert s.H == 0.0 - s = elements.Sextupole('sextupole', 1.0, -0.5) + s = elements.Sextupole("sextupole", 1.0, -0.5) assert s.MaxOrder == 2 assert len(s.PolynomA) == 3 assert s.H == -0.5 - s = elements.Sextupole('sextupole', 1.0, PolynomB=[0.0, 0.0, 0.005, 0.0]) + s = elements.Sextupole("sextupole", 1.0, PolynomB=[0.0, 0.0, 0.005, 0.0]) assert s.MaxOrder == 2 assert len(s.PolynomA) == 4 assert s.H == 0.005 - s = elements.Sextupole('sextupole', 1.0, PolynomB=[0.0, 0.0, 0.005, 0.001]) + s = elements.Sextupole("sextupole", 1.0, PolynomB=[0.0, 0.0, 0.005, 0.001]) assert s.MaxOrder == 3 assert len(s.PolynomA) == 4 assert s.H == 0.005 - s = elements.Sextupole('sextupole', 1.0, PolynomB=[0.0, 0.5, 0.005, 0.001], - MaxOrder=2) + s = elements.Sextupole( + "sextupole", 1.0, PolynomB=[0.0, 0.5, 0.005, 0.001], MaxOrder=2 + ) assert s.MaxOrder == 2 assert len(s.PolynomA) == 4 assert s.H == 0.005 def test_octupole(): - o = elements.Octupole('octupole', 1.0, [], [0.0, 0.0, 0.0, 0.0]) + o = elements.Octupole("octupole", 1.0, [], [0.0, 0.0, 0.0, 0.0]) assert o.MaxOrder == 3 assert len(o.PolynomA) == 4 def test_thinmultipole(): - m = elements.ThinMultipole('thin', [], [0.0, 0.0, 0.0, 0.0]) + m = elements.ThinMultipole("thin", [], [0.0, 0.0, 0.0, 0.0]) assert m.MaxOrder == 0 assert len(m.PolynomA) == 4 - m = elements.ThinMultipole('thin', [], [0.0, 0.0, 1.0, 0.0]) + m = elements.ThinMultipole("thin", [], [0.0, 0.0, 1.0, 0.0]) assert m.MaxOrder == 2 assert len(m.PolynomA) == 4 def test_multipole(): - m = elements.Multipole('multi', 1.0, [], [0.0, 0.0, 0.0, 0.0]) + m = elements.Multipole("multi", 1.0, [], [0.0, 0.0, 0.0, 0.0]) assert m.Length == 1.0 assert m.MaxOrder == 0 assert m.NumIntSteps == 10 - assert m.PassMethod == 'StrMPoleSymplectic4Pass' + assert m.PassMethod == "StrMPoleSymplectic4Pass" @pytest.mark.parametrize( @@ -306,19 +313,20 @@ def test_sextupolar_strength_prioritisation(): def test_divide_splits_attributes_correctly(): - pre = elements.Drift('drift', 1) + pre = elements.Drift("drift", 1) post = pre.divide([0.2, 0.5, 0.3]) assert len(post) == 3 assert sum([e.Length for e in post]) == pre.Length - pre = elements.Dipole('dipole', 1, KickAngle=[0.5, -0.5], BendingAngle=0.2) + pre = elements.Dipole("dipole", 1, KickAngle=[0.5, -0.5], BendingAngle=0.2) post = pre.divide([0.2, 0.5, 0.3]) assert len(post) == 3 assert sum([e.Length for e in post]) == pre.Length assert sum([e.KickAngle[0] for e in post]) == pre.KickAngle[0] assert sum([e.KickAngle[1] for e in post]) == pre.KickAngle[1] assert sum([e.BendingAngle for e in post]) == pre.BendingAngle - pre = elements.RFCavity('rfc', 1, voltage=187500, frequency=3.5237e+8, - harmonic_number=31, energy=6.e+9) + pre = elements.RFCavity( + "rfc", 1, voltage=187500, frequency=3.5237e8, harmonic_number=31, energy=6.0e9 + ) post = pre.divide([0.2, 0.5, 0.3]) assert len(post) == 3 assert sum([e.Length for e in post]) == pre.Length @@ -327,46 +335,47 @@ def test_divide_splits_attributes_correctly(): def test_insert_into_drift(): # Create elements - drift = elements.Drift('drift', 1) - monitor = elements.Monitor('bpm') - quad = elements.Quadrupole('quad', 0.3) + drift = elements.Drift("drift", 1) + monitor = elements.Monitor("bpm") + quad = elements.Quadrupole("quad", 0.3) # Test None splitting behaviour - el_list = drift.insert([(0., None), (0.3, None), (0.7, None), (1., None)]) + el_list = drift.insert([(0.0, None), (0.3, None), (0.7, None), (1.0, None)]) assert len(el_list) == 3 - numpy.testing.assert_almost_equal([e.Length for e in el_list], - [0.3, 0.4, 0.3]) + np.testing.assert_almost_equal([e.Length for e in el_list], [0.3, 0.4, 0.3]) # Test normal insertion el_list = drift.insert([(0.3, monitor), (0.7, quad)]) assert len(el_list) == 5 - numpy.testing.assert_almost_equal([e.Length for e in el_list], - [0.3, 0.0, 0.25, 0.3, 0.15]) + np.testing.assert_almost_equal( + [e.Length for e in el_list], [0.3, 0.0, 0.25, 0.3, 0.15] + ) # Test insertion at either end produces -ve length drifts el_list = drift.insert([(0.0, quad), (1.0, quad)]) assert len(el_list) == 5 - numpy.testing.assert_almost_equal([e.Length for e in el_list], - [-0.15, 0.3, 0.7, 0.3, -0.15]) + np.testing.assert_almost_equal( + [e.Length for e in el_list], [-0.15, 0.3, 0.7, 0.3, -0.15] + ) -@pytest.mark.parametrize('func', (lattice_track, lattice_pass, internal_lpass)) +@pytest.mark.parametrize("func", (lattice_track, lattice_pass, internal_lpass)) def test_correct_dimensions_does_not_raise_error(rin, func): func([], rin, 1) - rin = numpy.zeros((6,)) + rin = np.zeros((6,)) func([], rin, 1) - rin = numpy.array(numpy.zeros((6, 2), order='F')) + rin = np.array(np.zeros((6, 2), order="F")) func([], rin, 1) @pytest.mark.parametrize("dipole_class", (elements.Dipole, elements.Bend)) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_dipole_bend_synonym(rin, dipole_class, func): - b = dipole_class('dipole', 1.0, 0.1, EntranceAngle=0.05, ExitAngle=0.05) + b = dipole_class("dipole", 1.0, 0.1, EntranceAngle=0.05, ExitAngle=0.05) rin[0, 0] = 1e-6 if func == element_track: func(b, rin, in_place=True) else: func(b, rin) - rin_expected = numpy.array([1e-6, 0, 0, 0, 0, 1e-7]).reshape((6, 1)) - numpy.testing.assert_almost_equal(rin, rin_expected) + rin_expected = np.array([1e-6, 0, 0, 0, 0, 1e-7]).reshape((6, 1)) + np.testing.assert_almost_equal(rin, rin_expected) assert b.K == 0.0 b.PolynomB[1] = 0.2 assert b.K == 0.2 @@ -374,35 +383,35 @@ def test_dipole_bend_synonym(rin, dipole_class, func): assert b.PolynomB[1] == 0.1 -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_marker(rin, func): - m = elements.Marker('marker') + m = elements.Marker("marker") assert m.Length == 0 - rin = numpy.array(numpy.random.rand(*rin.shape), order='F') - rin_orig = numpy.array(rin, copy=True, order='F') + rin = np.array(np.random.rand(*rin.shape), order="F") + rin_orig = np.array(rin, copy=True, order="F") if func == element_track: func(m, rin, in_place=True) else: func(m, rin) - numpy.testing.assert_equal(rin, rin_orig) + np.testing.assert_equal(rin, rin_orig) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_monitor(rin, func): - mon = elements.Monitor('monitor') + mon = elements.Monitor("monitor") assert mon.Length == 0 - rin = numpy.array(numpy.random.rand(*rin.shape), order='F') + rin = np.array(np.random.rand(*rin.shape), order="F") rin_orig = rin.copy() if func == element_track: func(mon, rin, in_place=True) else: func(mon, rin) - numpy.testing.assert_equal(rin, rin_orig) + np.testing.assert_equal(rin, rin_orig) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_aperture_inside_limits(rin, func): - a = elements.Aperture('aperture', [-1e-3, 1e-3, -1e-4, 1e-4]) + a = elements.Aperture("aperture", [-1e-3, 1e-3, -1e-4, 1e-4]) assert a.Length == 0 rin[0, 0] = 1e-5 rin[2, 0] = -1e-5 @@ -411,12 +420,12 @@ def test_aperture_inside_limits(rin, func): func(a, rin, in_place=True) else: func(a, rin) - numpy.testing.assert_equal(rin, rin_orig) + np.testing.assert_equal(rin, rin_orig) -@pytest.mark.parametrize('func', (lattice_track, lattice_pass, internal_lpass)) +@pytest.mark.parametrize("func", (lattice_track, lattice_pass, internal_lpass)) def test_aperture_outside_limits(rin, func): - a = elements.Aperture('aperture', [-1e-3, 1e-3, -1e-4, 1e-4]) + a = elements.Aperture("aperture", [-1e-3, 1e-3, -1e-4, 1e-4]) assert a.Length == 0 lattice = [a] rin[0, 0] = 1e-2 @@ -425,13 +434,13 @@ def test_aperture_outside_limits(rin, func): func(lattice, rin, in_place=True) else: func(lattice, rin) - assert numpy.isnan(rin[0, 0]) + assert np.isnan(rin[0, 0]) assert rin[2, 0] == 0.0 # Only the 1st coordinate is nan, the rest is zero -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_drift_offset(rin, func): - d = elements.Drift('drift', 1) + d = elements.Drift("drift", 1) rin[0, 0] = 1e-6 rin[2, 0] = 2e-6 rin_orig = rin.copy() @@ -439,12 +448,12 @@ def test_drift_offset(rin, func): func(d, rin, in_place=True) else: func(d, rin) - numpy.testing.assert_equal(rin, rin_orig) + np.testing.assert_equal(rin, rin_orig) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_drift_divergence(rin, func): - d = elements.Drift('drift', 1.0) + d = elements.Drift("drift", 1.0) assert d.Length == 1 rin[1, 0] = 1e-6 rin[3, 0] = -2e-6 @@ -453,16 +462,15 @@ def test_drift_divergence(rin, func): else: func(d, rin) # results from Matlab - rin_expected = numpy.array([1e-6, 1e-6, -2e-6, -2e-6, 0, - 2.5e-12]).reshape(6, 1) - numpy.testing.assert_equal(rin, rin_expected) + rin_expected = np.array([1e-6, 1e-6, -2e-6, -2e-6, 0, 2.5e-12]).reshape(6, 1) + np.testing.assert_equal(rin, rin_expected) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_drift_two_particles(rin, func): - d = elements.Drift('drift', 1.0) + d = elements.Drift("drift", 1.0) assert d.Length == 1 - two_rin = numpy.array(numpy.concatenate((rin, rin), axis=1), order='F') + two_rin = np.array(np.concatenate((rin, rin), axis=1), order="F") # particle one is offset two_rin[0, 0] = 1e-6 two_rin[2, 0] = 2e-6 @@ -475,24 +483,27 @@ def test_drift_two_particles(rin, func): else: func(d, two_rin) # results from Matlab - p1_expected = numpy.array(two_rin_orig[:, 0]).reshape(6, 1) - p2_expected = numpy.array([1e-6, 1e-6, -2e-6, -2e-6, 0, - 2.5e-12]).reshape(6, 1) - two_rin_expected = numpy.concatenate((p1_expected, p2_expected), axis=1) - numpy.testing.assert_equal(two_rin, two_rin_expected) + p1_expected = np.array(two_rin_orig[:, 0]).reshape(6, 1) + p2_expected = np.array([1e-6, 1e-6, -2e-6, -2e-6, 0, 2.5e-12]).reshape(6, 1) + two_rin_expected = np.concatenate((p1_expected, p2_expected), axis=1) + np.testing.assert_equal(two_rin, two_rin_expected) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_quad(rin, func): - q = elements.Quadrupole('quad', 0.4, k=1) + q = elements.Quadrupole("quad", 0.4, k=1) rin[0, 0] = 1e-6 if func == element_track: func(q, rin, in_place=True) else: func(q, rin) - expected = numpy.array([0.9210610203854122, -0.3894182419439, 0, - 0, 0, 0.0000000103303797478]).reshape(6, 1) * 1e-6 - numpy.testing.assert_allclose(rin, expected) + expected = ( + np.array( + [0.9210610203854122, -0.3894182419439, 0, 0, 0, 0.0000000103303797478] + ).reshape(6, 1) + * 1e-6 + ) + np.testing.assert_allclose(rin, expected) assert q.K == 1 q.PolynomB[1] = 0.2 assert q.K == 0.2 @@ -500,40 +511,48 @@ def test_quad(rin, func): assert q.PolynomB[1] == 0.1 -@pytest.mark.parametrize('func', (lattice_track, lattice_pass, internal_lpass)) -def test_rfcavity(rin, func): - rf = elements.RFCavity('rfcavity', 0.0, 187500, 3.5237e+8, 31, 6.e+9) - lattice = [rf, rf, rf, rf] +@pytest.mark.parametrize("func", [element_track, element_pass, internal_epass]) +def test_rfcavity(rin: np.ndarray, func: any) -> None: + """Test rf cavity tracking. + + Arguments: + func: element track method + rin: np.array of dims (6,1) + """ + rfc = elements.RFCavity("rfcavity", 0.0, 187500, 3.5237e8, 31, 6.0e9) rin[4, 0] = 1e-6 rin[5, 0] = 1e-6 - if func == lattice_track: - func(lattice, rin, in_place=True) + if func is element_track: + func(rfc, rin, in_place=True, energy=+6.0e9) else: - func(lattice, rin) - expected = numpy.array([0., 0., 0., 0., 9.990769e-7, 1.e-6]).reshape(6, 1) - numpy.testing.assert_allclose(rin, expected, atol=1e-12) + func(rfc, rin, energy=+6.0e9) + + result = np.array([0.0, 0.0, 0.0, 0.0, 9.99769215e-07, 1.0e-6]).reshape(6, 1) + + np.testing.assert_allclose(rin, result, atol=1e-12) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) @pytest.mark.parametrize("n", (0, 1, 2, 3, 4, 5)) def test_m66(rin, n, func): - m = numpy.random.rand(6, 6) - m66 = elements.M66('m66', m) + m = np.random.rand(6, 6) + m66 = elements.M66("m66", m) assert m66.Length == 0 rin[n, 0] = 1e-6 if func == element_track: func(m66, rin, in_place=True) else: func(m66, rin) - expected = numpy.array([m[0, n], m[1, n], m[2, n], m[3, n], m[4, n], - m[5, n]]).reshape(6, 1) * 1e-6 - numpy.testing.assert_equal(rin, expected) + expected = ( + np.array([m[0, n], m[1, n], m[2, n], m[3, n], m[4, n], m[5, n]]).reshape(6, 1) + * 1e-6 + ) + np.testing.assert_equal(rin, expected) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_corrector(rin, func): - c = elements.Corrector('corrector', 0.0, numpy.array([0.9, 0.5], - dtype=numpy.float64)) + c = elements.Corrector("corrector", 0.0, np.array([0.9, 0.5], dtype=np.float64)) assert c.Length == 0 rin[0, 0] = 1e-6 rin_orig = rin.copy() @@ -543,29 +562,31 @@ def test_corrector(rin, func): func(c, rin, in_place=True) else: func(c, rin) - numpy.testing.assert_equal(rin, rin_orig) + np.testing.assert_equal(rin, rin_orig) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_wiggler(rin, func): period = 0.05 periods = 23 bmax = 1 - by = numpy.array([1, 1, 0, 1, 1, 0], dtype=numpy.float64) - c = elements.Wiggler('wiggler', period * periods, period, bmax, By=by) + by = np.array([1, 1, 0, 1, 1, 0], dtype=np.float64) + c = elements.Wiggler("wiggler", period * periods, period, bmax, By=by) assert abs(c.Length - 1.15) < 1e-10 # Expected value from Matlab AT. - expected = numpy.array(rin, copy=True) + expected = np.array(rin, copy=True) expected[5] = 0.000000181809691064259 if func == element_track: func(c, rin, energy=3e9, in_place=True) else: func(c, rin, energy=3e9) - numpy.testing.assert_allclose(rin, expected, atol=1e-12) + np.testing.assert_allclose(rin, expected, atol=1e-12) def test_exit_entrance(): - q = elements.Quadrupole('quad', 0.4, k=1) + q = elements.Quadrupole("quad", 0.4, k=1) for kin, kout in zip(q._entrance_fields, q._exit_fields): - assert_equal(kin.replace('Entrance', ''). replace('1', ''), - kout.replace('Exit', '').replace('2', '')) + np.testing.assert_equal( + kin.replace("Entrance", "").replace("1", ""), + kout.replace("Exit", "").replace("2", ""), + ) diff --git a/pyat/test/test_integrators.py b/pyat/test/test_integrators.py index 2a896cd99..c754c2189 100644 --- a/pyat/test/test_integrators.py +++ b/pyat/test/test_integrators.py @@ -10,57 +10,54 @@ from at import element_pass, internal_epass -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_exact_hamiltonian_pass(rin, func): - drift = elements.Multipole('m1', 1, [0, 0, 0, 0], [0, 0, 0, 0]) + drift = elements.Multipole("m1", 1, [0, 0, 0, 0], [0, 0, 0, 0]) drift.Type = 0 - drift.PassMethod = 'ExactHamiltonianPass' + drift.PassMethod = "ExactHamiltonianPass" drift.BendingAngle = 0 func(drift, rin) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_exact_hamiltonian_pass_with_dls_dipole(rin, func): - bend = elements.Multipole('rb', 0.15, [0, 0, 0, 0], - [-0.0116333, 3.786786, 0, 0]) + bend = elements.Multipole("rb", 0.15, [0, 0, 0, 0], [-0.0116333, 3.786786, 0, 0]) bend.Type = 1 - bend.PassMethod = 'ExactHamiltonianPass' + bend.PassMethod = "ExactHamiltonianPass" bend.BendingAngle = -0.001745 bend.Energy = 3.5e9 bend.MaxOrder = 3 - if func==element_track: + if func == element_track: func(bend, rin, in_place=True) else: func(bend, rin) # Results from Matlab - expected = numpy.array([9.23965e-9, 1.22319e-5, 0, - 0, 0, -4.8100e-10]).reshape(6, 1) + expected = numpy.array([9.23965e-9, 1.22319e-5, 0, 0, 0, -4.8100e-10]).reshape(6, 1) numpy.testing.assert_allclose(rin, expected, rtol=1e-5, atol=1e-6) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) -@pytest.mark.parametrize('passmethod', - ('GWigSymplecticPass', 'GWigSymplecticRadPass')) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("passmethod", ("GWigSymplecticPass", "GWigSymplecticRadPass")) def test_gwig_symplectic_pass(rin, passmethod, func): # Parameters copied from one of the Diamond wigglers. - wiggler = elements.Wiggler('w', 1.15, 0.05, 0.8) + wiggler = elements.Wiggler("w", 1.15, 0.05, 0.8) wiggler.PassMethod = passmethod - func(wiggler, rin) + func(wiggler, rin, energy=1e9) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_bndstrmpole_symplectic_4_pass(rin, func): - bend = elements.Dipole('b', 1.0) - bend.PassMethod = 'BndStrMPoleSymplectic4Pass' + bend = elements.Dipole("b", 1.0) + bend.PassMethod = "BndStrMPoleSymplectic4Pass" func(bend, rin) -@pytest.mark.parametrize('func', (element_track, element_pass, internal_epass)) +@pytest.mark.parametrize("func", (element_track, element_pass, internal_epass)) def test_pydrift(func): - pydrift = elements.Drift('drift', 1.0, PassMethod='pyDriftPass') - cdrift = elements.Drift('drift', 1.0, PassMethod='DriftPass') - pyout, *_ = func(pydrift, numpy.zeros(6)+1.0e-6) - cout, *_ = func(cdrift, numpy.zeros(6)+1.0e-6) + pydrift = elements.Drift("drift", 1.0, PassMethod="pyDriftPass") + cdrift = elements.Drift("drift", 1.0, PassMethod="DriftPass") + pyout, *_ = func(pydrift, numpy.zeros(6) + 1.0e-6) + cout, *_ = func(cdrift, numpy.zeros(6) + 1.0e-6) numpy.testing.assert_equal(pyout, cout) shift_elem(pydrift, 1.0e-3, 1.0e-3) @@ -71,23 +68,24 @@ def test_pydrift(func): tilt_elem(pydrift, 1.0e-3, 1.0e-3) tilt_elem(cdrift, 1.0e-3, 1.0e-3) - pyout, *_ = func(pydrift, numpy.zeros(6)+1.0e-6) - cout, *_ = func(cdrift, numpy.zeros(6)+1.0e-6) + pyout, *_ = func(pydrift, numpy.zeros(6) + 1.0e-6) + cout, *_ = func(cdrift, numpy.zeros(6) + 1.0e-6) numpy.testing.assert_equal(pyout, cout) # Multiple particles - pyout, *_ = func(pydrift, numpy.zeros((6, 2))+1.0e-6) - cout, *_ = func(cdrift, numpy.zeros((6, 2))+1.0e-6) + pyout, *_ = func(pydrift, numpy.zeros((6, 2)) + 1.0e-6) + cout, *_ = func(cdrift, numpy.zeros((6, 2)) + 1.0e-6) numpy.testing.assert_equal(pyout, cout) -@pytest.mark.parametrize('func', (lattice_track, lattice_pass, internal_lpass)) +@pytest.mark.parametrize("func", (lattice_track, lattice_pass, internal_lpass)) def test_pyintegrator(hmba_lattice, func): - params = {'Length': 0, - 'PassMethod': 'pyIdentityPass', - } - id_elem = Element('py_id', **params) - pin = numpy.zeros((6, 2))+1.0e-6 + params = { + "Length": 0, + "PassMethod": "pyIdentityPass", + } + id_elem = Element("py_id", **params) + pin = numpy.zeros((6, 2)) + 1.0e-6 pout1, *_ = func(hmba_lattice, pin.copy(), nturns=1) hmba_lattice = hmba_lattice + [id_elem] pout2, *_ = func(hmba_lattice, pin.copy(), nturns=1)