statcmp.py 7.93 KB
Newer Older
1 2 3
#!/usr/bin/env python3

import argparse
4
import logging
5 6 7 8 9 10 11 12 13 14 15
import sys
from typing import Optional, Sequence

from respdiff import cli
from respdiff.dataformat import Summary
from respdiff.stats import Stats, SummaryStatistics
from respdiff.typing import FieldLabel

# pylint: disable=wrong-import-order,wrong-import-position
import matplotlib
import matplotlib.axes
16
import matplotlib.ticker
17 18 19 20 21 22 23
matplotlib.use('Agg')
import matplotlib.pyplot as plt  # noqa


COLOR_OK = 'tab:blue'
COLOR_GOOD = 'tab:green'
COLOR_BAD = 'xkcd:bright red'
24
COLOR_THRESHOLD = 'tab:orange'
25 26 27 28 29
COLOR_BG = 'tab:gray'
COLOR_LABEL = 'black'

VIOLIN_FIGSIZE = (3, 6)

30
SAMPLE_COLORS = {
31 32 33 34
    Stats.SamplePosition.ABOVE_REF: COLOR_BAD,
    Stats.SamplePosition.ABOVE_THRESHOLD: COLOR_BAD,
    Stats.SamplePosition.NORMAL: COLOR_OK,
    Stats.SamplePosition.BELOW_REF: COLOR_GOOD,
35
}
36 37


38
class AxisMarker:
39
    def __init__(self, position: float, width: float = 0.7, color: str = COLOR_BG) -> None:
40 41 42
        self.position = position
        self.width = width
        self.color = color
43

44
    def draw(self, ax: matplotlib.axes.Axes):
45 46 47
        xmin = (1 - self.width) / 2
        xmax = 1 - xmin
        ax.axhline(self.position, color=self.color, xmin=xmin, xmax=xmax)
48

49 50 51 52 53 54 55 56

def plot_violin(
            ax: matplotlib.axes.Axes,
            violin_data: Sequence[float],
            markers: Sequence[AxisMarker],
            label: str,
            color: str = COLOR_LABEL
        ) -> None:
57 58
    ax.set_title(label, fontdict={'fontsize': 14}, color=color)

59 60 61 62 63 64 65 66 67 68 69 70
    # plot violin graph
    violin_parts = ax.violinplot(violin_data, bw_method=0.07,
                                 showmedians=False, showextrema=False)
    # set violin background color
    for pc in violin_parts['bodies']:
        pc.set_facecolor(COLOR_BG)
        pc.set_edgecolor(COLOR_BG)

    # draw axis markers
    for marker in markers:
        marker.draw(ax)

71 72 73 74 75 76 77
    # turn off axis spines
    for sp in ['right', 'top', 'bottom']:
        ax.spines[sp].set_color('none')
    # move the left ax spine to center
    ax.spines['left'].set_position(('data', 1))

    # customize axis ticks
78 79 80 81
    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
    ax.xaxis.set_minor_locator(matplotlib.ticker.NullLocator())
    if max(violin_data) == 0:  # fix tick at 0 when there's no data
        ax.yaxis.set_major_locator(matplotlib.ticker.FixedLocator([0]))
82
    else:
83
        ax.yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(
84
            nbins='auto', steps=[1, 2, 4, 5, 10], integer=True))
85
    ax.yaxis.set_minor_locator(matplotlib.ticker.NullLocator())
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    ax.tick_params(labelsize=14)


def _axes_iter(axes, width: int):
    index = 0
    while True:
        ix = index // width
        iy = index % width
        index += 1
        try:
            yield axes[ix, iy]
        except IndexError:
            return


101 102 103 104 105 106 107 108 109 110 111 112
def eval_and_plot_single(
            ax: matplotlib.axes.Axes,
            stats: Stats,
            label: str,
            samples: Sequence[float]
        ) -> bool:
    markers = []
    below_min = False
    above_thr = False
    for sample in samples:
        result = stats.evaluate_sample(sample)
        markers.append(AxisMarker(sample, color=SAMPLE_COLORS[result]))
113 114
        if result in (Stats.SamplePosition.ABOVE_REF,
                      Stats.SamplePosition.ABOVE_THRESHOLD):
115 116 117 118
            above_thr = True
            logging.error(
                '  %s: threshold exceeded! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%',
                label, sample, stats.get_percentile_rank(sample),
119 120
                stats.threshold, stats.get_percentile_rank(stats.threshold))
        elif result == Stats.SamplePosition.BELOW_REF:
121 122 123 124 125 126 127 128
            below_min = True
            logging.info(
                '  %s: new minimum found! new: %d vs prev: %d',
                label, sample, stats.min)
        else:
            logging.info(
                '  %s: ok! sample: %d / %4.2f%% vs threshold: %d / %4.2f%%',
                label, sample, stats.get_percentile_rank(sample),
129
                stats.threshold, stats.get_percentile_rank(stats.threshold))
