aboutsummaryrefslogtreecommitdiff
path: root/analysis/R/fast_em.R
diff options
context:
space:
mode:
Diffstat (limited to 'analysis/R/fast_em.R')
-rwxr-xr-xanalysis/R/fast_em.R137
1 files changed, 137 insertions, 0 deletions
diff --git a/analysis/R/fast_em.R b/analysis/R/fast_em.R
new file mode 100755
index 0000000..c19862c
--- /dev/null
+++ b/analysis/R/fast_em.R
@@ -0,0 +1,137 @@
+# fast_em.R: Wrapper around analysis/cpp/fast_em.cc.
+#
+# This serializes the input, shells out, and deserializes the output.
+
+.Flatten <- function(list_of_matrices) {
+ listOfVectors <- lapply(list_of_matrices, as.vector)
+ #print(listOfVectors)
+
+ # unlist takes list to vector.
+ unlist(listOfVectors)
+}
+
+.WriteListOfMatrices <- function(list_of_matrices, f) {
+ flattened <- .Flatten(list_of_matrices)
+
+ # NOTE: UpdateJointConditional does outer product of dimensions!
+
+ # 3 letter strings are null terminated
+ writeBin('ne ', con = f)
+ num_entries <- length(list_of_matrices)
+ writeBin(num_entries, con = f)
+
+ Log('Wrote num_entries = %d', num_entries)
+
+ # For 2x3, this is 6
+ writeBin('es ', con = f)
+
+ entry_size <- as.integer(prod(dim(list_of_matrices[[1]])))
+ writeBin(entry_size, con = f)
+
+ Log('Wrote entry_size = %d', entry_size)
+
+ # now write the data
+ writeBin('dat', con = f)
+ writeBin(flattened, con = f)
+}
+
+.ExpectTag <- function(f, tag) {
+ # Read a single NUL-terminated character string.
+ actual <- readBin(con = f, what = "char", n = 1)
+
+ # Assert that we got what was expected.
+ if (length(actual) != 1) {
+ stop(sprintf("Failed to read a tag '%s'", tag))
+ }
+ if (actual != tag) {
+ stop(sprintf("Expected '%s', got '%s'", tag, actual))
+ }
+}
+
+.ReadResult <- function (f, entry_size, matrix_dims) {
+ .ExpectTag(f, "emi")
+ # NOTE: assuming R integers are 4 bytes (uint32_t)
+ num_em_iters <- readBin(con = f, what = "int", n = 1)
+
+ .ExpectTag(f, "pij")
+ pij <- readBin(con = f, what = "double", n = entry_size)
+
+ # Adjust dimensions
+ dim(pij) <- matrix_dims
+
+ Log("Number of EM iterations: %d", num_em_iters)
+ Log("PIJ read from external implementation:")
+ print(pij)
+
+ # est, sd, var_cov, hist
+ list(est = pij, num_em_iters = num_em_iters)
+}
+
+.SanityChecks <- function(joint_conditional) {
+ # Display some stats before sending it over to C++.
+
+ inf_counts <- lapply(joint_conditional, function(m) {
+ sum(m == Inf)
+ })
+ total_inf <- sum(as.numeric(inf_counts))
+
+ nan_counts <- lapply(joint_conditional, function(m) {
+ sum(is.nan(m))
+ })
+ total_nan <- sum(as.numeric(nan_counts))
+
+ zero_counts <- lapply(joint_conditional, function(m) {
+ sum(m == 0.0)
+ })
+ total_zero <- sum(as.numeric(zero_counts))
+
+ #sum(joint_conditional[joint_conditional == Inf, ])
+ Log('total inf: %s', total_inf)
+ Log('total nan: %s', total_nan)
+ Log('total zero: %s', total_zero)
+}
+
+ConstructFastEM <- function(em_executable, tmp_dir) {
+
+ return(function(joint_conditional, max_em_iters = 1000,
+ epsilon = 10 ^ -6, verbose = FALSE,
+ estimate_var = FALSE) {
+ matrix_dims <- dim(joint_conditional[[1]])
+ # Check that number of dimensions is 2.
+ if (length(matrix_dims) != 2) {
+ Log('FATAL: Expected 2 dimensions, got %d', length(matrix_dims))
+ stop()
+ }
+
+ entry_size <- prod(matrix_dims)
+ Log('entry size: %d', entry_size)
+
+ .SanityChecks(joint_conditional)
+
+ input_path <- file.path(tmp_dir, 'list_of_matrices.bin')
+ Log("Writing flattened list of matrices to %s", input_path)
+ f <- file(input_path, 'wb') # binary file
+ .WriteListOfMatrices(joint_conditional, f)
+ close(f)
+ Log("Done writing %s", input_path)
+
+ output_path <- file.path(tmp_dir, 'pij.bin')
+
+ cmd <- sprintf("%s %s %s %s", em_executable, input_path, output_path,
+ max_em_iters)
+
+ Log("Shell command: %s", cmd)
+ exit_code <- system(cmd)
+
+ Log("Done running shell command")
+ if (exit_code != 0) {
+ stop(sprintf("Command failed with code %d", exit_code))
+ }
+
+ f <- file(output_path, 'rb')
+ result <- .ReadResult(f, entry_size, matrix_dims)
+ close(f)
+
+ result
+ })
+}