aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKapileshwar Singh <kapileshwarsingh@gmail.com>2016-01-16 20:03:49 -0800
committerKapileshwar Singh <kapileshwarsingh@gmail.com>2016-01-16 20:03:49 -0800
commit79310bfbf0050de7330efbc84a5dc5d0834aa3eb (patch)
tree99f674f8aabb856489b6360a98b8d3073634d689
parenta5c1fb04cfdfc963ce6c26e2102d3e4793481d24 (diff)
parent5af9d234eb3445a36c34695c5d1b1bd8b88d5c6f (diff)
downloadbart-79310bfbf0050de7330efbc84a5dc5d0834aa3eb.tar.gz
Merge pull request #46 from JaviMerino/assert_statement_dataframe
assert when the parsed statement returns a dataframe of bools
-rw-r--r--bart/common/Analyzer.py11
-rw-r--r--tests/test_common_utils.py32
2 files changed, 38 insertions, 5 deletions
diff --git a/bart/common/Analyzer.py b/bart/common/Analyzer.py
index 51194d7..d9dc74d 100644
--- a/bart/common/Analyzer.py
+++ b/bart/common/Analyzer.py
@@ -22,6 +22,7 @@ implemented yet.
from trappy.stats.grammar import Parser
import warnings
import numpy as np
+import pandas as pd
# pylint: disable=invalid-name
@@ -56,12 +57,12 @@ class Analyzer(object):
result = self.getStatement(statement, select=select)
- # pylint: disable=no-member
- if not (isinstance(result, bool) or isinstance(result, np.bool_)):
- warnings.warn(
- "solution of {} is not an instance of bool".format(statement))
+ if isinstance(result, pd.DataFrame):
+ result = result.all().all()
+ elif not(isinstance(result, bool) or isinstance(result, np.bool_)): # pylint: disable=no-member
+ warnings.warn("solution of {} is not boolean".format(statement))
+
return result
- # pylint: enable=no-member
def getStatement(self, statement, reference=False, select=None):
"""Evaluate the statement"""
diff --git a/tests/test_common_utils.py b/tests/test_common_utils.py
index 56398be..09b31e3 100644
--- a/tests/test_common_utils.py
+++ b/tests/test_common_utils.py
@@ -14,8 +14,10 @@
#
from bart.common import Utils
+from bart.common.Analyzer import Analyzer
import unittest
import pandas as pd
+import trappy
class TestCommonUtils(unittest.TestCase):
@@ -96,3 +98,33 @@ class TestCommonUtils(unittest.TestCase):
method="rect",
step="pre"),
0)
+
+
+class TestAnalyzer(unittest.TestCase):
+
+ def test_assert_statement_bool(self):
+ """Check that asssertStatement() works with a simple boolean case"""
+
+ rolls_dfr = pd.DataFrame({"results": [1, 3, 2, 6, 2, 4]})
+ trace = trappy.BareTrace()
+ trace.add_parsed_event("dice_rolls", rolls_dfr)
+ config = {"MAX_DICE_NUMBER": 6}
+
+ t = Analyzer(trace, config)
+ statement = "numpy.max(dice_rolls:results) <= MAX_DICE_NUMBER"
+ self.assertTrue(t.assertStatement(statement, select=0))
+
+ def test_assert_statement_dataframe(self):
+ """assertStatement() works if the generated statement creates a pandas.DataFrame of bools"""
+
+ rolls_dfr = pd.DataFrame({"results": [1, 3, 2, 6, 2, 4]})
+ trace = trappy.BareTrace()
+ trace.add_parsed_event("dice_rolls", rolls_dfr)
+ config = {"MIN_DICE_NUMBER": 1, "MAX_DICE_NUMBER": 6}
+ t = Analyzer(trace, config)
+
+ statement = "(dice_rolls:results <= MAX_DICE_NUMBER) & (dice_rolls:results >= MIN_DICE_NUMBER)"
+ self.assertTrue(t.assertStatement(statement))
+
+ statement = "dice_rolls:results == 3"
+ self.assertFalse(t.assertStatement(statement))