Source code for graphid.demo.demo_script

import itertools as it
import ubelt as ub
from graphid import util
from os.path import join


[docs] def run_demo(): """ CommandLine: python -m graphid.demo.demo_script run_demo --viz python -m graphid.demo.demo_script run_demo Example: >>> run_demo() """ from graphid import demo import matplotlib as mpl TMP_RC = { 'axes.titlesize': 12, 'axes.labelsize': int(ub.argval('--labelsize', default=8)), 'font.family': 'sans-serif', 'font.serif': 'CMU Serif', 'font.sans-serif': 'CMU Sans Serif', 'font.monospace': 'CMU Typewriter Text', 'xtick.labelsize': 12, 'ytick.labelsize': 12, # 'legend.alpha': .8, 'legend.fontsize': 12, 'legend.facecolor': 'w', } mpl.rcParams.update(TMP_RC) # ---- Synthetic data params params = { 'redun.pos': 2, 'redun.neg': 2, } # oracle_accuracy = .98 # oracle_accuracy = .90 # oracle_accuracy = (.8, 1.0) oracle_accuracy = (.85, 1.0) # oracle_accuracy = 1.0 # --- draw params VISUALIZE = ub.argflag('--viz') # QUIT_OR_EMEBED = 'embed' QUIT_OR_EMEBED = 'quit' def asint(p): return p if p is None else int(p) TARGET_REVIEW = asint(ub.argval('--target', default=None)) START = asint(ub.argval('--start', default=None)) END = asint(ub.argval('--end', default=None)) # ------------------ # rng = np.random.RandomState(42) # infr = demo.demodata_infr(num_pccs=4, size=3, size_std=1, p_incon=0) # infr = demo.demodata_infr(num_pccs=6, size=7, size_std=1, p_incon=0) # infr = demo.demodata_infr(num_pccs=3, size=5, size_std=.2, p_incon=0) infr = demo.demodata_infr(pcc_sizes=[5, 2, 4]) infr.verbose = 100 infr.ensure_cliques() infr.ensure_full() # Dummy scoring infr.init_simulation(oracle_accuracy=oracle_accuracy, name='run_demo') # infr_gt = infr.copy() dpath = ub.ensuredir(ub.Path('~/Desktop/demo').expand()) if 0: ub.delete(dpath) ub.ensuredir(dpath) fig_counter = it.count(0) def show_graph(infr, title, final=False, selected_edges=None): from matplotlib import pyplot as plt if not VISUALIZE: return # TODO: rich colored text? latest = '\n'.join(infr.latest_logs()) showkw = dict( # fontsize=infr.graph.graph['fontsize'], # fontname=infr.graph.graph['fontname'], show_unreviewed_edges=True, show_inferred_same=False, show_inferred_diff=False, outof=(len(infr.aids)), # show_inferred_same=True, # show_inferred_diff=True, selected_edges=selected_edges, show_labels=True, simple_labels=True, # show_recent_review=not final, show_recent_review=False, # splines=infr.graph.graph['splines'], reposition=False, # with_colorbar=True ) verbose = infr.verbose infr.verbose = 0 infr_ = infr.copy() infr_ = infr infr_.verbose = verbose infr_.show(pickable=True, verbose=0, **showkw) infr.verbose = verbose # print('status ' + ub.urepr(infr_.status())) # infr.show(**showkw) ax = plt.gca() ax.set_title(title, fontsize=20) fig = plt.gcf() # fontsize = 22 fontsize = 12 if True: # postprocess xlabel lines = [] for line in latest.split('\n'): if False and line.startswith('ORACLE ERROR'): lines += ['ORACLE ERROR'] else: lines += [line] latest = '\n'.join(lines) if len(lines) > 10: fontsize = 16 if len(lines) > 12: fontsize = 14 if len(lines) > 14: fontsize = 12 if len(lines) > 18: fontsize = 10 if len(lines) > 23: fontsize = 8 if True: util.mplutil.adjust_subplots(top=.95, left=0, right=1, bottom=.45, fig=fig) ax.set_xlabel('\n' + latest) xlabel = ax.get_xaxis().get_label() xlabel.set_horizontalalignment('left') # xlabel.set_x(.025) # xlabel.set_x(-.6) xlabel.set_x(-2.0) # xlabel.set_fontname('CMU Typewriter Text') xlabel.set_fontname('Inconsolata') xlabel.set_fontsize(fontsize) ax.set_aspect('equal') # ax.xaxis.label.set_color('red') fpath = join(dpath, 'demo_{:04d}.png'.format(next(fig_counter))) fig.savefig(fpath, dpi=300, # transparent=True, edgecolor='none') # pt.save_figure(dpath=dpath, dpi=300) infr.latest_logs() if VISUALIZE: infr.update_visual_attrs(groupby='name_label') infr.set_node_attrs('pin', 'true') node_dict = infr.graph.nodes print(ub.urepr(node_dict[1])) if VISUALIZE: infr.latest_logs() # Pin Nodes into the target groundtruth position show_graph(infr, 'target-gt') print(ub.urepr(infr.status())) infr.clear_feedback() infr.clear_name_labels() infr.clear_edges() print(ub.urepr(infr.status())) infr.latest_logs() if VISUALIZE: infr.update_visual_attrs() infr.prioritize('prob_match') if VISUALIZE or TARGET_REVIEW is None or TARGET_REVIEW == 0: show_graph(infr, 'initial state') def on_new_candidate_edges(infr, edges): # hack updateing visual attrs as a callback if VISUALIZE: infr.update_visual_attrs() infr.on_new_candidate_edges = on_new_candidate_edges infr.params.update(**params) infr.refresh_candidate_edges() VIZ_ALL = (VISUALIZE and TARGET_REVIEW is None and START is None) print('VIZ_ALL = %r' % (VIZ_ALL,)) if VIZ_ALL or TARGET_REVIEW == 0: show_graph(infr, 'find-candidates') # _iter2 = enumerate(infr.generate_reviews(**params)) # _iter2 = list(_iter2) # assert len(_iter2) > 0 # prog = ub.ProgIter(_iter2, label='run_demo', bs=False, adjust=False, # enabled=False) count = 1 first = 1 for edge, priority in infr._generate_reviews(data=True): msg = 'review #%d, priority=%.3f' % (count, priority) print('\n----------') infr.print('pop edge {} with priority={:.3f}'.format(edge, priority)) # print('remaining_reviews = %r' % (infr.remaining_reviews()),) # Make the next review if START is not None: VIZ_ALL = count >= START if END is not None and count >= END: break infr.print(msg) if ub.allsame(infr.pos_graph.node_labels(*edge)) and first: # Have oracle make a mistake early feedback = infr.request_oracle_review(edge, accuracy=0) first -= 1 else: feedback = infr.request_oracle_review(edge) AT_TARGET = TARGET_REVIEW is not None and count >= TARGET_REVIEW - 1 SHOW_CANDIATE_POP = True if SHOW_CANDIATE_POP and (VIZ_ALL or AT_TARGET): infr.print(ub.urepr(infr.task_probs['match_state'][edge], precision=4, si=True)) infr.print('len(queue) = %r' % (len(infr.queue))) # Show edge selection infr.print('Oracle will predict: ' + feedback['evidence_decision']) show_graph(infr, 'pre' + msg, selected_edges=[edge]) if count == TARGET_REVIEW: infr.EMBEDME = QUIT_OR_EMEBED == 'embed' infr.add_feedback(edge, **feedback) infr.print('len(queue) = %r' % (len(infr.queue))) # infr.apply_nondynamic_update() # Show the result if VIZ_ALL or AT_TARGET: show_graph(infr, msg) # import sys # sys.exit(1) if count == TARGET_REVIEW: break count += 1 infr.print('status = ' + ub.urepr(infr.status(extended=False))) show_graph(infr, 'post-review (#reviews={})'.format(count), final=True) if VISUALIZE: if not getattr(infr, 'EMBEDME', False): # import plottool_ibeis as pt # util.mplutil.all_figures_tile() util.mplutil.show_if_requested()
if __name__ == '__main__': """ CommandLine: python -m graphid.demo.demo_script all """ import xdoctest xdoctest.doctest_module(__file__)