Skip to content
Snippets Groups Projects
checkbit4bit.py 1.46 KiB
Newer Older
  • Learn to ignore specific revisions
  • import sys
    import numpy as np
    import xarray as xr
    from pathlib import Path
    
    PRINT_MESSAGE = 'Test {0: >3} - Filename: {1: <64} - Total difference: {2: <8}'
    ERROR_MESSAGE = 'Found differences in test cases: {}'
    SUCCESS_MESSAGE = 'SUCCESSFUL TEST RUN: All files match!'
    USAGE = 'USAGE: ./check_bit_4_bit.py BENCHMARK_DIRECTORY TEST_DIRECTORY'
    
    if len(sys.argv) != 3:
        print(USAGE)
        exit(1)
    
    bench_dir = sys.argv[1]
    test_dir = sys.argv[2]
    
    bench_files = sorted(list(Path(bench_dir).glob('**/*.nc')))
    test_files = sorted(list(Path(test_dir).glob('**/*.nc')))
    
    assert len(bench_files) == len(test_files), \
        'Found {} files but need {}!'.format(len(test_files), len(bench_files))
    
    def rem(li1, li2):
        return list(set(li1) - set(li2))
    
    def compute_diffs(bench_files, test_files):
        all_diffs = []
        all_tots = []
        for i, (f1, f2) in enumerate(zip(bench_files, test_files)):
            ds1, ds2 = xr.open_dataset(str(f1)), xr.open_dataset(str(f2))
            diff = (ds1 - ds2).sum(dim='time')
            tot = 0.0
            for v in rem(list(diff.variables.keys()), list(diff.dims.keys())):
                tot += np.sum(diff[v].values)
            all_diffs.append(diff)
            all_tots.append(tot)
            print(PRINT_MESSAGE.format(i, str(f1).split('/')[-1], tot))
        return all_diffs, all_tots
    
    all_diffs, all_tots = compute_diffs(bench_files, test_files)
    
    assert np.sum(np.absolute(all_tots)) == 0.0, \
        ERROR_MESSAGE.format(np.argwhere(np.asarray(all_tots) != 0))
    
    print(SUCCESS_MESSAGE)