# Test file for STIR buildblock py.test
# Use as follows:
# on command line
#     py.test test_buildblock.py


#    Copyright (C) 2013, 2023, 2025 University College London
#    This file is part of STIR.
#
#    SPDX-License-Identifier: Apache-2.0
#
#    See STIR/LICENSE.txt for details

try:
    import pytest
except ImportError:
    # No pytest, try older py.test
    try:
        import py.test as pytest
    except ImportError:
        raise ImportError('Tests require pytest or py<1.4')

from stir import *
import stir
import stirextra
import math

def test_Vector():
    dv=FloatVector(3)
    dv[2]=1
    #for a in dv:
    #print a
    #assert dv[0]==0 #probably not initialised
    #assert dv[1]==0 #probably not initialised
    assert dv[2]==1

    # assign same vector to another variable
    dvcopy=dv;
    # this will modify the original vector as well (as they are the same)
    # this behaviour is the same as for a Python list, but is different from C++
    dvcopy[2]=4;
    assert dv[2]==dvcopy[2]

    # instead, in Python we need to explicitly create a new object 
    dvcopy=FloatVector(dv)
    dvcopy[2]=dv[2]+2;
    assert dv[2]+2 == dvcopy[2]

def test_Coordinate():
    #a=make_FloatCoordinate(1,2,3)
    a=Float3BasicCoordinate((1,2,3))
    assert a.__getitem__(1)==1
    assert a[2]==2

    a=Float3Coordinate(1,2,3)
    assert a[2]==2
    # use tuple constructor
    a=Float3BasicCoordinate((1,2,3))
    assert a[3]==3

def test_VectorWithOffset():
    v=FloatVectorWithOffset(1,3)
    v[2]=3
    assert v[2]==3
    #assert v[1]==0 #probably not initialised
    with pytest.raises(IndexError):
        v[0] # check index-out-of-range

def test_Array1D():
    a=FloatArray1D(1,5)
    a[2]=3
    assert a[2]==3
    #help(FloatArray1D)
    # create a copy
    a2=FloatArray1D(a);
    # this should not modify the original
    a2[2]=4
    assert a[2]==3
    assert a2[2]==4

def test_Array1Diterator():
    a=FloatArray1D(1,3)
    a[2]=3
    a[1]=2
    sum=0
    for el in a: sum+=el
    assert sum==a.sum()

def test_Array2D():
    a2=FloatArray2D(IndexRange2D(Int2BasicCoordinate((3,3)), Int2BasicCoordinate((9,9))))
    a2.fill(2);
    ind=Int2BasicCoordinate((4,5))
    assert a2[ind]==2
    a2[ind]=4
    assert a2[ind]==4
    assert a2[(4,5)]==4
    #assert a2[ind[1]][ind[2]]==4
    # let's set a whole row
    #a1=a2[4]
    #a1[5]=66;
    #print 'original value in 2D array should be 2:', a2[make_IntCoordinate(4,5)], ', but the value in the copied row has to be 66:', a1[5]
    #a2[4]=a1
    #print 'now the entry in the 2D array has to be modified to 66 as well:', a2[Int2BasicCoordinate((4,5))]