130 131 132 133 134

    # add min/med/max markers
    markers.append(AxisMarker(stats.min, 0.5, COLOR_BG))
    markers.append(AxisMarker(stats.median, 0.5, COLOR_BG))
    markers.append(AxisMarker(stats.max, 0.5, COLOR_BG))
135
    markers.append(AxisMarker(stats.threshold, 0.9, COLOR_THRESHOLD))
136 137 138 139 140 141 142 143 144

    # select label color
    if above_thr:
        color = COLOR_BAD
    elif below_min:
        color = COLOR_GOOD
    else:
        color = COLOR_LABEL

145
    plot_violin(ax, stats.samples, markers, label, color)
146 147 148 149

    return not above_thr


150
def plot_overview(
151 152
            sumstats: SummaryStatistics,
            fields: Sequence[FieldLabel],
153
            summaries: Optional[Sequence[Summary]] = None,
154 155
            label: str = 'fields_overview'
        ) -> bool:
156 157 158
    """
    Plot an overview of all fields using violing graphs. If any summaries are provided,
    they are drawn in the graphs and also evaluated. If any sample in any field exceeds
159
    the threshold, the function return False. True is returned otherwise.
160 161 162
    """
    if summaries is None:
        summaries = []
163 164 165 166 167 168 169 170 171 172 173 174 175 176

    passed = True
    OVERVIEW_X_FIG = 7
    OVERVIEW_Y_FIG = 3

    # prepare subplot axis
    fig, axes = plt.subplots(
        OVERVIEW_Y_FIG,
        OVERVIEW_X_FIG,
        figsize=(OVERVIEW_X_FIG*VIOLIN_FIGSIZE[0], OVERVIEW_Y_FIG*VIOLIN_FIGSIZE[1]))
    ax_it = _axes_iter(axes, OVERVIEW_X_FIG)

    # target disagreements
    assert sumstats.target_disagreements is not None
177
    samples = [len(summary) for summary in summaries]
178 179
    passed &= eval_and_plot_single(
        next(ax_it), sumstats.target_disagreements, 'target_disagreements', samples)
180 181 182

    # upstream unstable
    assert sumstats.upstream_unstable is not None
183
    samples = [summary.upstream_unstable for summary in summaries]
184 185
    passed &= eval_and_plot_single(
        next(ax_it), sumstats.upstream_unstable, 'upstream_unstable', samples)
186 187 188

    # not 100% reproducible
    assert sumstats.not_reproducible is not None
189
    samples = [summary.not_reproducible for summary in summaries]
190 191
    passed &= eval_and_plot_single(
        next(ax_it), sumstats.not_reproducible, 'not_reproducible', samples)
192 193 194

    # fields
    assert sumstats.fields is not None
195
    fcs = [summary.get_field_counters() for summary in summaries]
196
    for field in fields:
197 198 199 200 201
        if field not in sumstats.fields:
            logging.warning('Field "%s" missing in statistics, omitting...', field)
            continue
        passed &= eval_and_plot_single(
            next(ax_it),
202 203 204 205 206 207 208 209 210 211 212
            sumstats.fields[field].total,
            field,
            [len(list(fc[field].elements())) for fc in fcs])

    # hide unused axis
    for ax in ax_it:
        ax.set_visible(False)

    # display sample size
    fig.text(
        0.95, 0.95,
213
        'stat sample size: {}'.format(len(sumstats.target_disagreements.samples)),
214 215 216 217 218 219
        fontsize=18, color=COLOR_BG, ha='right', va='bottom', alpha=0.7)

    # save image
    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    fig.suptitle(label, fontsize=22)
220
    plt.savefig('{}.png'.format(label))
221 222 223 224 225 226 227 228 229
    plt.close()

    return passed


def main():
    cli.setup_logging()
    parser = argparse.ArgumentParser(
        description=("Plot and compare reports against statistical data. "
230
                     "Returns non-zero exit code if any threshold is exceeded."))
231 232 233 234 235 236 237 238 239 240 241

    cli.add_arg_stats(parser)
    cli.add_arg_report(parser)
    cli.add_arg_config(parser)
    parser.add_argument('-l', '--label', default='fields_overview',
                        help='Set plot label. It is also used for the filename.')

    args = parser.parse_args()
    sumstats = args.stats
    field_weights = args.cfg['report']['field_weights']

242 243 244 245 246
    try:
        summaries = cli.load_summaries(args.report)
    except ValueError:
        sys.exit(1)

247
    logging.info('Start Comparison: %s', args.label)
248
    passed = plot_overview(sumstats, field_weights, summaries, args.label)
249 250 251 252 253 254 255 256 257

    if not passed:
        sys.exit(3)
    else:
        sys.exit(0)


if __name__ == '__main__':
    main()