Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import sys
import numpy as np
import xarray as xr
from pathlib import Path
PRINT_MESSAGE = 'Test {0: >3} - Filename: {1: <64} - Total difference: {2: <8}'
ERROR_MESSAGE = 'Found differences in test cases: {}'
SUCCESS_MESSAGE = 'SUCCESSFUL TEST RUN: All files match!'
USAGE = 'USAGE: ./check_bit_4_bit.py BENCHMARK_DIRECTORY TEST_DIRECTORY'
if len(sys.argv) != 3:
print(USAGE)
exit(1)
bench_dir = sys.argv[1]
test_dir = sys.argv[2]
bench_files = sorted(list(Path(bench_dir).glob('**/*.nc')))
test_files = sorted(list(Path(test_dir).glob('**/*.nc')))
assert len(bench_files) == len(test_files), \
'Found {} files but need {}!'.format(len(test_files), len(bench_files))
def rem(li1, li2):
return list(set(li1) - set(li2))
def compute_diffs(bench_files, test_files):
all_diffs = []
all_tots = []
for i, (f1, f2) in enumerate(zip(bench_files, test_files)):
ds1, ds2 = xr.open_dataset(str(f1)), xr.open_dataset(str(f2))
diff = (ds1 - ds2).sum(dim='time')
tot = 0.0
for v in rem(list(diff.variables.keys()), list(diff.dims.keys())):
tot += np.sum(diff[v].values)
all_diffs.append(diff)
all_tots.append(tot)
print(PRINT_MESSAGE.format(i, str(f1).split('/')[-1], tot))
return all_diffs, all_tots
all_diffs, all_tots = compute_diffs(bench_files, test_files)
assert np.sum(np.absolute(all_tots)) == 0.0, \
ERROR_MESSAGE.format(np.argwhere(np.asarray(all_tots) != 0))
print(SUCCESS_MESSAGE)