def test_Array2D_numerics():
    a = FloatArray2D(IndexRange2D(Int2BasicCoordinate((3,3)), Int2BasicCoordinate((9,9))))
    b = FloatArray2D(a.get_index_range())
    a.fill(2)
    b.fill(3)
    ind = ((4,5))
    a[ind] = 5
    # compare STIR operations (in float) with Python operations (in double), so need tolerance
    c = a + b
    assert math.isclose(c[ind], a[ind] + b[ind], rel_tol=1e-4)
    c = a - b
    assert math.isclose(c[ind], a[ind] - b[ind], rel_tol=1e-4)
    c = a * b
    assert math.isclose(c[ind], a[ind] * b[ind], rel_tol=1e-4)
    c = a / b
    assert math.isclose(c[ind], a[ind] / b[ind], rel_tol=1e-4)
    c = a + 3
    assert math.isclose(c[ind], a[ind] + 3, rel_tol=1e-4)
    c = a - 3
    assert math.isclose(c[ind], a[ind] - 3, rel_tol=1e-4)
    c = a * 3
    assert math.isclose(c[ind], a[ind] * 3, rel_tol=1e-4)
    c = a / 3
    assert math.isclose(c[ind], a[ind] / 3, rel_tol=1e-4)
    # same, but now with += etc.
    # Note: using a simple/stupid trick to create a copy of a by adding 0
    c = a + 0
    c += b
    assert math.isclose(c[ind], a[ind] + b[ind], rel_tol=1e-4)
    c = a + 0
    c -= b
    assert math.isclose(c[ind], a[ind] - b[ind], rel_tol=1e-4)
    c = a + 0
    c *= b
    assert math.isclose(c[ind], a[ind] * b[ind], rel_tol=1e-4)
    c = a + 0
    c /= b
    assert math.isclose(c[ind], a[ind] / b[ind], rel_tol=1e-4)
    c = a + 0
    c += 3
    assert math.isclose(c[ind], a[ind] + 3, rel_tol=1e-4)
    c = a + 0
    c -= 3
    assert math.isclose(c[ind], a[ind] - 3, rel_tol=1e-4)
    c = a + 0
    c *= 3
    assert math.isclose(c[ind], a[ind] * 3, rel_tol=1e-4)
    c = a + 0
    c /= 3
    assert math.isclose(c[ind], a[ind] / 3, rel_tol=1e-4)
    
def test_Array2Diterator():
    a2=FloatArray2D(IndexRange2D(Int2BasicCoordinate((3,3)), Int2BasicCoordinate((9,9))))
    a2.fill(2);
    # use flat iterator, i.e. go over all elements
    assert a2.sum() == sum(a2.flat())

def test_Array3D():
    minind=Int3BasicCoordinate(3)
    maxind=Int3BasicCoordinate(9)
    maxind[3]=11;
    indrange=IndexRange3D(minind,maxind)
    a3=FloatArray3D(indrange)
    minindtest=Int3BasicCoordinate(1)
    maxindtest=Int3BasicCoordinate(1)
    a3.get_regular_range(minindtest, maxindtest)
    assert minindtest==minind
    assert maxindtest==maxind
    assert a3.shape()==(7,7,9)
    # fill with scalar
    a3.fill(2)
    ind=Int3BasicCoordinate((4,5,6))
    assert a3[ind]==2
    a3[ind]=9
    assert a3[(4,5,6)]==9
    assert a3.find_max()==9

    # test as_array()
    np_arr = a3.as_array()
    assert np_arr.shape == a3.shape()
    assert np_arr[ind[1] - minind[1], ind[2] - minind[2], ind[3] - minind[3]] == a3[ind]
    assert np_arr[0, 0, 0] == 2

    a4 = FloatArray3D(indrange)
    # test fill with iterator
    a4.fill(np_arr.flat)
    assert a4[ind]==9
    assert a4[minind]==2
    a4.fill(1)
    assert a4[ind]==1
    # test fill with numpy array
    a4.fill(np_arr)
    assert a4[ind]==9
    assert a4[minind]==2
    

def test_FloatVoxelsOnCartesianGrid():
    origin=FloatCartesianCoordinate3D(0,1,6)
    gridspacing=FloatCartesianCoordinate3D(1,1,2)
    minind=Int3BasicCoordinate(3)
    maxind=Int3BasicCoordinate(9)
    indrange=IndexRange3D(minind,maxind)
    image=FloatVoxelsOnCartesianGrid(indrange, origin,gridspacing)
    org= image.get_origin()
    assert org==origin
    image.fill(2)
    ind=Int3BasicCoordinate((4,4,4))
    assert image[ind]==2
    image[ind] = 4
    assert image[ind] == 4
    assert image[(5,3,4)]==2
    # construct from array
    a3=FloatArray3D(indrange)
    a3.fill(1.4);
    image2=FloatVoxelsOnCartesianGrid(a3, origin,gridspacing)
    assert abs(image2[ind]-1.4)<.001
    # change original array
    a3.fill(2)
    # shouldn't change image constructed from array
    assert abs(image2[ind]-1.4)<.001

    # test as_array
    np_arr = image.as_array()
    assert np_arr.shape == image.shape()
    assert np_arr[ind[1] - minind[1], ind[2] - minind[2], ind[3] - minind[3]] == image[ind]
    assert np_arr[0, 0, 0] == 2
    # test fill with iterator
    image.fill(0)
    image.fill(np_arr.flat)
    assert image[ind]==4
    assert image[minind]==2
    image.fill(1)
    assert image[ind]==1
    # test fill with numpy array
    image.fill(np_arr)
    assert image[ind]==4
    assert image[minind]==2

