Source code for graphid.core.mixin_simulation

"""
Mixin functionality for experiments, tests, and simulations.
This includes recordings measures used to generate plots in JC's thesis.
"""
import itertools as it
import ubelt as ub
import pandas as pd
from functools import partial
from graphid.core.state import (POSTV, NEGTV, INCMP, UNREV, UNKWN, NULL)
from graphid import util


[docs] class SimulationHelpers:
[docs] def init_simulation(infr, oracle_accuracy=1.0, k_redun=2, enable_autoreview=True, enable_inference=True, classifiers=None, match_state_thresh=None, max_outer_loops=None, name=None): infr.print('INIT SIMULATION', color='yellow') infr.name = name infr.simulation_mode = True infr.verifiers = classifiers infr.params['inference.enabled'] = enable_inference infr.params['autoreview.enabled'] = enable_autoreview infr.params['redun.pos'] = k_redun infr.params['redun.neg'] = k_redun # keeps track of edges where the decision != the groundtruth infr.mistake_edges = set() infr.queue = util.PriorityQueue() infr.oracle = UserOracle(oracle_accuracy, rng=infr.name) if match_state_thresh is None: match_state_thresh = { POSTV: 1.0, NEGTV: 1.0, INCMP: 1.0, } pb_state_thresh = None if pb_state_thresh is None: pb_state_thresh = { 'pb': .5, 'notpb': .9, } infr.task_thresh = { 'photobomb_state': pd.Series(pb_state_thresh), 'match_state': pd.Series(match_state_thresh) } infr.params['algo.max_outer_loops'] = max_outer_loops
[docs] def init_test_mode(infr): from graphid.core import nx_dynamic_graph infr.print('init_test_mode') infr.test_mode = True # infr.edge_truth = {} infr.metrics_list = [] infr.test_state = { 'n_decision': 0, 'n_algo': 0, 'n_manual': 0, 'n_true_merges': 0, 'n_error_edges': 0, 'confusion': None, } infr.test_gt_pos_graph = nx_dynamic_graph.DynConnGraph() infr.test_gt_pos_graph.add_nodes_from(infr.aids) infr.nid_to_gt_cc = ub.group_items(infr.aids, infr.orig_name_labels) infr.node_truth = ub.dzip(infr.aids, infr.orig_name_labels) # infr.real_n_pcc_mst_edges = sum( # len(cc) - 1 for cc in infr.nid_to_gt_cc.values()) # util.cprint('real_n_pcc_mst_edges = %r' % ( # infr.real_n_pcc_mst_edges,), 'red') infr.metrics_list = [] infr.nid_to_gt_cc = ub.group_items(infr.aids, infr.orig_name_labels) infr.real_n_pcc_mst_edges = sum( len(cc) - 1 for cc in infr.nid_to_gt_cc.values()) infr.print('real_n_pcc_mst_edges = %r' % ( infr.real_n_pcc_mst_edges,), color='red')
[docs] def measure_error_edges(infr): for edge, data in infr.edges(data=True): true_state = data['truth'] pred_state = data.get('evidence_decision', UNREV) if pred_state != UNREV: if true_state != pred_state: error = ub.odict([('real', true_state), ('pred', pred_state)]) yield edge, error
[docs] def measure_metrics(infr): real_pos_edges = [] n_true_merges = infr.test_state['n_true_merges'] confusion = infr.test_state['confusion'] n_tp = confusion[POSTV][POSTV] confusion[POSTV] columns = set(confusion.keys()) reviewd_cols = columns - {UNREV} non_postv = reviewd_cols - {POSTV} non_negtv = reviewd_cols - {NEGTV} n_fn = sum(ub.take(confusion[POSTV], non_postv)) n_fp = sum(ub.take(confusion[NEGTV], non_negtv)) n_error_edges = sum(confusion[r][c] + confusion[c][r] for r, c in it.combinations(reviewd_cols, 2)) # assert n_fn + n_fp == n_error_edges pred_n_pcc_mst_edges = n_true_merges # Find all annotations involved in a mistake assert n_error_edges == len(infr.mistake_edges) direct_mistake_aids = {a for edge in infr.mistake_edges for a in edge} mistake_nids = set(infr.node_labels(*direct_mistake_aids)) mistake_aids = set(ub.flatten([infr.pos_graph.component(nid) for nid in mistake_nids])) pos_acc = pred_n_pcc_mst_edges / infr.real_n_pcc_mst_edges metrics = { 'n_decision': infr.test_state['n_decision'], 'n_manual': infr.test_state['n_manual'], 'n_algo': infr.test_state['n_algo'], 'phase': infr.loop_phase, 'pos_acc': pos_acc, 'n_merge_total': infr.real_n_pcc_mst_edges, 'n_merge_remain': infr.real_n_pcc_mst_edges - n_true_merges, 'n_true_merges': n_true_merges, 'recovering': infr.is_recovering(), # 'recovering2': infr.test_state['recovering'], 'merge_remain': 1 - pos_acc, 'n_mistake_aids': len(mistake_aids), 'frac_mistake_aids': len(mistake_aids) / len(infr.aids), 'n_mistake_nids': len(mistake_nids), 'n_errors': n_error_edges, 'n_fn': n_fn, 'n_fp': n_fp, 'refresh_support': len(infr.refresh.manual_decisions), 'pprob_any': infr.refresh.prob_any_remain(), 'mu': infr.refresh._ewma, 'test_action': infr.test_state['test_action'], 'action': infr.test_state.get('action', None), 'user_id': infr.test_state['user_id'], 'pred_decision': infr.test_state['pred_decision'], 'true_decision': infr.test_state['true_decision'], 'n_neg_redun': infr.neg_redun_metagraph.number_of_edges(), 'n_neg_redun1': (infr.neg_metagraph.number_of_edges() - infr.neg_metagraph.number_of_selfloops()), } return metrics
[docs] def _print_previous_loop_statistics(infr, count): # Print stats about what happend in the this loop history = infr.metrics_list[-count:] recover_blocks = ub.group_items([ (k, sum(1 for i in g)) for k, g in it.groupby(util.take_column(history, 'recovering')) ]).get(True, []) infr.print(( 'Recovery mode entered {} times, ' 'made {} recovery decisions.').format( len(recover_blocks), sum(recover_blocks)), color='green') testaction_hist = ub.dict_hist(util.take_column(history, 'test_action')) infr.print( 'Test Action Histogram: {}'.format( ub.urepr(testaction_hist, si=True)), color='yellow') if infr.params['inference.enabled']: action_hist = ub.dict_hist( util.emap(frozenset, util.take_column(history, 'action'))) infr.print( 'Inference Action Histogram: {}'.format( ub.urepr(action_hist, si=True)), color='yellow') infr.print( 'Decision Histogram: {}'.format(ub.urepr(ub.dict_hist( util.take_column(history, 'pred_decision') ), si=True)), color='yellow') infr.print( 'User Histogram: {}'.format(ub.urepr(ub.dict_hist( util.take_column(history, 'user_id') ), si=True)), color='yellow')
[docs] def _dynamic_test_callback(infr, edge, decision, prev_decision, user_id): was_gt_pos = infr.test_gt_pos_graph.has_edge(*edge) # prev_decision = infr.get_edge_attr(edge, 'decision', default=UNREV) # prev_decision = list(infr.edge_decision_from([edge]))[0] true_decision = infr.edge_truth[edge] was_within_pred = infr.pos_graph.are_nodes_connected(*edge) was_within_gt = infr.test_gt_pos_graph.are_nodes_connected(*edge) was_reviewed = prev_decision != UNREV is_within_gt = was_within_gt was_correct = prev_decision == true_decision is_correct = true_decision == decision # print('prev_decision = {!r}'.format(prev_decision)) # print('decision = {!r}'.format(decision)) # print('true_decision = {!r}'.format(true_decision)) test_print = partial(infr.print, level=2) def test_print(x, **kw): infr.print('[ACTION] ' + x, level=2, **kw) # test_print = lambda *a, **kw: None # NOQA if decision == POSTV: if is_correct: if not was_gt_pos: infr.test_gt_pos_graph.add_edge(*edge) elif was_gt_pos: test_print("UNDID GOOD POSITIVE EDGE", color='darkred') infr.test_gt_pos_graph.remove_edge(*edge) is_within_gt = infr.test_gt_pos_graph.are_nodes_connected(*edge) split_gt = is_within_gt != was_within_gt if split_gt: test_print("SPLIT A GOOD MERGE", color='darkred') infr.test_state['n_true_merges'] -= 1 confusion = infr.test_state['confusion'] if confusion is None: # initialize dynamic confusion matrix states = (POSTV, NEGTV, INCMP, UNREV, UNKWN) confusion = {r: {c: 0 for c in states} for r in states} infr.test_state['confusion'] = confusion if was_reviewed: confusion[true_decision][prev_decision] -= 1 confusion[true_decision][decision] += 1 else: confusion[true_decision][decision] += 1 test_action = None action_color = None if is_correct: # CORRECT DECISION if was_reviewed: if prev_decision == decision: test_action = 'correct duplicate' action_color = 'darkyellow' else: infr.mistake_edges.remove(edge) test_action = 'correction' action_color = 'darkgreen' if decision == POSTV: if not was_within_gt: test_action = 'correction redid merge' action_color = 'darkgreen' infr.test_state['n_true_merges'] += 1 else: if decision == POSTV: if not was_within_gt: test_action = 'correct merge' action_color = 'darkgreen' infr.test_state['n_true_merges'] += 1 else: test_action = 'correct redundant positive' action_color = 'darkblue' else: if decision == NEGTV: test_action = 'correct negative' action_color = 'teal' else: test_action = 'correct uninferrable' action_color = 'teal' else: action_color = 'darkred' # INCORRECT DECISION infr.mistake_edges.add(edge) if was_reviewed: if prev_decision == decision: test_action = 'incorrect duplicate' elif was_correct: test_action = 'incorrect undid good edge' else: if decision == POSTV: if was_within_pred: test_action = 'incorrect redundant merge' else: test_action = 'incorrect new merge' else: test_action = 'incorrect new mistake' infr.test_state['test_action'] = test_action infr.test_state['pred_decision'] = decision infr.test_state['true_decision'] = true_decision infr.test_state['user_id'] = user_id infr.test_state['recovering'] = (infr.recover_graph.has_node(edge[0]) or infr.recover_graph.has_node(edge[1])) infr.test_state['n_decision'] += 1 if user_id.startswith('algo'): infr.test_state['n_algo'] += 1 elif user_id.startswith('user') or user_id == 'oracle': infr.test_state['n_manual'] += 1 else: raise AssertionError('unknown user_id=%r' % (user_id,)) test_print(test_action, color=action_color) assert test_action is not None, 'what happened?'
[docs] class UserOracle: def __init__(oracle, accuracy, rng): if isinstance(rng, str): rng = sum(map(ord, rng)) rng = util.ensure_rng(rng, api='python') if isinstance(accuracy, tuple): oracle.normal_accuracy = accuracy[0] oracle.recover_accuracy = accuracy[1] else: oracle.normal_accuracy = accuracy oracle.recover_accuracy = accuracy oracle.rng = rng oracle.states = {POSTV, NEGTV, INCMP}
[docs] def review(oracle, edge, truth, infr, accuracy=None): feedback = { 'user_id': 'user:oracle', 'confidence': 'absolutely_sure', 'evidence_decision': None, 'meta_decision': NULL, 'timestamp_s1': util.get_timestamp('int', isutc=True), 'timestamp_c1': util.get_timestamp('int', isutc=True), 'timestamp_c2': util.get_timestamp('int', isutc=True), 'tags': [], } is_recovering = infr.is_recovering() if accuracy is None: if is_recovering: accuracy = oracle.recover_accuracy else: accuracy = oracle.normal_accuracy # The oracle can get anything where the hardness is less than its # accuracy hardness = oracle.rng.random() error = accuracy < hardness if error: error_options = list(oracle.states - {truth} - {INCMP}) observed = oracle.rng.choice(list(error_options)) else: observed = truth if accuracy < 1.0: feedback['confidence'] = 'pretty_sure' if accuracy < .5: feedback['confidence'] = 'guessing' feedback['evidence_decision'] = observed if error: infr.print( 'ORACLE ERROR real={} pred={} acc={:.2f} hard={:.2f}'.format( truth, observed, accuracy, hardness), 2, color='red') return feedback
if __name__ == '__main__': """ CommandLine: python ~/code/graphid/graphid.core/mixin_simulation.py all """ import xdoctest xdoctest.doctest_module(__file__)