import sys
import numpy as np
import xarray as xr
from pathlib import Path

PRINT_MESSAGE = 'Test {0: >3} - Filename: {1: <64} - Total difference: {2: <8}'
ERROR_MESSAGE = 'Found differences in test cases: {}'
SUCCESS_MESSAGE = 'SUCCESSFUL TEST RUN: All files match!'
USAGE = 'USAGE: ./check_bit_4_bit.py BENCHMARK_DIRECTORY TEST_DIRECTORY'

if len(sys.argv) != 3:
    print(USAGE)
    exit(1)

bench_dir = sys.argv[1]
test_dir = sys.argv[2]

bench_files = sorted(list(Path(bench_dir).glob('**/*.nc')))
test_files = sorted(list(Path(test_dir).glob('**/*.nc')))

assert len(bench_files) == len(test_files), \
    'Found {} files but need {}!'.format(len(test_files), len(bench_files))

def rem(li1, li2):
    return list(set(li1) - set(li2))

def compute_diffs(bench_files, test_files):
    all_diffs = []
    all_tots = []
    for i, (f1, f2) in enumerate(zip(bench_files, test_files)):
        ds1, ds2 = xr.open_dataset(str(f1)), xr.open_dataset(str(f2))
        diff = (ds1 - ds2).sum(dim='time')
        tot = 0.0
        for v in rem(list(diff.variables.keys()), list(diff.dims.keys())):
            tot += np.sum(diff[v].values)
        all_diffs.append(diff)
        all_tots.append(tot)
        print(PRINT_MESSAGE.format(i, str(f1).split('/')[-1], tot))
    return all_diffs, all_tots

all_diffs, all_tots = compute_diffs(bench_files, test_files)

assert np.sum(np.absolute(all_tots)) == 0.0, \
    ERROR_MESSAGE.format(np.argwhere(np.asarray(all_tots) != 0))

print(SUCCESS_MESSAGE)