def test_FloatVoxelsOnCartesianGrid_numerics():
    origin=FloatCartesianCoordinate3D(0,1,6)
    gridspacing=FloatCartesianCoordinate3D(1,1,2)
    minind=Int3BasicCoordinate(3)
    maxind=Int3BasicCoordinate(9)
    indrange=IndexRange3D(minind,maxind)
    a = FloatVoxelsOnCartesianGrid(indrange, origin,gridspacing)
    b = FloatVoxelsOnCartesianGrid(indrange, origin,gridspacing)
    a.fill(2)
    b.fill(3)
    ind = ((3,4,5))
    a[ind] = 5
    # compare STIR operations (in float) with Python operations (in double), so need tolerance
    c = a + b
    assert math.isclose(c[ind], a[ind] + b[ind], rel_tol=1e-4)
    c = a - b
    assert math.isclose(c[ind], a[ind] - b[ind], rel_tol=1e-4)
    c = a * b
    assert math.isclose(c[ind], a[ind] * b[ind], rel_tol=1e-4)
    c = a / b
    assert math.isclose(c[ind], a[ind] / b[ind], rel_tol=1e-4)
    c = a + 3
    assert math.isclose(c[ind], a[ind] + 3, rel_tol=1e-4)
    c = a - 3
    assert math.isclose(c[ind], a[ind] - 3, rel_tol=1e-4)
    c = a * 3
    assert math.isclose(c[ind], a[ind] * 3, rel_tol=1e-4)
    c = a / 3
    assert math.isclose(c[ind], a[ind] / 3, rel_tol=1e-4)
    # same, but now with += etc.
    # Note: using a simple/stupid trick to create a copy of a by adding 0
    c = a + 0
    c += b
    assert math.isclose(c[ind], a[ind] + b[ind], rel_tol=1e-4)
    c = a + 0
    c -= b
    assert math.isclose(c[ind], a[ind] - b[ind], rel_tol=1e-4)
    c = a + 0
    c *= b
    assert math.isclose(c[ind], a[ind] * b[ind], rel_tol=1e-4)
    c = a + 0
    c /= b
    assert math.isclose(c[ind], a[ind] / b[ind], rel_tol=1e-4)
    c = a + 0
    c += 3
    assert math.isclose(c[ind], a[ind] + 3, rel_tol=1e-4)
    c = a + 0
    c -= 3
    assert math.isclose(c[ind], a[ind] - 3, rel_tol=1e-4)
    c = a + 0
    c *= 3
    assert math.isclose(c[ind], a[ind] * 3, rel_tol=1e-4)
    c = a + 0
    c /= 3
    assert math.isclose(c[ind], a[ind] / 3, rel_tol=1e-4)

