#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 14 18:24:26 2022
interpolateBadDICOMs.py
@author: fordb


This is my attempt at resolving the following error from dcm2niix:

FileSize < (ImageSize+HeaderSize): 720896 < (627200+158720) 
Warning: File not large enough to store image data: <filename>

Inspecting the offending sequence of DICOMs, one or more will have a smaller 
filesize than the rest. If these dicoms are simply deleted, dcm2niix will 
process the file. However, these missing frames then make the timing 
information unreliable. So, these frames could merely be interpolated, and 
optionally these frames each scrubbed by modelling their effect in a 
first-level model. 

So, I would like to write a little script that would take a folder with the 
offending sequence in it, ID the bad dicoms, dump them, convert the file, 
insert frames where the old ones were dumped, interpolate data into those 
frames, and then write the new file, as well as a sidecar file indicating 
which frames were interpolated and therefore should be scrubbed. 

Not all dicoms from a 'good' sequence will be exactly the same size. So this 
really needs to only detect files that are significantly different. 

I was doing this with just the file sizes, but I could also have just used pydicom
ds = pydicom.dcmread(dicomfile)
and checked ds.pixel_array, the bad files return a ValueError: The length of 
the pixel data in the dataset (369224 bytes) doesn't match the expected length 
(401408 bytes). The dataset may be corrupted or there may be an issue with the 
pixel data handler.
and worked off that as an error instead?



Okay, replacing the file with zeros makes programs like SPM, when trying to 
realign the subsequent file, throw errors. So instead I'll replace them with
a clearely intensity-increased version of the file so it will get flagged still?
Not sure the best way to go about this. The other option is to interpolate 
the data between the two files, but this ideally would include e.g. motion
correction and similar stuff. Not straightforward things I can produce.  
"""

import sys
import os 
import numpy as np
import pydicom

#takes a folder as an input
zscorethresh = 2
globalOffsetMultiplier = 2

def main():
    if os.path.isdir(sys.argv[1]):
        #the input is a directory, lets find the contents of the directory
        path = sys.argv[1]
        filelist = os.listdir(path)
        print("Found " + str(len(filelist)) + " files.")
        if (len(filelist)>1):
            print("Sizing files...")
            sizes={}
            for file in filelist:
                sizes[file]=os.path.getsize(os.path.join(path,file))
            values, counts = np.unique(list(sizes.values()), return_counts=True)
            
            #compute and report size distributions
            print("Detected "+str(len(values))+" unique file sizes:")
            meanSize = np.mean(list(sizes.values()))
            stdSize = np.std(list(sizes.values()))
            hitlist = []
            for x in range(len(values)):
                zscore = (values[x]-meanSize)/stdSize
                print(str(counts[x]) + " files of size " + str(values[x]) + ". Size Zscore = " + str(np.round(zscore,2)), end = '')
                if (np.abs(zscore)>zscorethresh):
                    print(" *")
                    hitlist.append(values[x])
                else:
                    print("")
            print("Zscore threshold is " + str(zscorethresh))
            
            if len(hitlist)>0:
                print("Filesizes marked with an asterisk exceed threshold")
                print("These files will be OVERWRITTEN...")
                firstGood = 0
                goodfile = None
                for file in filelist:
                    #find a good file to start with
                    if os.path.getsize(os.path.join(path,file)) not in hitlist:
                        goodfile = pydicom.dcmread(os.path.join(path,file))
                        firstGood = 1
                        break
                if firstGood == 0:
                    print("Couldn't find a good file...")
                    sys.exit()
                #now find offending files and rewrite their pixel data
                for file in filelist:
                    #find a good file to start with
                    if os.path.getsize(os.path.join(path,file)) in hitlist:
                        print("Overwriting " + file)
                        tempfile = pydicom.dcmread(os.path.join(path,file))
                        tempfile.PixelData = (goodfile.pixel_array * globalOffsetMultiplier).tobytes()
                        tempfile.save_as(os.path.join(path,file))
                print("Done")
            else:
                print("No files exceed size threshold")
    else:
        print(sys.argv[1] + " is not a directory")

main()
