Source code for graphid.core.mixin_simulation

"""
Mixin functionality for experiments, tests, and simulations.
This includes recordings measures used to generate plots in JC's thesis.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import itertools as it
import ubelt as ub
import pandas as pd
from functools import partial
from graphid.core.state import (POSTV, NEGTV, INCMP, UNREV, UNKWN, NULL)
from graphid import util


[docs] class SimulationHelpers(object):
[docs] def init_simulation(infr, oracle_accuracy=1.0, k_redun=2, enable_autoreview=True, enable_inference=True, classifiers=None, match_state_thresh=None, max_outer_loops=None, name=None): infr.print('INIT SIMULATION', color='yellow') infr.name = name infr.simulation_mode = True infr.verifiers = classifiers infr.params['inference.enabled'] = enable_inference infr.params['autoreview.enabled'] = enable_autoreview infr.params['redun.pos'] = k_redun infr.params['redun.neg'] = k_redun # keeps track of edges where the decision != the groundtruth infr.mistake_edges = set() infr.queue = util.PriorityQueue() infr.oracle = UserOracle(oracle_accuracy, rng=infr.name) if match_state_thresh is None: match_state_thresh = { POSTV: 1.0, NEGTV: 1.0, INCMP: 1.0, } pb_state_thresh = None if pb_state_thresh is None: pb_state_thresh = { 'pb': .5, 'notpb': .9, } infr.task_thresh = { 'photobomb_state': pd.Series(pb_state_thresh), 'match_state': pd.Series(match_state_thresh) } infr.params['algo.max_outer_loops'] = max_outer_loops
[docs] def init_test_mode(infr): from graphid.core import nx_dynamic_graph infr.print('init_test_mode') infr.test_mode = True # infr.edge_truth = {} infr.metrics_list = [] infr.test_state = { 'n_decision': 0, 'n_algo': 0, 'n_manual': 0, 'n_true_merges': 0, 'n_error_edges': 0, 'confusion': None, } infr.test_gt_pos_graph = nx_dynamic_graph.DynConnGraph() infr.test_gt_pos_graph.add_nodes_from(infr.aids) infr.nid_to_gt_cc = ub.group_items(infr.aids, infr.orig_name_labels) infr.node_truth = ub.dzip(infr.aids, infr.orig_name_labels) # infr.real_n_pcc_mst_edges = sum( # len(cc) - 1 for cc in infr.nid_to_gt_cc.values()) # util.cprint('real_n_pcc_mst_edges = %r' % ( # infr.real_n_pcc_mst_edges,), 'red') infr.metrics_list = [] infr.nid_to_gt_cc = ub.group_items(infr.aids, infr.orig_name_labels) infr.real_n_pcc_mst_edges = sum( len(cc) - 1 for cc in infr.nid_to_gt_cc.values()) infr.print('real_n_pcc_mst_edges = %r' % ( infr.real_n_pcc_mst_edges,), color='red')
[docs] def measure_error_edges(infr): for edge, data in infr.edges(data=True): true_state = data['truth'] pred_state = data.get('evidence_decision', UNREV) if pred_state != UNREV: if true_state != pred_state: error = ub.odict([('real', true_state), ('pred', pred_state)]) yield edge, error
[docs] def measure_metrics(infr): real_pos_edges = [] n_true_merges = infr.test_state['n_true_merges'] confusion = infr.test_state['confusion'] n_tp = confusion[POSTV][POSTV] confusion[POSTV] columns = set(confusion.keys()) reviewd_cols = columns - {UNREV} non_postv = reviewd_cols - {POSTV} non_negtv = reviewd_cols - {NEGTV} n_fn = sum(ub.take(confusion[POSTV], non_postv)) n_fp = sum(ub.take(confusion[NEGTV], non_negtv)) n_error_edges = sum(confusion[r][c] + confusion[c][r] for r, c in it.combinations(reviewd_cols, 2)) # assert n_fn + n_fp == n_error_edges pred_n_pcc_mst_edges = n_true_merges # Find all annotations involved in a mistake assert n_error_edges == len(infr.mistake_edges) direct_mistake_aids = {a for edge in infr.mistake_edges for a in edge} mistake_nids = set(infr.node_labels(*direct_mistake_aids)) mistake_aids = set(ub.flatten([infr.pos_graph.component(nid) for nid in mistake_nids])) pos_acc = pred_n_pcc_mst_edges / infr.real_n_pcc_mst_edges metrics = { 'n_decision': infr.test_state['n_decision'], 'n_manual': infr.test_state['n_manual'], 'n_algo': infr.test_state['n_algo'], 'phase': infr.loop_phase, 'pos_acc': pos_acc, 'n_merge_total': infr.real_n_pcc_mst_edges, 'n_merge_remain': infr.real_n_pcc_mst_edges - n_true_merges, 'n_true_merges': n_true_merges, 'recovering': infr.is_recovering(), # 'recovering2': infr.test_state['recovering'], 'merge_remain': 1 - pos_acc, 'n_mistake_aids': len(mistake_aids), 'frac_mistake_aids': len(mistake_aids) / len(infr.aids), 'n_mistake_nids': len(mistake_nids), 'n_errors': n_error_edges, 'n_fn': n_fn, 'n_fp': n_fp, 'refresh_support': len(infr.refresh.manual_decisions), 'pprob_any': infr.refresh.prob_any_remain(), 'mu': infr.refresh._ewma, 'test_action': infr.test_state['test_action'], 'action': infr.test_state.get('action', None), 'user_id': infr.test_state['user_id'], 'pred_decision': infr.test_state['pred_decision'], 'true_decision': infr.test_state['true_decision'], 'n_neg_redun': infr.neg_redun_metagraph.number_of_edges(), 'n_neg_redun1': (infr.neg_metagraph.number_of_edges() - infr.neg_metagraph.number_of_selfloops()), } return metrics
[docs] def _print_previous_loop_statistics(infr, count): # Print stats about what happend in the this loop history = infr.metrics_list[-count:] recover_blocks = ub.group_items([ (k, sum(1 for i in g)) for k, g in it.groupby(util.take_column(history, 'recovering')) ]).get(True, []) infr.print(( 'Recovery mode entered {} times, ' 'made {} recovery decisions.').format( len(recover_blocks), sum(recover_blocks)), color='green') testaction_hist = ub.dict_hist(util.take_column(history, 'test_action')) infr.print( 'Test Action Histogram: {}'.format( ub.urepr(testaction_hist, si=True)), color='yellow') if infr.params['inference.enabled']: action_hist = ub.dict_hist( util.emap(frozenset, util.take_column(history, 'action'))) infr.print( 'Inference Action Histogram: {}'.format( ub.urepr(action_hist, si=True)), color='yellow') infr.print( 'Decision Histogram: {}'.format(ub.urepr(ub.dict_hist( util.take_column(history, 'pred_decision') ), si=True)), color='yellow') infr.print( 'User Histogram: {}'.format(ub.urepr(ub.dict_hist( util.take_column(history, 'user_id') ), si=True)), color='yellow')
[docs] def _dynamic_test_callback(infr, edge, decision, prev_decision, user_id): was_gt_pos = infr.test_gt_pos_graph.has_edge(*edge) # prev_decision = infr.get_edge_attr(edge, 'decision', default=UNREV) # prev_decision = list(infr.edge_decision_from([edge]))[0] true_decision = infr.edge_truth[edge] was_within_pred = infr.pos_graph.are_nodes_connected(*edge) was_within_gt = infr.test_gt_pos_graph.are_nodes_connected(*edge) was_reviewed = prev_decision != UNREV is_within_gt = was_within_gt was_correct = prev_decision == true_decision is_correct = true_decision == decision # print('prev_decision = {!r}'.format(prev_decision)) # print('decision = {!r}'.format(decision)) # print('true_decision = {!r}'.format(true_decision)) test_print = partial(infr.print, level=2) def test_print(x, **kw): infr.print('[ACTION] ' + x, level=2, **kw) # test_print = lambda *a, **kw: None # NOQA if decision == POSTV: if is_correct: if not was_gt_pos: infr.test_gt_pos_graph.add_edge(*edge) elif was_gt_pos: test_print("UNDID GOOD POSITIVE EDGE", color='darkred') infr.test_gt_pos_graph.remove_edge(*edge) is_within_gt = infr.test_gt_pos_graph.are_nodes_connected(*edge) split_gt = is_within_gt != was_within_gt if split_gt: test_print("SPLIT A GOOD MERGE", color='darkred') infr.test_state['n_true_merges'] -= 1 confusion = infr.test_state['confusion'] if confusion is None: # initialize dynamic confusion matrix states = (POSTV, NEGTV, INCMP, UNREV, UNKWN) confusion = {r: {c: 0 for c in states} for r in states} infr.test_state['confusion'] = confusion if was_reviewed: confusion[true_decision][prev_decision] -= 1 confusion[true_decision][decision] += 1 else: confusion[true_decision][decision] += 1 test_action = None action_color = None if is_correct: # CORRECT DECISION if was_reviewed: if prev_decision == decision: test_action = 'correct duplicate' action_color = 'darkyellow' else: infr.mistake_edges.remove(edge) test_action = 'correction' action_color = 'darkgreen' if decision == POSTV: if not was_within_gt: test_action = 'correction redid merge' action_color = 'darkgreen' infr.test_state['n_true_merges'] += 1 else: if decision == POSTV: if not was_within_gt: test_action = 'correct merge' action_color = 'darkgreen' infr.test_state['n_true_merges'] += 1 else: test_action = 'correct redundant positive' action_color = 'darkblue' else: if decision == NEGTV: test_action = 'correct negative' action_color = 'teal' else: test_action = 'correct uninferrable' action_color = 'teal' else: action_color = 'darkred' # INCORRECT DECISION infr.mistake_edges.add(edge) if was_reviewed: if prev_decision == decision: test_action = 'incorrect duplicate' elif was_correct: test_action = 'incorrect undid good edge' else: if decision == POSTV: if was_within_pred: test_action = 'incorrect redundant merge' else: test_action = 'incorrect new merge' else: test_action = 'incorrect new mistake' infr.test_state['test_action'] = test_action infr.test_state['pred_decision'] = decision infr.test_state['true_decision'] = true_decision infr.test_state['user_id'] = user_id infr.test_state['recovering'] = (infr.recover_graph.has_node(edge[0]) or infr.recover_graph.has_node(edge[1])) infr.test_state['n_decision'] += 1 if user_id.startswith('algo'): infr.test_state['n_algo'] += 1 elif user_id.startswith('user') or user_id == 'oracle': infr.test_state['n_manual'] += 1 else: raise AssertionError('unknown user_id=%r' % (user_id,)) test_print(test_action, color=action_color) assert test_action is not None, 'what happened?'
[docs] class UserOracle(object): def __init__(oracle, accuracy, rng): if isinstance(rng, str): rng = sum(map(ord, rng)) rng = util.ensure_rng(rng, api='python') if isinstance(accuracy, tuple): oracle.normal_accuracy = accuracy[0] oracle.recover_accuracy = accuracy[1] else: oracle.normal_accuracy = accuracy oracle.recover_accuracy = accuracy oracle.rng = rng oracle.states = {POSTV, NEGTV, INCMP}
[docs] def review(oracle, edge, truth, infr, accuracy=None): feedback = { 'user_id': 'user:oracle', 'confidence': 'absolutely_sure', 'evidence_decision': None, 'meta_decision': NULL, 'timestamp_s1': util.get_timestamp('int', isutc=True), 'timestamp_c1': util.get_timestamp('int', isutc=True), 'timestamp_c2': util.get_timestamp('int', isutc=True), 'tags': [], } is_recovering = infr.is_recovering() if accuracy is None: if is_recovering: accuracy = oracle.recover_accuracy else: accuracy = oracle.normal_accuracy # The oracle can get anything where the hardness is less than its # accuracy hardness = oracle.rng.random() error = accuracy < hardness if error: error_options = list(oracle.states - {truth} - {INCMP}) observed = oracle.rng.choice(list(error_options)) else: observed = truth if accuracy < 1.0: feedback['confidence'] = 'pretty_sure' if accuracy < .5: feedback['confidence'] = 'guessing' feedback['evidence_decision'] = observed if error: infr.print( 'ORACLE ERROR real={} pred={} acc={:.2f} hard={:.2f}'.format( truth, observed, accuracy, hardness), 2, color='red') return feedback
if __name__ == '__main__': """ CommandLine: python ~/code/graphid/graphid.core/mixin_simulation.py all """ import xdoctest xdoctest.doctest_module(__file__)