def test_zoom_image():
    # create test image
    origin=FloatCartesianCoordinate3D(3,1,6)
    gridspacing=FloatCartesianCoordinate3D(1,1,2)
    minind=Int3BasicCoordinate((0,-9,-9))
    maxind=Int3BasicCoordinate(9)
    indrange=IndexRange3D(minind,maxind)
    image=FloatVoxelsOnCartesianGrid(indrange, origin,gridspacing)
    image.fill(1)
    # find coordinate of middle of image for later use (independent of image sizes etc)
    [min_in_mm, max_in_mm]=stirextra.get_physical_coordinates_for_bounding_box(image)
    try:
        middle_in_mm=FloatCartesianCoordinate3D((min_in_mm+max_in_mm)/2.)
    except:
        # SWIG versions pre 3.0.11 had a bug, which we try to work around here
        middle_in_mm=FloatCartesianCoordinate3D((min_in_mm+max_in_mm).__div__(2))

    # test that we throw an exception if ZoomOptions is out-of-range
    try:
        zo=ZoomOptions(42)
        assert False
    except:
        assert True

    zoom=2
    offset=1
    new_size=6
    zoomed_image=zoom_image(image, zoom, offset, offset, new_size)
    ind=zoomed_image.get_indices_closest_to_physical_coordinates(middle_in_mm)
    assert abs(zoomed_image[ind]-1./(zoom*zoom))<.001
    # awkward syntax...
    zoomed_image=zoom_image(image, zoom, offset, offset, new_size, ZoomOptions(ZoomOptions.preserve_sum))
    assert abs(zoomed_image[ind]-1./(zoom*zoom))<.001
    zoomed_image=zoom_image(image, zoom, offset, offset, new_size, ZoomOptions(ZoomOptions.preserve_values))
    assert abs(zoomed_image[ind]-1)<.001
    zoomed_image=zoom_image(image, zoom, offset, offset, new_size, ZoomOptions(ZoomOptions.preserve_projections))
    assert abs(zoomed_image[ind]-1./(zoom))<.001

def test_DetectionPositionPair():
    d1=DetectionPosition(1,2,0)
    d2=DetectionPosition(4,5,6)
    dp=DetectionPositionPair(d1,d2,3)
    assert d1==dp.pos1
    assert d2==dp.pos2
    assert dp.timing_pos == 3
    dp.pos1.tangential_coord = 7
    assert dp.pos1.tangential_coord == 7
    assert d1.tangential_coord == 1

def test_Scanner():
    scanner=Scanner.get_scanner_from_name("ECAT 962")
    assert scanner.get_num_rings()==32
    assert scanner.get_num_detectors_per_ring()==576
    #l=scanner.get_all_names()
    #print scanner
    # does not work
    #for a in l:
    #    print a
    scanner=Scanner.get_scanner_from_name("SAFIRDualRingPrototype")
    scanner.set_scanner_geometry("BlocksOnCylindrical")
    scanner.set_up()
    d=DetectionPosition(1,1,0)
    c=scanner.get_coordinate_for_det_pos(d)
    d2=DetectionPosition();
    s=scanner.find_detection_position_given_cartesian_coordinate(d2, c)
    assert s.succeeded()
    assert d==d2

def test_Radionuclide():
    modality = ImagingModality(ImagingModality.PT)
    db = RadionuclideDB()
    r = db.get_radionuclide(modality, "^18^Fluorine")
    assert abs(r.get_half_life() - 6584) < 1
    modality = ImagingModality(ImagingModality.NM)
    r = db.get_radionuclide(modality, "^99m^Technetium")
    assert abs(r.get_half_life() - 6.0058*3600) < 10

def test_Bin():
    segment_num=1;
    view_num=2;
    axial_pos_num=3;
    tangential_pos_num=4;
    bin=Bin(segment_num, view_num, axial_pos_num, tangential_pos_num);
    assert bin.bin_value==0;
    assert bin.segment_num==segment_num;
    assert bin.view_num==view_num;
    assert bin.axial_pos_num==axial_pos_num;
    assert bin.tangential_pos_num==tangential_pos_num;
    bin.segment_num=5;
    assert bin.segment_num==5;
    bin_value=0.3;
    bin.bin_value=bin_value;
    assert abs(bin.bin_value-bin_value)<.01;
    bin=Bin(segment_num, view_num, axial_pos_num, tangential_pos_num, bin_value);
    assert abs(bin.bin_value-bin_value)<.01;
    bin.time_frame_num=3;
    assert bin.time_frame_num==3;
    
