"""
These check for certain invariants that should be maintained by the dynamic
data structure.
"""
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
import itertools as it
import networkx as nx
import ubelt as ub
from graphid.util.nx_utils import e_
DEBUG_INCON = True
[docs]
class AssertInvariants(object):
[docs]
def assert_edge(infr, edge):
assert edge[0] < edge[1], (
'edge={} does not satisfy ordering constraint'.format(edge))
[docs]
def assert_invariants(infr, msg=''):
infr.assert_disjoint_invariant(msg)
infr.assert_union_invariant(msg)
infr.assert_consistency_invariant(msg)
infr.assert_recovery_invariant(msg)
infr.assert_neg_metagraph()
[docs]
def assert_union_invariant(infr, msg=''):
edge_sets = {
key: set(it.starmap(e_, graph.edges()))
for key, graph in infr.review_graphs.items()
}
edge_union = set.union(*edge_sets.values())
all_edges = set(it.starmap(e_, infr.graph.edges()))
if edge_union != all_edges:
print('ERROR STATUS DUMP:')
print(ub.urepr(infr.status()))
raise AssertionError(
'edge sets must have full union. Found union=%d vs all=%d' % (
len(edge_union), len(all_edges)
))
[docs]
def assert_disjoint_invariant(infr, msg=''):
# infr.print('assert_disjoint_invariant', 200)
edge_sets = {
key: set(it.starmap(e_, graph.edges()))
for key, graph in infr.review_graphs.items()
}
for es1, es2 in it.combinations(edge_sets.values(), 2):
assert es1.isdisjoint(es2), 'edge sets must be disjoint'
[docs]
def assert_consistency_invariant(infr, msg=''):
if not DEBUG_INCON:
return
# infr.print('assert_consistency_invariant', 200)
if infr.params['inference.enabled']:
incon_ccs = list(infr.inconsistent_components())
if len(incon_ccs) > 0:
raise AssertionError('The graph is not consistent. ' +
msg)
[docs]
def assert_recovery_invariant(infr, msg=''):
if not DEBUG_INCON:
return
# infr.print('assert_recovery_invariant', 200)
inconsistent_ccs = list(infr.inconsistent_components())
incon_cc = set(ub.flatten(inconsistent_ccs)) # NOQA