def test_ProjDataInfo():
    s=Scanner.get_scanner_from_name("ECAT 962")
    #construct_proj_data_info(const shared_ptr<Scanner>& scanner_ptr,
    #		  const int span, const int max_delta,
    #             const int num_views, const int num_tangential_poss, 
    #
    projdatainfo=ProjDataInfo.construct_proj_data_info(s,3,9,8,6)
    #print( projdatainfo)
    assert projdatainfo.get_scanner().get_num_rings()==32
    # use arc-correction specific keywords
    projdatainfo.set_tangential_sampling(5) # dangerous of course, but just for the test
    assert projdatainfo.get_tangential_sampling() == 5
    # extract sinogram
    sinogram=projdatainfo.get_empty_sinogram(1,2)
    assert sinogram.sum()==0
    assert sinogram.get_segment_num()==2
    assert sinogram.get_axial_pos_num()==1
    assert sinogram.get_num_views() == projdatainfo.get_num_views()
    print(sinogram.get_proj_data_info())
    # TODO currently does not work due to TypeError
    #assert sinogram.get_proj_data_info() == projdatainfo
    assert sinogram.get_proj_data_info().parameter_info() == projdatainfo.parameter_info()

def test_ProjDataInMemory_numerics():
    # define a projection with some dummy data
    s = Scanner.get_scanner_from_name("ECAT 962")
    projdatainfo = ProjDataInfo.construct_proj_data_info(s,3,9,8,6)
    a = ProjDataInMemory(ExamInfo(),projdatainfo)
    b = ProjDataInMemory(a)
    _bin = Bin(0,1,2,3)
    _bin.bin_value = 5
    a.fill(2)
    b.fill(3)
    a.set_bin_value(_bin)
    # compare STIR operations (in float) with Python operations (in double), so need tolerance
    c = a + b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) + b.get_bin_value(_bin), rel_tol=1e-4)
    c = a - b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) - b.get_bin_value(_bin), rel_tol=1e-4)
    c = a * b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) * b.get_bin_value(_bin), rel_tol=1e-4)
    c = a / b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) / b.get_bin_value(_bin), rel_tol=1e-4)
    c = a + 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) + 3, rel_tol=1e-4)
    c = a - 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) - 3, rel_tol=1e-4)
    c = a * 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) * 3, rel_tol=1e-4)
    c = a / 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) / 3, rel_tol=1e-4)
    # same, but now with += etc.
    # Note: using a simple/stupid trick to create a copy of a by adding 0
    c = a + 0
    c += b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) + b.get_bin_value(_bin), rel_tol=1e-4)
    c = a + 0
    c -= b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) - b.get_bin_value(_bin), rel_tol=1e-4)
    c = a + 0
    c *= b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) * b.get_bin_value(_bin), rel_tol=1e-4)
    c = a + 0
    c /= b
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) / b.get_bin_value(_bin), rel_tol=1e-4)
    c = a + 0
    c += 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) + 3, rel_tol=1e-4)
    c = a + 0
    c -= 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) - 3, rel_tol=1e-4)
    c = a + 0
    c *= 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) * 3, rel_tol=1e-4)
    c = a + 0
    c /= 3
    assert math.isclose(c.get_bin_value(_bin), a.get_bin_value(_bin) / 3, rel_tol=1e-4)

def helper_ProjDataInMemory_from_to_Array(projdata, new_projdata):
    # convert to Array and back again
    stir_array=projdata.to_array()
    # fill with iterator
    new_projdata.fill(stir_array.flat())
    # assert every data point is equal
    assert all(a==b for a, b in zip(projdata.to_array().flat(), new_projdata.as_array().flat))
    # fill with numpy array
    new_projdata.fill(stir_array.as_array())
    # assert every data point is equal
    assert all(a==b for a, b in zip(projdata.to_array().flat(), new_projdata.as_array().flat))

def test_ProjDataInMemory_from_to_Array():
    # define a projection with some dummy data (filled with segment no.)
    s = stir.Scanner.get_scanner_from_name("ECAT 962")
    projdatainfo = stir.ProjDataInfo.construct_proj_data_info(s,3,9,8,6)
    examinfo = stir.ExamInfo()
    projdata = stir.ProjDataInMemory(ExamInfo(),projdatainfo)
    for seg_idx in range(projdata.get_min_segment_num(),projdata.get_max_segment_num()+1):
        segment=projdata.get_empty_segment_by_sinogram(seg_idx)
        segment.fill(seg_idx)
        projdata.set_segment(segment)

    # Check we actually put the data in (not just zeros)
    assert all([all([x==s for x in projdata.get_segment_by_sinogram(s).flat()])
                for s in range(projdata.get_min_segment_num(),projdata.get_max_segment_num()+1)])

    # test in memory
    new_projdata = stir.ProjDataInMemory(ExamInfo(),projdatainfo)
    helper_ProjDataInMemory_from_to_Array(projdata, new_projdata)
    # test on file
    projdata.write_to_file("test_projdata.hs")
    inout = stir.ios.trunc|stir.ios.ios_base_in|stir.ios.out;
    new_projdata = stir.ProjDataInterfile(examinfo, projdatainfo, "test_projdata.hs", inout)
    helper_ProjDataInMemory_from_to_Array(projdata, new_projdata)
    helper_ProjDataInMemory_from_to_Array(new_projdata, projdata)
    helper_ProjDataInMemory_from_to_Array(new_projdata, new_projdata)

def test_xapyb_and_sapyb():
    """
    Test the xapyb and sapyb methods for FloatVoxelsOnCartesianGrid and ProjDataInMemory
    """
    test_value = 1.4
    approx_val = pytest.approx(2*test_value + 3*test_value)

    # Test FloatCartesianCoordinate3D
    origin=FloatCartesianCoordinate3D(0,1,6)
    gridspacing=FloatCartesianCoordinate3D(1,1,2)
    indrange=IndexRange3D(Int3BasicCoordinate(3), Int3BasicCoordinate(9))
    image=FloatVoxelsOnCartesianGrid(indrange, origin,gridspacing)

    #  xapyb
    image.fill(test_value)
    image.xapyb(image, 2.0,image, 3.0)
    assert image.find_max()==approx_val
    assert image.find_min()==approx_val
    # sapyb
    image.fill(test_value)
    image.sapyb(3,image,2)
    assert image.find_max()==approx_val
    assert image.find_min()==approx_val

    # Test ProjData
    s=Scanner.get_scanner_from_name("ECAT 962")
    projdatainfo=ProjDataInfo.construct_proj_data_info(s,3,9,8,6)
    projdata=ProjDataInMemory(ExamInfo(),projdatainfo)

    projdata.fill(test_value)
    projdata.xapyb(projdata, 2.0,projdata, 3.0)
    assert projdata.to_array().find_max()==approx_val
    assert projdata.to_array().find_min()==approx_val
    # sapyb
    projdata.fill(test_value)
    projdata.sapyb(3,projdata,2)
    assert projdata.to_array().find_max()==approx_val
    assert projdata.to_array().find_min()==approx_val

def test_multiply_crystal_factors():
    # Create proj data
    s=Scanner.get_scanner_from_name("ECAT 962")
    projdatainfo=ProjDataInfo.construct_proj_data_info(s,1,9,8,6,False)
    projdata=ProjDataInMemory(ExamInfo(),projdatainfo)
    projdata.fill(1)

    # Create array
    efficiencies = FloatArray2D(IndexRange2D(Int2BasicCoordinate((0,0)),
                                Int2BasicCoordinate((s.get_num_rings() - 1, s.get_num_detectors_per_ring() - 1))))
    efficiencies.fill(1)

    # Test multiply_crystal_factors()
    multiply_crystal_factors(projdata, efficiencies, 1.0)

    assert projdata.find_max() == projdata.find_min()
    view_mash_factor = s.get_num_detectors_per_ring() / 2 / projdatainfo.get_num_views()
    assert projdata.find_max() == view_mash_factor  # only true for span=1, as in this test case
