aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-12-06 00:18:44 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2023-12-06 00:18:44 +0000
commit8d9bf3868e943e2fb7132b501a15cbc2b4280b14 (patch)
tree2801ad9997cc9de3c09893b9826eab4843217f36
parent634c38a28fc72363d9823c243f7e51f59d7ed83e (diff)
parentf55eb52017cb64e6fd0e4aba2dfd9334c0eb6887 (diff)
downloadprivate-join-and-compute-android14-mainline-os-statsd-release.tar.gz
Snap for 11181721 from f55eb52017cb64e6fd0e4aba2dfd9334c0eb6887 to mainline-os-statsd-releaseaml_sta_341615000aml_sta_341511040aml_sta_341410000android14-mainline-os-statsd-release
Change-Id: Ib5ae17e1d9f07fc8c9d08374b8267aeb48d270f5
-rw-r--r--Android.bp10
-rw-r--r--bazel/pjc_deps.bzl10
-rw-r--r--external/requirements.txt4
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD3
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc33
-rw-r--r--private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc35
-rw-r--r--private_join_and_compute/crypto/proto/BUILD4
-rw-r--r--private_join_and_compute/crypto/proto/proto_util.cc6
-rw-r--r--private_join_and_compute/crypto/proto/proto_util.h8
-rw-r--r--private_join_and_compute/crypto/proto/proto_util_test.cc22
-rw-r--r--private_join_and_compute/py/BUILD43
-rw-r--r--private_join_and_compute/py/README16
-rw-r--r--private_join_and_compute/py/__init__.py13
-rw-r--r--private_join_and_compute/py/ciphers/BUILD43
-rw-r--r--private_join_and_compute/py/ciphers/ec_cipher.py127
-rw-r--r--private_join_and_compute/py/ciphers/ec_cipher_test.py78
-rw-r--r--private_join_and_compute/py/crypto_util/BUILD104
-rw-r--r--private_join_and_compute/py/crypto_util/converters.py83
-rw-r--r--private_join_and_compute/py/crypto_util/converters_test.py70
-rw-r--r--private_join_and_compute/py/crypto_util/elliptic_curve.py390
-rw-r--r--private_join_and_compute/py/crypto_util/elliptic_curve_test.py122
-rw-r--r--private_join_and_compute/py/crypto_util/ssl_util.py1098
-rw-r--r--private_join_and_compute/py/crypto_util/ssl_util_test.py543
-rw-r--r--private_join_and_compute/py/crypto_util/supported_curves.py32
-rw-r--r--private_join_and_compute/py/crypto_util/supported_hashes.py37
25 files changed, 2881 insertions, 53 deletions
diff --git a/Android.bp b/Android.bp
index a70c383..34ce2e4 100644
--- a/Android.bp
+++ b/Android.bp
@@ -68,10 +68,12 @@ cc_library {
"private_join_and_compute/crypto/",
".",
],
+ include_dirs: [
+ "external/protobuf",
+ ],
shared_libs: [
"libcrypto",
"liblog",
-
],
static_libs: [
"libpjc_third_party_libabsl",
@@ -83,6 +85,9 @@ cc_library {
local_include_dirs: [
".",
],
+ include_dirs: [
+ "external/protobuf",
+ ],
},
sanitize: {
integer_overflow: true,
@@ -122,6 +127,9 @@ cc_test {
"libgmock",
"libpjc_third_party_libabsl",
],
+ include_dirs: [
+ "external/protobuf",
+ ],
test_suites: ["general-tests"],
target: {
host: {
diff --git a/bazel/pjc_deps.bzl b/bazel/pjc_deps.bzl
index 62a0e15..61c6fc5 100644
--- a/bazel/pjc_deps.bzl
+++ b/bazel/pjc_deps.bzl
@@ -59,3 +59,13 @@ def pjc_deps():
"https://github.com/protocolbuffers/protobuf/archive/f0dc78d7e6e331b8c6bb2d5283e06aa26883ca7c.tar.gz",
],
)
+
+
+ # Six (python compatibility)
+ if "six" not in native.existing_rules():
+ http_archive(
+ name = "six",
+ build_file = "@com_google_protobuf//:six.BUILD",
+ sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
+ url = "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz#md5=34eed507548117b2ab523ab14b2f8b55",
+ )
diff --git a/external/requirements.txt b/external/requirements.txt
new file mode 100644
index 0000000..2f321c8
--- /dev/null
+++ b/external/requirements.txt
@@ -0,0 +1,4 @@
+# repositories to install via Pip for compiling private-join-and-compute
+# python code externally
+six
+absl-py
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD
index ebf5965..11530fa 100644
--- a/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/BUILD
@@ -49,10 +49,7 @@ cc_library(
"//private_join_and_compute/crypto:bn_util",
"//private_join_and_compute/crypto:ec_util",
"//private_join_and_compute/crypto:pedersen_over_zn",
- "//private_join_and_compute/crypto/proto:big_num_cc_proto",
- "//private_join_and_compute/crypto/proto:ec_point_cc_proto",
"//private_join_and_compute/crypto/proto:proto_util",
- "//private_join_and_compute/util:status_includes",
"@com_google_absl//absl/strings",
"@com_google_protobuf//:protobuf_lite",
],
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc
index cd35f69..e2f1390 100644
--- a/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/bb_oblivious_signature.cc
@@ -26,6 +26,7 @@
#include <utility>
#include <vector>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include "absl/strings/str_cat.h"
#include "private_join_and_compute/crypto/big_num.h"
#include "private_join_and_compute/crypto/camenisch_shoup.h"
@@ -37,7 +38,6 @@
#include "private_join_and_compute/crypto/proto/camenisch_shoup.pb.h"
#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
#include "private_join_and_compute/crypto/proto/proto_util.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
namespace private_join_and_compute {
@@ -71,10 +71,9 @@ GenerateHomomorphicCsCiphertexts(
for (size_t i = 0; i < num_camenisch_shoup_ciphertexts; ++i) {
size_t batch_start_index =
i * public_camenisch_shoup->vector_encryption_length();
- size_t batch_size =
- std::min(
- public_camenisch_shoup->vector_encryption_length(),
- static_cast<uint64_t>(masked_messages.size() - batch_start_index));
+ size_t batch_size = std::min(
+ public_camenisch_shoup->vector_encryption_length(),
+ static_cast<uint64_t>(masked_messages.size() - batch_start_index));
size_t batch_end_index = batch_start_index + batch_size;
// Determine the messages for the i'th batch.
std::vector<BigNum> masked_messages_for_batch_i(
@@ -1513,18 +1512,10 @@ StatusOr<BigNum> BbObliviousSignature::GenerateRequestProofChallenge(
challenge_sos.get());
challenge_cos->SetSerializationDeterministic(true);
challenge_cos->WriteVarint64(proof_statement.ByteSizeLong());
- if (!proof_statement.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "BbObliviousSignature::GenerateRequestProofChallenge: Failed to "
- "serialize statement.");
- }
+ challenge_cos->WriteString(SerializeAsStringInOrder(proof_statement));
challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong());
- if (!proof_message_1.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "BbObliviousSignature::GenerateRequestProofChallenge: Failed to "
- "serialize proof_message_1.");
- }
+ challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1));
// Delete the CodedOutputStream and StringOutputStream to make sure they are
// cleaned up before hashing.
@@ -1571,18 +1562,10 @@ StatusOr<BigNum> BbObliviousSignature::GenerateResponseProofChallenge(
challenge_sos.get());
challenge_cos->SetSerializationDeterministic(true);
challenge_cos->WriteVarint64(statement.ByteSizeLong());
- if (!statement.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "BbObliviousSignature::GenerateResponseProofChallenge: Failed to "
- "serialize statement.");
- }
+ challenge_cos->WriteString(SerializeAsStringInOrder(statement));
challenge_cos->WriteVarint64(proof_message_1.ByteSizeLong());
- if (!proof_message_1.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "BbObliviousSignature::GenerateResponseProofChallenge: Failed to "
- "serialize proof_message_1.");
- }
+ challenge_cos->WriteString(SerializeAsStringInOrder(proof_message_1));
// Delete the CodedOutputStream and StringOutputStream to make sure they are
// cleaned up before hashing.
diff --git a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc
index e95564a..ed711c8 100644
--- a/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc
+++ b/private_join_and_compute/crypto/dodis_yampolskiy_prf/dy_verifiable_random_function.cc
@@ -28,8 +28,9 @@
#include "private_join_and_compute/crypto/ec_point.h"
#include "private_join_and_compute/crypto/pedersen_over_zn.h"
#include "private_join_and_compute/crypto/proto/proto_util.h"
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+#include "src/google/protobuf/io/coded_stream.h"
+#include "src/google/protobuf/io/zero_copy_stream_impl_lite.h"
namespace private_join_and_compute {
@@ -224,18 +225,12 @@ DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof(
challenge_sos.get());
challenge_cos->SetSerializationDeterministic(true);
challenge_cos->WriteVarint64(statement.ByteSizeLong());
- if (!statement.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof: "
- "Failed to serialize statement.");
- }
- challenge_cos->WriteVarint64(message_1.ByteSizeLong());
- if (!message_1.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "DyVerifiableRandomFunction::GenerateChallengeForGenerateKeysProof: "
- "Failed to serialize message_1.");
- }
+ challenge_cos->WriteString(SerializeAsStringInOrder(statement));
+
+ challenge_cos->WriteVarint64(message_1.ByteSizeLong());
+ challenge_cos->WriteString(SerializeAsStringInOrder(message_1));
+
BigNum challenge_bound =
context_->One().Lshift(parameters_proto_.challenge_length_bits());
@@ -552,17 +547,11 @@ StatusOr<BigNum> DyVerifiableRandomFunction::GenerateApplyProofChallenge(
challenge_sos.get());
challenge_cos->SetSerializationDeterministic(true);
challenge_cos->WriteVarint64(statement.ByteSizeLong());
- if (!statement.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "DyVerifiableRandomFunction::GenerateApplyProofChallenge: Failed to "
- "serialize statement.");
- }
+
+ challenge_cos->WriteString(SerializeAsStringInOrder(statement));
+
challenge_cos->WriteVarint64(message_1.ByteSizeLong());
- if (!message_1.SerializeToCodedStream(challenge_cos.get())) {
- return absl::InternalError(
- "DyVerifiableRandomFunction::GenerateApplyProofChallenge: Failed to "
- "serialize message_1.");
- }
+ challenge_cos->WriteString(SerializeAsStringInOrder(message_1));
BigNum challenge_bound =
context_->One().Lshift(parameters_proto_.challenge_length_bits());
diff --git a/private_join_and_compute/crypto/proto/BUILD b/private_join_and_compute/crypto/proto/BUILD
index 34797fb..e1bac60 100644
--- a/private_join_and_compute/crypto/proto/BUILD
+++ b/private_join_and_compute/crypto/proto/BUILD
@@ -71,6 +71,7 @@ cc_library(
"//private_join_and_compute/crypto:bn_util",
"//private_join_and_compute/crypto:ec_util",
"//private_join_and_compute/util:status_includes",
+ "@com_google_protobuf//:protobuf",
],
)
@@ -80,13 +81,14 @@ cc_test(
deps = [
":big_num_cc_proto",
":ec_point_cc_proto",
+ ":pedersen_cc_proto",
":proto_util",
"//private_join_and_compute/crypto:bn_util",
"//private_join_and_compute/crypto:ec_util",
"//private_join_and_compute/crypto:openssl_includes",
+ "//private_join_and_compute/crypto:pedersen_over_zn",
"//private_join_and_compute/util:status_includes",
"//private_join_and_compute/util:status_testing_includes",
"@com_github_google_googletest//:gtest_main",
- "@com_google_absl//absl/memory",
],
)
diff --git a/private_join_and_compute/crypto/proto/proto_util.cc b/private_join_and_compute/crypto/proto/proto_util.cc
index be368c3..f063412 100644
--- a/private_join_and_compute/crypto/proto/proto_util.cc
+++ b/private_join_and_compute/crypto/proto/proto_util.cc
@@ -15,6 +15,8 @@
#include "private_join_and_compute/crypto/proto/proto_util.h"
+
+#include <string>
#include <utility>
#include <vector>
@@ -75,4 +77,8 @@ StatusOr<std::vector<ECPoint>> ParseECPointVectorProto(
return std::move(ec_point_vector);
}
+std::string SerializeAsStringInOrder(const google::protobuf::MessageLite& proto) {
+ return proto.SerializeAsString();
+}
+
} // namespace private_join_and_compute
diff --git a/private_join_and_compute/crypto/proto/proto_util.h b/private_join_and_compute/crypto/proto/proto_util.h
index 6449897..4576403 100644
--- a/private_join_and_compute/crypto/proto/proto_util.h
+++ b/private_join_and_compute/crypto/proto/proto_util.h
@@ -16,6 +16,7 @@
#ifndef PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_
#define PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_
+#include <string>
#include <vector>
#include "private_join_and_compute/crypto/context.h"
@@ -23,6 +24,8 @@
#include "private_join_and_compute/crypto/proto/big_num.pb.h"
#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "src/google/protobuf/message_lite.h"
+
namespace private_join_and_compute {
// Converts a std::vector<BigNum> into a protocol buffer BigNumVector.
proto::BigNumVector BigNumVectorToProto(
@@ -41,6 +44,11 @@ StatusOr<std::vector<ECPoint>> ParseECPointVectorProto(
Context* context, ECGroup* ec_group,
const proto::ECPointVector& ec_point_vector_proto);
+// Serializes a proto to a string by serializing the fields in tag order. This
+// will guarantee deterministic encoding, as long as there are no cross-language
+// strings, and no unknown fields across different serializations.
+std::string SerializeAsStringInOrder(const google::protobuf::MessageLite& proto);
+
} // namespace private_join_and_compute
#endif // PRIVATE_JOIN_AND_COMPUTE_CRYPTO_PROTO_PROTO_UTIL_H_
diff --git a/private_join_and_compute/crypto/proto/proto_util_test.cc b/private_join_and_compute/crypto/proto/proto_util_test.cc
index 199db0b..2d4b381 100644
--- a/private_join_and_compute/crypto/proto/proto_util_test.cc
+++ b/private_join_and_compute/crypto/proto/proto_util_test.cc
@@ -19,6 +19,8 @@
#include <gtest/gtest.h>
#include <memory>
+
+#include <string>
#include <utility>
#include <vector>
@@ -26,8 +28,12 @@
#include "private_join_and_compute/crypto/ec_group.h"
#include "private_join_and_compute/crypto/ec_point.h"
#include "private_join_and_compute/crypto/openssl.inc"
+
+#include "private_join_and_compute/crypto/pedersen_over_zn.h"
#include "private_join_and_compute/crypto/proto/big_num.pb.h"
#include "private_join_and_compute/crypto/proto/ec_point.pb.h"
+#include "private_join_and_compute/crypto/proto/pedersen.pb.h"
+
#include "private_join_and_compute/util/status.inc"
#include "private_join_and_compute/util/status_testing.inc"
@@ -90,5 +96,21 @@ TEST(ProtoUtilTest, ParseEmptyECPointVector) {
EXPECT_EQ(empty_ec_point_vector, deserialized);
}
+TEST(ProtoUtilTest, SerializeAsStringInOrderIsConsistent) {
+ Context ctx;
+ std::vector<BigNum> big_num_vector = {ctx.One(), ctx.Two(), ctx.Three()};
+
+ proto::PedersenParameters pedersen_parameters_proto;
+ pedersen_parameters_proto.set_n(ctx.CreateBigNum(37).ToBytes());
+ *pedersen_parameters_proto.mutable_gs() = BigNumVectorToProto(big_num_vector);
+ pedersen_parameters_proto.set_h(ctx.CreateBigNum(4).ToBytes());
+
+ const std::string kExpectedSerialized =
+ "\n\x1%\x12\t\n\x1\x1\n\x1\x2\n\x1\x3\x1A\x1\x4";
+ std::string serialized = SerializeAsStringInOrder(pedersen_parameters_proto);
+
+ EXPECT_EQ(serialized, kExpectedSerialized);
+}
+
} // namespace
} // namespace private_join_and_compute
diff --git a/private_join_and_compute/py/BUILD b/private_join_and_compute/py/BUILD
new file mode 100644
index 0000000..59bebec
--- /dev/null
+++ b/private_join_and_compute/py/BUILD
@@ -0,0 +1,43 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_python//python:packaging.bzl", "py_package", "py_wheel")
+
+package(default_visibility = ["//visibility:public"])
+
+# Creates private_join_and_compute-0.0.1.whl
+py_wheel(
+ name = "private_join_and_compute_wheel",
+ classifiers = [
+ "License :: OSI Approved :: Apache Software License",
+ ],
+ description_file = "README",
+ # This should match the project name on PyPI. It's also the name that is used to refer to the
+ # package in other packages' dependencies.
+ distribution = "private_join_and_compute",
+ python_tag = "py3",
+ requires = [
+ "absl-py",
+ "six",
+ ],
+ version = "0.0.1",
+ deps = [
+ "//private_join_and_compute/py/ciphers:ec_cipher",
+ "//private_join_and_compute/py/crypto_util:converters",
+ "//private_join_and_compute/py/crypto_util:elliptic_curve",
+ "//private_join_and_compute/py/crypto_util:ssl_util",
+ "//private_join_and_compute/py/crypto_util:supported_curves",
+ "//private_join_and_compute/py/crypto_util:supported_hashes",
+ ],
+)
diff --git a/private_join_and_compute/py/README b/private_join_and_compute/py/README
new file mode 100644
index 0000000..2758e0f
--- /dev/null
+++ b/private_join_and_compute/py/README
@@ -0,0 +1,16 @@
+This library contains a python wrapper over OpenSSL/BoringSSL elliptic curves.
+
+Example Usage:
+
+::
+
+ from private_join_and_compute.py.ciphers import ec_cipher
+ from private_join_and_compute.py.crypto_util import supported_curves
+ from private_join_and_compute.py.crypto_util import supported_hashes
+
+ client_cipher = ec_cipher.EcCipher(
+ curve_id=supported_curves.SupportedCurve.SECP256R1.id,
+ hash_type=supported_hashes.HashType.SHA256,
+ private_key_bytes=None) # "None" generates a new key
+ encrypted_point = client_cipher.Encrypt(b"id_bytes")
+
diff --git a/private_join_and_compute/py/__init__.py b/private_join_and_compute/py/__init__.py
new file mode 100644
index 0000000..7489074
--- /dev/null
+++ b/private_join_and_compute/py/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/private_join_and_compute/py/ciphers/BUILD b/private_join_and_compute/py/ciphers/BUILD
new file mode 100644
index 0000000..1ff2d69
--- /dev/null
+++ b/private_join_and_compute/py/ciphers/BUILD
@@ -0,0 +1,43 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Description:
+# Contains libraries for openssl big num operations.
+load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+load("@pip_deps//:requirements.bzl", "requirement")
+
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "ec_cipher",
+ srcs = [
+ "ec_cipher.py",
+ ],
+ deps = [
+ "//private_join_and_compute/py/crypto_util:elliptic_curve",
+ "//private_join_and_compute/py/crypto_util:supported_hashes",
+ ],
+)
+
+py_test(
+ name = "ec_cipher_test",
+ size = "small",
+ srcs = ["ec_cipher_test.py"],
+ deps = [
+ ":ec_cipher",
+ "//private_join_and_compute/py/crypto_util:supported_curves",
+ "//private_join_and_compute/py/crypto_util:supported_hashes",
+ ],
+)
diff --git a/private_join_and_compute/py/ciphers/ec_cipher.py b/private_join_and_compute/py/ciphers/ec_cipher.py
new file mode 100644
index 0000000..36ae8ec
--- /dev/null
+++ b/private_join_and_compute/py/ciphers/ec_cipher.py
@@ -0,0 +1,127 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""EC based commutative cipher."""
+
+from typing import Optional
+
+from private_join_and_compute.py.crypto_util import elliptic_curve
+from private_join_and_compute.py.crypto_util import supported_hashes
+
+NID_secp224r1 = 713 # pylint: disable=invalid-name
+DEFAULT_CURVE_ID = NID_secp224r1
+POINT_CONVERSION_COMPRESSED = 2
+
+
+class EcCipher(object):
+ """A commutative cipher based on Elliptic Curves."""
+
+ # key is an address.
+ def __init__(
+ self,
+ curve_id: int = DEFAULT_CURVE_ID,
+ private_key_bytes: Optional[bytes] = None,
+ hash_type: Optional[supported_hashes.HashType] = None,
+ ) -> None:
+ """Generate a new EC key pair, if the key is not passed as a parameter.
+
+ The private key is a random value and the private point is the result of
+ performing a scalar point multiplication of that value with the curve's
+ base point.
+
+ Args:
+ curve_id: the id of the curve to use, given as an int value.
+ private_key_bytes: an ec key in bytes, if the key has already been
+ generated.
+ hash_type: the hash to use in order to map a string to the elliptic curve.
+
+ Raises:
+ TypeError: If curve_id is not an int.
+ Exception: If the key could not be generated.
+ """
+ self._ec_key = elliptic_curve.ECKey(curve_id, private_key_bytes, hash_type)
+
+ def Encrypt(self, id_bytes: bytes) -> bytes:
+ """Hashes the client id to a point on the curve.
+
+ It then encrypts the point by multiplying it with the private key.
+
+ Args:
+ id_bytes: a client id encoded as a string/byte value.
+
+ Returns:
+ the compressed encoded EC Point in bytes.
+
+ Raises:
+ TypeError: If id_bytes is not a str type.
+ """
+ ec_point = self._ec_key.elliptic_curve.GetPointByHashingToCurve(id_bytes)
+ return self.EncryptPoint(ec_point)
+
+ def EncryptPoint(self, ec_point) -> bytes:
+ """Encrypts a point on the curve.
+
+ Args:
+ ec_point: the point to encrypt.
+
+ Returns:
+ the compressed encoded encrypted point in bytes
+ """
+ ec_point *= self._ec_key.priv_key_bn
+ return ec_point.GetAsBytes()
+
+ def ReEncrypt(self, enc_id_bytes: bytes) -> bytes:
+ """Re-encrypts the id by multiplying with the private key.
+
+ Args:
+ enc_id_bytes: an encrypted client id as a bytes value.
+
+ Returns:
+ the compressed encoded re-encrypted EC Point in bytes.
+
+ Raises:
+ TypeError: If enc_id_bytes id is not a str type.
+ """
+ ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(enc_id_bytes)
+ return self.EncryptPoint(ec_point)
+
+ @property
+ def ec_key(self):
+ return self._ec_key
+
+ @property
+ def elliptic_curve(self):
+ return self._ec_key.elliptic_curve
+
+ def DecryptReEncryptedId(self, reenc_id_bytes: bytes) -> bytes:
+ """Decrypts a reencrypted id to its encrypted id form.
+
+ Assuming reenc_id_bytes=E_k1(E_k2(m)) where E(.) is the ec_cipher and k1/k2
+ are private keys. This function with decryption key, k1', returns E_k2(m) or
+ with decryption key, k2', E_k1(m). Essentially this removes one layer of
+ encryption from the reenc_id_bytes.
+
+ This function *cannot* be applied to encrypted ids as the return value would
+ be the message one-way hashed to a point on the curve.
+
+ Args:
+ reenc_id_bytes: a reencrypted client id, encoded with a key and then
+ reencoded with another key.
+
+ Returns:
+ An encoded id in bytes.
+ """
+ ec_point = self._ec_key.elliptic_curve.GetPointFromBytes(reenc_id_bytes)
+ ec_point *= self._ec_key.decrypt_key_bignum
+ return ec_point.GetAsBytes()
diff --git a/private_join_and_compute/py/ciphers/ec_cipher_test.py b/private_join_and_compute/py/ciphers/ec_cipher_test.py
new file mode 100644
index 0000000..5bcf082
--- /dev/null
+++ b/private_join_and_compute/py/ciphers/ec_cipher_test.py
@@ -0,0 +1,78 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test class for EcCommutativeCipher."""
+
+import unittest
+from private_join_and_compute.py.ciphers import ec_cipher
+from private_join_and_compute.py.crypto_util import supported_curves
+from private_join_and_compute.py.crypto_util import supported_hashes
+
+
+class EcCommutativeCipherTest(unittest.TestCase):
+
+ def setUp(self):
+ super(EcCommutativeCipherTest, self).setUp()
+ self.client_cipher = ec_cipher.EcCipher(713)
+ self.server_cipher = ec_cipher.EcCipher(713)
+
+ def ReEncryptionSameId(self, cipher1, cipher2):
+ user_id = b'3274646578436540569872403985702934875092834502'
+ enc_id1 = cipher1.Encrypt(user_id)
+ enc_id2 = cipher2.Encrypt(user_id)
+ result1 = cipher2.ReEncrypt(enc_id1)
+ result2 = cipher1.ReEncrypt(enc_id2)
+ self.assertEqual(result1, result2)
+
+ def testReEncryptionSameId(self):
+ self.ReEncryptionSameId(self.client_cipher, self.server_cipher)
+
+ def testReEncryptionDifferentId(self):
+ user_id1 = b'3274646578436540569872403985702934875092834502'
+ user_id2 = b'7402039857096829483572943875209348524958235824'
+ enc_id1 = self.client_cipher.Encrypt(user_id1)
+ enc_id2 = self.server_cipher.Encrypt(user_id2)
+ result1 = self.server_cipher.ReEncrypt(enc_id1)
+ result2 = self.client_cipher.ReEncrypt(enc_id2)
+ self.assertNotEqual(result1, result2)
+
+ def testDecode(self):
+ user_id = b'7402039857096829483572943875209348524958235824'
+ enc_id1 = self.client_cipher.Encrypt(user_id)
+ enc_id2 = self.server_cipher.Encrypt(user_id)
+ result1 = self.server_cipher.ReEncrypt(enc_id1)
+ actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1)
+ actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1)
+ self.assertEqual(enc_id1, actual_enc_id2)
+ self.assertEqual(enc_id2, actual_enc_id1)
+
+ def testDifferentHashFunctions(self):
+ # freshly sampled key
+ sha256_cipher = ec_cipher.EcCipher(
+ curve_id=supported_curves.SupportedCurve.SECP256R1.id,
+ hash_type=supported_hashes.HashType.SHA256,
+ )
+ sha512_cipher = ec_cipher.EcCipher(
+ curve_id=supported_curves.SupportedCurve.SECP256R1.id,
+ hash_type=supported_hashes.HashType.SHA512,
+ private_key_bytes=sha256_cipher.ec_key.priv_key_bytes,
+ )
+ user_id = b'7402039857096829483572943875209348524958235824'
+ enc_id1 = sha256_cipher.Encrypt(user_id)
+ enc_id2 = sha512_cipher.Encrypt(user_id)
+ self.assertNotEqual(enc_id1, enc_id2)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/BUILD b/private_join_and_compute/py/crypto_util/BUILD
new file mode 100644
index 0000000..a015e35
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/BUILD
@@ -0,0 +1,104 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Description:
+# Contains libraries for openssl big num operations.
+
+load("@rules_python//python:defs.bzl", "py_library", "py_test")
+load("@pip_deps//:requirements.bzl", "requirement")
+
+package(default_visibility = ["//visibility:public"])
+
+py_library(
+ name = "converters",
+ srcs = [
+ "converters.py",
+ ],
+ deps = [
+ requirement("six"),
+ ],
+)
+
+py_test(
+ name = "converters_test",
+ size = "small",
+ srcs = ["converters_test.py"],
+ deps = [
+ ":converters",
+ ],
+)
+
+py_library(
+ name = "ssl_util",
+ srcs = [
+ "ssl_util.py",
+ ],
+ deps = [
+ ":converters",
+ ":supported_hashes",
+ requirement("six"),
+ requirement("absl-py"),
+ ],
+)
+
+py_library(
+ name = "supported_curves",
+ srcs = [
+ "supported_curves.py",
+ ],
+)
+
+py_library(
+ name = "supported_hashes",
+ srcs = [
+ "supported_hashes.py",
+ ],
+)
+
+py_test(
+ name = "ssl_util_test",
+ size = "small",
+ srcs = ["ssl_util_test.py"],
+ deps = [
+ ":ssl_util",
+ requirement("absl-py"),
+ ],
+)
+
+py_library(
+ name = "elliptic_curve",
+ srcs = [
+ "elliptic_curve.py",
+ ],
+ deps = [
+ ":converters",
+ ":ssl_util",
+ ":supported_curves",
+ ":supported_hashes",
+ requirement("six"),
+ ],
+)
+
+py_test(
+ name = "elliptic_curve_test",
+ size = "small",
+ srcs = ["elliptic_curve_test.py"],
+ deps = [
+ ":converters",
+ ":elliptic_curve",
+ ":ssl_util",
+ ":supported_curves",
+ ":supported_hashes",
+ ],
+)
diff --git a/private_join_and_compute/py/crypto_util/converters.py b/private_join_and_compute/py/crypto_util/converters.py
new file mode 100644
index 0000000..02fe28f
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/converters.py
@@ -0,0 +1,83 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Module providing conversion functions like long to bytes or bytes to long."""
+
+import operator
+import struct
+
+import six
+
+
+def _PadZeroBytes(byte_str, blocksize):
+ """Pads the front of byte_str with binary zeros.
+
+ Args:
+ byte_str: byte string to pad the binary zeros.
+ blocksize: the byte_str will be padded so that the length of the output will
+ be a multiple of blocksize.
+
+ Returns:
+ a new byte string padded with binary zeros if necessary.
+ """
+ if len(byte_str) % blocksize:
+ return (blocksize - len(byte_str) % blocksize) * b'\000' + byte_str
+ return byte_str
+
+
+def LongToBytes(number: int, blocksize: int = 0) -> bytes:
+ """Converts an arbitrary length number to a byte string.
+
+ Args:
+ number: number to convert to bytes.
+ blocksize: if specified, the output bytes length will be a multiple of
+ blocksize.
+
+ Returns:
+ byte string for the number.
+
+ Raises:
+ ValueError: when the number is negative.
+ """
+ if number < 0:
+ raise ValueError('number needs to be >=0, given: {}'.format(number))
+ number_32bitunit_components = []
+ while number != 0:
+ number_32bitunit_components.insert(0, number & 0xFFFFFFFF)
+ number >>= 32
+ converter = struct.Struct('>' + str(len(number_32bitunit_components)) + 'I')
+ n_bytes = six.ensure_binary(converter.pack(*number_32bitunit_components))
+ for idx in range(len(n_bytes)):
+ if operator.getitem(n_bytes, idx) != 0:
+ break
+ else:
+ n_bytes = b'\000'
+ idx = 0
+ n_bytes = n_bytes[idx:]
+ if blocksize > 0:
+ n_bytes = _PadZeroBytes(n_bytes, blocksize)
+ return six.ensure_binary(n_bytes)
+
+
+def BytesToLong(byte_string: bytes) -> int:
+ """Converts given byte string to a long."""
+ result = 0
+ padded_byte_str = _PadZeroBytes(byte_string, 4)
+ component_length = len(padded_byte_str) // 4
+ converter = struct.Struct('>' + str(component_length) + 'I')
+ unpacked_data = converter.unpack(padded_byte_str)
+ for i in range(0, component_length):
+ result += unpacked_data[i] << (32 * (component_length - i - 1))
+ return result
diff --git a/private_join_and_compute/py/crypto_util/converters_test.py b/private_join_and_compute/py/crypto_util/converters_test.py
new file mode 100644
index 0000000..3722ab3
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/converters_test.py
@@ -0,0 +1,70 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Test class for Convertors."""
+
+import unittest
+
+from private_join_and_compute.py.crypto_util import converters
+
+
+class ConvertorsTest(unittest.TestCase):
+
+ def testLongToBytes(self):
+ bytes_n = converters.LongToBytes(5)
+ self.assertEqual(b'\005', bytes_n)
+
+ def testZeroToBytes(self):
+ bytes_n = converters.LongToBytes(0)
+ self.assertEqual(b'\000', bytes_n)
+
+ def testLongToBytesForBigNum(self):
+ bytes_n = converters.LongToBytes(2**72 - 1)
+ self.assertEqual(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff', bytes_n)
+
+ def testBytesToLong(self):
+ number = converters.BytesToLong(b'\005')
+ self.assertEqual(5, number)
+
+ def testBytesToLongForBigNum(self):
+ number = converters.BytesToLong(b'\xff\xff\xff\xff\xff\xff\xff\xff\xff')
+ self.assertEqual(2**72 - 1, number)
+
+ def testLongToBytesCompatibleWithBytesToLong(self):
+ long_num = 4239423984023840823047823975923401283971204812394723040127401238
+ self.assertEqual(
+ long_num, converters.BytesToLong(converters.LongToBytes(long_num))
+ )
+
+ def testLongToBytesWithPadding(self):
+ bytes_n = converters.LongToBytes(5, 6)
+ self.assertEqual(b'\000\000\000\000\000\005', bytes_n)
+
+ def testBytesToLongWithPadding(self):
+ number = converters.BytesToLong(b'\000\000\000\000\000\005')
+ self.assertEqual(5, number)
+
+ def testLongToBytesCompatibleWithBytesToLongWithPadding(self):
+ long_num = 4239423984023840823047823975923401283971204812394723040127401238
+ self.assertEqual(
+ long_num, converters.BytesToLong(converters.LongToBytes(long_num, 51))
+ )
+
+ def testLongToBytesRaisesValueErrorForNegativeNumbers(self):
+ self.assertRaises(ValueError, converters.LongToBytes, -1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve.py b/private_join_and_compute/py/crypto_util/elliptic_curve.py
new file mode 100644
index 0000000..6d02670
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/elliptic_curve.py
@@ -0,0 +1,390 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Module for elliptic curve related classes."""
+
+import ctypes
+from typing import Optional, Union
+
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util import ssl_util
+from private_join_and_compute.py.crypto_util.ssl_util import BigNum
+from private_join_and_compute.py.crypto_util.ssl_util import OpenSSLHelper
+from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
+from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve
+from private_join_and_compute.py.crypto_util.supported_hashes import HashType
+import six
+
+POINT_CONVERSION_COMPRESSED = 2
+
+
+class ECPoint(object):
+ """The ECPoint class."""
+
+ def __init__(self, group, ec_point_bn):
+ self._group = group
+ self._point = ec_point_bn
+ self.ctx = OpenSSLHelper().ctx
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl_util.ssl
+
+ @classmethod
+ def FromPoint(cls, group: int, x: int, y: int):
+ """Creates an EC_POINT object with the given x, y affine coordinates.
+
+ Args:
+ group: the EC_GROUP for the given point's elliptic curve
+ x: the x coordinate of the point as long value
+ y: the y coordinate of the point as long value
+
+ Returns:
+ <x, y> ECPoint on the elliptic curve defined by group
+
+ Raises:
+ TypeError: If the x, y coordinates are not of type long.
+ """
+ ec_point = cls._EmptyPoint(group)
+ with TempBNs(x=x, y=y) as bn:
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_set_affine_coordinates_GFp(
+ group, ec_point._point, bn.x, bn.y, None
+ )
+ # pylint: enable=protected-access
+ ec_point.CheckValidity()
+ return ec_point
+
+ @classmethod
+ def FromLongOrBytes(cls, group: int, point_long_or_bytes: Union[int, bytes]):
+ """Creates an EC_POINT object from its serialized bytes representation.
+
+ Args:
+ group: the EC_GROUP for the point's elliptic curve.
+ point_long_or_bytes: the serialized bytes representations of the point.
+
+ Returns:
+ The point encoded by point_long_or_bytes
+
+ Raises:
+ ValueError: if point_long_or_bytes is not a valid encoding of a point
+ from the EC group.
+ """
+ ec_point = cls._EmptyPoint(group)
+ if isinstance(point_long_or_bytes, int):
+ point_long_or_bytes = converters.LongToBytes(point_long_or_bytes)
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_oct2point(
+ group,
+ ec_point._point,
+ point_long_or_bytes,
+ len(point_long_or_bytes),
+ None,
+ )
+ # pylint: enable=protected-access
+ ec_point.CheckValidity()
+ return ec_point
+
+ @classmethod
+ def GetPointAtInfinity(cls, group):
+ p = ssl_util.ssl.EC_POINT_new(group)
+ ssl_util.ssl.EC_POINT_set_to_infinity(group, p)
+ return ECPoint(group, p)
+
+ @classmethod
+ def _EmptyPoint(cls, group):
+ return ECPoint(group, ssl_util.ssl.EC_POINT_new(group))
+
+ def __del__(self):
+ self.ssl.EC_POINT_free(self._point)
+
+ def CheckValidity(self):
+ """Checks if this point is valid and can be multiplied with the key.
+
+ If the point is corrupted as a result of a faulty computation, this might
+ leak data about the key.
+
+ Raises:
+ ValueError: If the point is not on the curve or if the point is the
+ neutral element.
+ """
+ if not self.IsOnCurve():
+ raise ValueError('The point is not on the curve.')
+
+ if self.IsAtInfinity():
+ raise ValueError('The point is the neutral element.')
+
+ def __mul__(self, scalar):
+ new_ec_point = self._EmptyPoint(self._group)
+ # pylint: disable=protected-access
+ if isinstance(scalar, BigNum):
+ ssl_util.ssl.EC_POINT_mul(
+ self._group,
+ new_ec_point._point,
+ None,
+ self._point,
+ scalar._bn_num,
+ self.ctx,
+ )
+ else:
+ ssl_util.ssl.EC_POINT_mul(
+ self._group, new_ec_point._point, None, self._point, scalar, self.ctx
+ )
+ # pylint: enable=protected-access
+ return new_ec_point
+
+ def __imul__(self, scalar):
+ if isinstance(scalar, BigNum):
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_mul(
+ self._group, self._point, None, self._point, scalar._bn_num, self.ctx
+ )
+ # pylint: enable=protected-access
+ else:
+ ssl_util.ssl.EC_POINT_mul(
+ self._group, self._point, None, self._point, scalar, self.ctx
+ )
+ return self
+
+ def __add__(self, ec_point):
+ new_ec_point = self._EmptyPoint(self._group)
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_add(
+ self._group, new_ec_point._point, self._point, ec_point._point, self.ctx
+ )
+ # pylint: enable=protected-access
+ return new_ec_point
+
+ def __iadd__(self, ec_point):
+ # pylint: disable=protected-access
+ ssl_util.ssl.EC_POINT_add(
+ self._group, self._point, self._point, ec_point._point, self.ctx
+ )
+ # pylint: enable=protected-access
+ return self
+
+ def IsOnCurve(self) -> bool:
+ return 1 == ssl_util.ssl.EC_POINT_is_on_curve(
+ self._group, self._point, None
+ )
+
+ def IsAtInfinity(self) -> bool:
+ return 1 == ssl_util.ssl.EC_POINT_is_at_infinity(self._group, self._point)
+
+ def GetAsLong(self) -> int:
+ return converters.BytesToLong(self.GetAsBytes())
+
+ def GetAsBytes(self) -> bytes:
+ buf_len = ssl_util.ssl.EC_POINT_point2oct(
+ self._group, self._point, POINT_CONVERSION_COMPRESSED, None, 0, None
+ )
+ buf = ctypes.create_string_buffer(buf_len)
+ ssl_util.ssl.EC_POINT_point2oct(
+ self._group,
+ self._point,
+ POINT_CONVERSION_COMPRESSED,
+ buf,
+ buf_len,
+ None,
+ )
+ return six.ensure_binary(buf.raw)
+
+ def __eq__(self, other: 'ECPoint'):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return 0 == ssl_util.ssl.EC_POINT_cmp(
+ self._group, self._point, other._point, self.ctx
+ )
+ raise ValueError('Cannot compare ECPoint with type {}'.format(type(other)))
+ # pylint: enable=protected-access
+
+ def __ne__(self, other: 'ECPoint'):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return str(self.GetAsLong())
+
+
+class EllipticCurve(object):
+ """Class for representing the elliptic curve."""
+
+ def __init__(
+ self,
+ curve_id: Union[int, SupportedCurve],
+ hash_type: Optional[HashType] = None,
+ ):
+ if isinstance(curve_id, SupportedCurve):
+ curve_id = curve_id.id
+ if hash_type is None:
+ hash_type = HashType.SHA512
+ self._hash_type = hash_type
+ self._group = ssl_util.ssl.EC_GROUP_new_by_curve_name(curve_id)
+ with TempBNs(p=None, a=None, b=None, order=None) as bn:
+ ssl_util.ssl.EC_GROUP_get_curve_GFp(self._group, bn.p, bn.a, bn.b, None)
+ ssl_util.ssl.EC_GROUP_get_order(
+ self._group, bn.order, OpenSSLHelper().ctx
+ )
+ self._order = ssl_util.BnToLong(bn.order)
+ self._p = ssl_util.BnToLong(bn.p)
+ self._p_bn = BigNum.FromLongNumber(self._p)
+ if not self._p_bn.IsPrime():
+ raise ValueError(
+ 'Wrong curve parameters: p must be a prime. p: {}'.format(self._p)
+ )
+ self._a = ssl_util.BnToLong(bn.a)
+ self._b = ssl_util.BnToLong(bn.b)
+ self._p_sub_one_div_by_two = (self._p - 1) >> 1
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl_util.ssl
+
+ def __del__(self):
+ self.ssl.EC_GROUP_free(self._group)
+
+ def GetPointByHashingToCurve(self, m: Union[int, bytes]) -> ECPoint:
+ """Hashes m into the elliptic curve."""
+ return ECPoint.FromPoint(self.group, *self.HashToCurve(m))
+
+ def GetPointFromLong(self, m_long: int) -> ECPoint:
+ """Converts the given compressed point (m_long) into ECPoint."""
+ return ECPoint.FromLongOrBytes(self.group, m_long)
+
+ def GetPointFromBytes(self, m_bytes: bytes) -> ECPoint:
+ """Converts the given compressed point (m_bytes) into ECPoint."""
+ return ECPoint.FromLongOrBytes(self.group, m_bytes)
+
+ def GetPointAtInfinity(self) -> ECPoint:
+ """Gets a point at the infinity."""
+ return ECPoint.GetPointAtInfinity(self.group)
+
+ def GetRandomGenerator(self):
+ ssl_point = ssl_util.ssl.EC_GROUP_get0_generator(self.group)
+ generator = ECPoint(
+ self.group, ssl_util.ssl.EC_POINT_dup(ssl_point, self.group)
+ )
+ generator *= BigNum.FromLongNumber(self.order).GenerateRandWithStart(
+ BigNum.One()
+ )
+ return generator
+
+ def ComputeYSquare(self, x: int):
+ """Returns y^2 calculated with x^3 + ax + b."""
+ return (x**3 + self._a * x + self._b) % self._p
+
+ def HashToCurve(self, m: Union[int, bytes]):
+ """ "Hash m to a point on the elliptic curve y^2 = x^3 + ax + b.
+
+ To hash m to a point on the curve, the algorithm first computes an integer
+ hash value x = h(m) and determines whether x is the abscissa of a point on
+ the elliptic curve y^2 = x^3 + ax + b. If not, set x = h(x) and try again.
+
+ Security:
+ The number of operations required to hash a message m depends on m, which
+ could lead to a timing attack.
+
+ Args:
+ m: long, int or str input
+
+ Returns:
+ A point (x, y) on this elliptic curve.
+ """
+ x = ssl_util.RandomOracle(m, self._p, hash_type=self._hash_type)
+ y2 = self.ComputeYSquare(x)
+
+ # y2 is a quadratic residue if y2^(p-1)/2 = 1
+ if 1 == ssl_util.ModExp(y2, self._p_sub_one_div_by_two, self._p):
+ y2_bn = ssl_util.BigNum.FromLongNumber(y2).Mutable()
+ y2_bn.IModSqrt(self._p_bn)
+ if y2_bn.IsBitSet(0):
+ return (x, y2_bn.ModNegate(self._p_bn).GetAsLong())
+ return (x, y2_bn.GetAsLong())
+ else:
+ return self.HashToCurve(x)
+
+ def __eq__(self, other):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return self._p == other._p and self._a == other._a and self._b == other._b
+ raise ValueError(
+ 'Cannot compare EllipticCurve with type {}'.format(type(other))
+ )
+ # pylint: enable=protected-access
+
+ @property
+ def group(self):
+ return self._group
+
+ @property
+ def order(self):
+ return self._order
+
+
+class ECKey(object):
+ """Class representing the elliptic curve key."""
+
+ def __init__(
+ self,
+ curve_id: Union[int, SupportedCurve],
+ priv_key_bytes: Optional[bytes] = None,
+ hash_type: Optional[HashType] = None,
+ ):
+ if isinstance(curve_id, SupportedCurve):
+ curve_id = curve_id.id
+ self._curve_id = curve_id
+ self._key = ssl_util.ssl.EC_KEY_new_by_curve_name(curve_id)
+ if priv_key_bytes:
+ ssl_util.ssl.EC_KEY_set_private_key(
+ self._key, ssl_util.BytesToBn(priv_key_bytes)
+ )
+ else:
+ if 1 != ssl_util.ssl.EC_KEY_generate_key(self._key):
+ raise Exception('EC key generation failed.')
+ self._Check()
+ self._priv_key_bn = ssl_util.ssl.EC_KEY_get0_private_key(self._key)
+ self._priv_key_bytes = ssl_util.BnToBytes(self._priv_key_bn)
+ self._priv_key_bignum = BigNum.FromBytes(self._priv_key_bytes)
+ self._ec = EllipticCurve(curve_id, hash_type=hash_type)
+ self._decrypt_key = self._priv_key_bignum.ModInverse(
+ BigNum.FromLongNumber(self._ec.order)
+ )
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl_util.ssl
+
+ def __del__(self):
+ self.ssl.EC_KEY_free(self._key)
+
+ def _Check(self):
+ if 0 == ssl_util.ssl.EC_KEY_check_key(self._key):
+ raise ValueError('The ECKey checks has failed.')
+
+ @property
+ def priv_key_bytes(self):
+ return self._priv_key_bytes
+
+ @property
+ def priv_key_bn(self):
+ return self._priv_key_bn
+
+ @property
+ def priv_key_bignum(self):
+ return self._priv_key_bignum
+
+ @property
+ def decrypt_key_bignum(self):
+ return self._decrypt_key
+
+ @property
+ def elliptic_curve(self):
+ return self._ec
+
+ @property
+ def curve_id(self):
+ return self._curve_id
diff --git a/private_join_and_compute/py/crypto_util/elliptic_curve_test.py b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py
new file mode 100644
index 0000000..c3dfebc
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/elliptic_curve_test.py
@@ -0,0 +1,122 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test class for elliptic_curve module."""
+
+import os
+import random
+import unittest
+from unittest import mock
+
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util import ssl_util
+from private_join_and_compute.py.crypto_util.elliptic_curve import ECKey
+from private_join_and_compute.py.crypto_util.elliptic_curve import ECPoint
+from private_join_and_compute.py.crypto_util.ssl_util import BigNum
+from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
+from private_join_and_compute.py.crypto_util.supported_curves import SupportedCurve
+from private_join_and_compute.py.crypto_util.supported_hashes import HashType
+
+
+# Equivalent to C++ curve NID_X9_62_prime256v1
+TEST_CURVE = SupportedCurve.SECP256R1
+TEST_CURVE_ID = TEST_CURVE.id
+
+
+class EllipticCurveTest(unittest.TestCase):
+
+ def setUp(self):
+ super(EllipticCurveTest, self).setUp()
+
+ def testEcKey(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_key_same = ECKey(TEST_CURVE_ID, ec_key.priv_key_bytes)
+ self.assertEqual(
+ ssl_util.BnToBytes(ec_key.priv_key_bn),
+ ssl_util.BnToBytes(ec_key_same.priv_key_bn),
+ )
+ self.assertEqual(ec_key.curve_id, ec_key_same.curve_id)
+ self.assertEqual(ec_key.elliptic_curve, ec_key_same.elliptic_curve)
+
+ @mock.patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.RandomOracle',
+ lambda x, bit_length, hash_type=None: 2 * x,
+ )
+ def testHashToPoint(self):
+ t = random.getrandbits(160)
+ ec_key = ECKey(TEST_CURVE_ID)
+ x, y = ec_key.elliptic_curve.HashToCurve(t)
+ ECPoint.FromPoint(ec_key.elliptic_curve.group, x, y).CheckValidity()
+
+ def testEcPointsMultiplicationWithAddition(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ ec_point_sum = ec_point + ec_point + ec_point
+ with TempBNs(x=3) as bn:
+ ec_point_mul = ec_point * bn.x
+ self.assertEqual(ec_point_sum, ec_point_mul)
+ self.assertNotEqual(ec_point, ec_point_mul)
+
+ def testEcPointsInPlaceMult(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ with TempBNs(x=3) as bn:
+ ec_point *= bn.x
+ self.assertNotEqual(
+ ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point
+ )
+
+ def testEcPointsInPlaceAdd(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ ec_point += ec_key.elliptic_curve.GetPointByHashingToCurve(11)
+ self.assertNotEqual(
+ ec_key.elliptic_curve.GetPointByHashingToCurve(10), ec_point
+ )
+
+ def testEcCurveOrder(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ ec_point1 = ec_point * BigNum.FromLongNumber(3)
+ ec_point2 = ec_point * BigNum.FromLongNumber(
+ 3 + ec_key.elliptic_curve.order
+ )
+ self.assertEqual(ec_point1, ec_point2)
+
+ def testDecryptKey(self):
+ ec_key = ECKey(TEST_CURVE_ID)
+ ec_point = ec_key.elliptic_curve.GetPointByHashingToCurve(10)
+ self.assertEqual(
+ ec_point, ec_point * ec_key.priv_key_bn * ec_key.decrypt_key_bignum
+ )
+
+ @mock.patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.BigNum'
+ '.GenerateRandWithStart'
+ )
+ def testGetRandomGenerator(self, gen_rand):
+ gen_rand.return_value = BigNum.FromLongNumber(2)
+ ec_key = ECKey(TEST_CURVE_ID)
+ g1 = ec_key.elliptic_curve.GetRandomGenerator()
+ self.assertFalse(g1.IsAtInfinity())
+ self.assertTrue(g1.IsOnCurve())
+ gen_rand.return_value = BigNum.FromLongNumber(4)
+ g2 = ec_key.elliptic_curve.GetRandomGenerator()
+ self.assertFalse(g2.IsAtInfinity())
+ self.assertTrue(g2.IsOnCurve())
+ self.assertEqual(g2, g1 + g1)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/ssl_util.py b/private_join_and_compute/py/crypto_util/ssl_util.py
new file mode 100644
index 0000000..548deb8
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/ssl_util.py
@@ -0,0 +1,1098 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Make available access to openssl library and bn functions."""
+
+import ctypes.util
+from functools import total_ordering
+import hashlib
+import math
+from typing import Union
+
+from absl import logging
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util.supported_hashes import HashType
+import six
+
+ssl = None
+
+try:
+ ssl_libpath = ctypes.util.find_library('crypto')
+ ssl = ctypes.cdll.LoadLibrary(ssl_libpath)
+except (OSError, IOError) as e:
+ logging.fatal('Could not load the ssl library.\n%s', e)
+
+ssl.ERR_error_string_n.restype = ctypes.c_void_p
+ssl.ERR_error_string_n.argtypes = [
+ ctypes.c_long,
+ ctypes.c_char_p,
+ ctypes.c_size_t,
+]
+ssl.ERR_get_error.restype = ctypes.c_long
+ssl.ERR_get_error.argtypes = []
+
+ssl.BN_new.restype = ctypes.c_void_p
+ssl.BN_new.argtypes = []
+ssl.BN_free.argtypes = [ctypes.c_void_p]
+ssl.BN_num_bits.restype = ctypes.c_int
+ssl.BN_num_bits.argtypes = [ctypes.c_void_p]
+ssl.BN_bin2bn.restype = ctypes.c_void_p
+ssl.BN_bin2bn.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
+ssl.BN_bn2bin.restype = ctypes.c_int
+ssl.BN_bn2bin.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_CTX_new.restype = ctypes.c_void_p
+ssl.BN_CTX_new.argtypes = []
+ssl.BN_CTX_free.restype = ctypes.c_int
+ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
+ssl.BN_mod_exp.restype = ctypes.c_int
+ssl.BN_mod_exp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mod_mul.restype = ctypes.c_int
+ssl.BN_mod_mul.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_CTX_new.argtypes = []
+ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
+ssl.BN_generate_prime_ex.restype = ctypes.c_int
+ssl.BN_generate_prime_ex.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_int,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_is_prime_ex.restype = ctypes.c_int
+ssl.BN_is_prime_ex.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mul.restype = ctypes.c_int
+ssl.BN_mul.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_div.restype = ctypes.c_int
+ssl.BN_div.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_exp.restype = ctypes.c_int
+ssl.BN_exp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.RAND_seed.restype = ctypes.c_int
+ssl.RAND_seed.argtypes = [ctypes.c_void_p, ctypes.c_int]
+ssl.BN_gcd.restype = ctypes.c_int
+ssl.BN_gcd.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mod_inverse.restype = ctypes.c_void_p
+ssl.BN_mod_inverse.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_mod_sqrt.restype = ctypes.c_void_p
+ssl.BN_mod_sqrt.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_add.restype = ctypes.c_int
+ssl.BN_add.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_sub.restype = ctypes.c_int
+ssl.BN_sub.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_nnmod.restype = ctypes.c_int
+ssl.BN_nnmod.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_rand_range.restype = ctypes.c_int
+ssl.BN_rand_range.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_lshift.restype = ctypes.c_int
+ssl.BN_lshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
+ssl.BN_rshift.restype = ctypes.c_int
+ssl.BN_rshift.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
+ssl.BN_cmp.restype = ctypes.c_int
+ssl.BN_cmp.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_is_bit_set.restype = ctypes.c_int
+ssl.BN_is_bit_set.argtypes = [ctypes.c_void_p, ctypes.c_int]
+
+ssl.EVP_PKEY_new.argtypes = []
+ssl.EVP_PKEY_new.restype = ctypes.c_void_p
+
+ssl.EC_KEY_new.restype = ctypes.c_void_p
+ssl.EC_KEY_new.argtypes = []
+ssl.EC_KEY_free.argtypes = [ctypes.c_void_p]
+ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p
+ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int]
+ssl.EC_KEY_generate_key.restype = ctypes.c_int
+ssl.EC_KEY_generate_key.argtypes = [ctypes.c_void_p]
+ssl.EC_KEY_set_asn1_flag.restype = None
+ssl.EC_KEY_set_asn1_flag.argtypes = [ctypes.c_void_p, ctypes.c_int]
+
+ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p
+ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p]
+
+ssl.EC_KEY_set_public_key.restype = ctypes.c_int
+ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_KEY_get0_private_key.restype = ctypes.c_void_p
+ssl.EC_KEY_get0_private_key.argtypes = [ctypes.c_void_p]
+
+ssl.EC_KEY_set_private_key.restype = ctypes.c_int
+ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_KEY_check_key.restype = ctypes.c_int
+ssl.EC_KEY_check_key.argtypes = [ctypes.c_void_p]
+
+ssl.EVP_PKEY_free.argtypes = [ctypes.c_void_p]
+ssl.EVP_PKEY_free.restype = None
+
+ssl.EVP_PKEY_get1_EC_KEY.restype = ctypes.c_void_p
+ssl.EVP_PKEY_get1_EC_KEY.argtypes = [ctypes.c_void_p]
+
+ssl.EC_GROUP_free.argtypes = [ctypes.c_void_p]
+ssl.EC_GROUP_get_order.restype = ctypes.c_int
+ssl.EC_GROUP_get_order.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_GROUP_new_by_curve_name.restype = ctypes.c_void_p
+ssl.EC_GROUP_new_by_curve_name.argtypes = [ctypes.c_int]
+ssl.EC_GROUP_get0_generator.restype = ctypes.c_void_p
+ssl.EC_GROUP_get0_generator.argtypes = [ctypes.c_void_p]
+
+ssl.EC_POINT_new.argtypes = [ctypes.c_void_p]
+ssl.EC_POINT_new.restype = ctypes.c_void_p
+ssl.EC_POINT_dup.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.EC_POINT_dup.restype = ctypes.c_void_p
+
+ssl.EC_POINT_free.argtypes = [ctypes.c_void_p]
+
+ssl.EC_POINT_mul.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_mul.restype = ctypes.c_int
+
+ssl.EC_POINT_add.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_add.restype = ctypes.c_int
+
+ssl.EC_POINT_point2oct.restype = ctypes.c_int
+ssl.EC_POINT_point2oct.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_oct2point.restype = ctypes.c_int
+ssl.EC_POINT_oct2point.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_int,
+ ctypes.c_void_p,
+]
+
+ssl.EC_POINT_is_on_curve.restype = ctypes.c_int
+ssl.EC_POINT_is_on_curve.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int
+ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.EC_POINT_set_to_infinity.restype = ctypes.c_int
+ssl.EC_POINT_set_to_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_POINT_cmp.restype = ctypes.c_int
+ssl.EC_POINT_cmp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.PEM_write_PUBKEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.PEM_write_PUBKEY.restypes = ctypes.c_int
+
+ssl.PEM_write_PrivateKey.restype = ctypes.c_int
+ssl.PEM_write_PrivateKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.PEM_read_PrivateKey.restype = ctypes.c_void_p
+ssl.PEM_read_PrivateKey.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.EVP_PKEY_set1_EC_KEY.restype = ctypes.c_int
+ssl.EVP_PKEY_set1_EC_KEY.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+
+ssl.EC_GROUP_get_curve_GFp.restype = ctypes.c_int
+ssl.EC_GROUP_get_curve_GFp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.EC_POINT_set_affine_coordinates_GFp.restype = ctypes.c_int
+ssl.EC_POINT_set_affine_coordinates_GFp.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+
+ssl.BN_MONT_CTX_new.restype = ctypes.c_void_p
+ssl.BN_MONT_CTX_new.argtypes = []
+ssl.BN_MONT_CTX_set.restype = ctypes.c_int
+ssl.BN_MONT_CTX_set.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_MONT_CTX_free.argtypes = [ctypes.c_void_p]
+ssl.BN_mod_mul_montgomery.restype = ctypes.c_int
+ssl.BN_mod_mul_montgomery.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_to_montgomery.restype = ctypes.c_int
+ssl.BN_to_montgomery.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_from_montgomery.restype = ctypes.c_int
+ssl.BN_from_montgomery.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+ ctypes.c_void_p,
+]
+ssl.BN_copy.restype = ctypes.c_void_p
+ssl.BN_copy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+ssl.BN_dup.restype = ctypes.c_void_p
+ssl.BN_dup.argtypes = [ctypes.c_void_p]
+
+pointer = ctypes.pointer
+cast = ctypes.cast
+
+
+class SSLProxy(object):
+ """Wrapper (a pass-through with error checking) for the loaded ssl library.
+
+ This class checks the ssl methods returning pointers does not return None and
+ also checks methods returning 0 on failure. In case of a failure, it prints
+ OpenSSL error messages.
+ """
+
+ def __init__(self, ssl_lib):
+ self._ssl = ssl_lib
+ self._cache = {}
+ # Functions without a return value or having a return value that is already
+ # explicitly checked in the code.
+ self._funcs_to_skip = {
+ 'BN_free',
+ 'BN_CTX_free',
+ 'BN_cmp',
+ 'BN_num_bits',
+ 'BN_bn2bin',
+ 'EC_POINT_is_at_infinity',
+ 'EC_POINT_cmp',
+ 'EC_POINT_free',
+ 'EC_KEY_free',
+ 'BN_MONT_CTX_free',
+ 'BN_is_bit_set',
+ 'EC_GROUP_free',
+ 'BN_is_prime_ex',
+ 'EC_POINT_point2oct',
+ }
+
+ def _DebugInfo(self):
+ """Returns the last error message from the OpenSSL library."""
+ err = ctypes.create_string_buffer(256)
+ self._ssl.ERR_error_string_n(self._ssl.ERR_get_error(), err, 256)
+ return '\nOpenSSL Error: {}'.format(err.value)
+
+ def __getattr__(self, name):
+ if name in self._funcs_to_skip:
+ return getattr(self._ssl, name)
+ if name not in self._cache:
+
+ def WrapperFunc(*args):
+ func = getattr(self._ssl, name)
+ ret = func(*args)
+ if func.restype is ctypes.c_void_p:
+ assert ret is not None, 'ret is None{}'.format(self._DebugInfo())
+ elif func.restype is ctypes.c_int:
+ assert 1 == ret, 'ret is not 1, ret: {}{}'.format(
+ ret, self._DebugInfo()
+ )
+ return ret
+
+ self._cache[name] = WrapperFunc
+ return self._cache[name]
+
+
+ssl = SSLProxy(ssl)
+
+
+def LongtoBn(bn_r: int, a: int) -> int:
+ """Converts a to BigNum and stores in preallocated bn_r."""
+ bytes_a = converters.LongToBytes(a)
+ return ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r)
+
+
+def BnToLong(bn_a: int) -> int:
+ """Converts BigNum to long."""
+ num_bits_in_a = ssl.BN_num_bits(bn_a)
+ num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0))
+ bytes_a = ctypes.create_string_buffer(num_bytes_in_a)
+ ssl.BN_bn2bin(bn_a, bytes_a)
+ return converters.BytesToLong(bytes_a.raw)
+
+
+def BnToBytes(bn_a: int) -> bytes:
+ """Converts BigNum to long."""
+ num_bits_in_a = ssl.BN_num_bits(bn_a)
+ num_bytes_in_a = int(math.ceil(num_bits_in_a / 8.0))
+ bytes_a = ctypes.create_string_buffer(num_bytes_in_a)
+ ssl.BN_bn2bin(bn_a, bytes_a)
+ return bytes_a.raw
+
+
+def BytesToBn(bytes_a: bytes) -> int:
+ """Converts BigNum to long."""
+ bn_r = ssl.BN_new()
+ ssl.BN_bin2bn(bytes_a, len(bytes_a), bn_r)
+ return bn_r
+
+
+def GetRandomInRange(long_start: int, long_end: int) -> int:
+ """ "Returns a random in the range [long_start, long_end)."""
+ with TempBNs(rand=None, interval=(long_end - long_start)) as bn:
+ ssl.BN_rand_range(bn.rand, bn.interval)
+ return BnToLong(bn.rand) + long_start
+
+
+def ModExp(g: int, x: int, n: int) -> int:
+ """Computes g^x mod n."""
+ with TempBNs(r=None, g=g, x=x, n=n) as bn:
+ ssl.BN_mod_exp(bn.r, bn.g, bn.x, bn.n, OpenSSLHelper().ctx)
+ return BnToLong(bn.r)
+
+
+def ModInverse(x: int, n: int) -> int:
+ """Computes 1/x mod n."""
+ with TempBNs(r=None, x=x, n=n) as bn:
+ ssl.BN_mod_inverse(bn.r, bn.x, bn.n, OpenSSLHelper().ctx)
+ return BnToLong(bn.r)
+
+
+class TempBNs(object):
+ """Class for creating temporary openssl bignums by using 'with' clause."""
+
+ # Disable pytype attribute checking.
+ _HAS_DYNAMIC_ATTRIBUTES = True
+
+ def __init__(self, **kwargs):
+ r"""Initializes and assigns all temporary bignums.
+
+ Usage:
+ with TempBNs(x=5, y=[10,11]) as bn:
+ # bn.x is the temporary bignum holding the value 5 within this scope.
+ # bn.y is the temporary list of bignum holding the value 10 and 11
+ # within this scope.
+
+ or it can be used for assigning temporary results into bignums as follows:
+ with TempBNs(result=None, x=5) as bn:
+ # bn.result is an empty temporary bignum within this scope.
+ # bn.x is the same as before.
+
+ or bytes can be given as well as longs:
+ with TempBNs(x=5, y=['\001', '\002']) as bn:
+ # bn.x is the temporary bignum holding the value 5 within this scope.
+ # bn.y is the temporary list of bignum holding the value 1 and 2 within
+ # this scope.
+
+ Args:
+ **kwargs: key (variable), value (int or long) pairs.
+ """
+ self._args = []
+ for key, value in kwargs.items():
+ assert not hasattr(self, key), '{} already exists.'.format(key)
+ if isinstance(value, list):
+ assert value, 'Cannot declare empty list in TempBNs.'
+ for v in value:
+ self._args.append(ssl.BN_new())
+ self._BytesOrLongToBn(self._args[-1], v)
+ setattr(self, key, self._args[-len(value) :])
+ else:
+ self._args.append(ssl.BN_new())
+ setattr(self, key, self._args[-1])
+ if value:
+ self._BytesOrLongToBn(self._args[-1], value)
+
+ @classmethod
+ def _BytesOrLongToBn(cls, bn, val) -> int:
+ if isinstance(val, int):
+ LongtoBn(bn, val)
+ if isinstance(val, str):
+ ssl.BN_bin2bn(val, len(val), bn)
+
+ def __enter__(self, *args):
+ return self
+
+ def __exit__(self, some_type, value, traceback):
+ for bn in self._args:
+ ssl.BN_free(bn)
+
+
+def RandomOracle(
+ x: Union[int, bytes],
+ max_value: int,
+ hash_type: Union[type(None), HashType] = None,
+) -> int:
+ """A random oracle function mapping x deterministically into a large domain.
+
+ The random oracle is similar to the example given in the last paragraph of
+ Chapter 6 of [1] where the output is expanded by successively hashing the
+ concatenation of the input with a fixed sized counter starting from 1.
+
+ [1] Bellare, Mihir, and Phillip Rogaway. "Random oracles are practical:
+ A paradigm for designing efficient protocols." Proceedings of the 1st ACM
+ conference on Computer and communications security. ACM, 1993.
+
+ Args:
+ x: long or string input
+ max_value: the max value of the output domain.
+ hash_type: the hash function to use, as a HashType. If 'None' is provided
+ this defaults to HashType.SHA512.
+
+ Returns:
+ a long value from the set [0, max_value).
+
+ Raises:
+ ValueError: if bit length of max_value is greater than
+ hash_type.bit_length * 254. Since the counter used for expanding the
+ output is expanded to 8 bit length (hard-coded), any counter value that is
+ greater than 256 would cause variable length inputs passed to the
+ underlying hash calls and might make this random oracle's output not
+ uniform across the output domain. The output length is increased by a
+ security value of hash_type.bit_length which reduces the bias of selecting
+ certain values more often than others when max_value is not a multiple of
+ 2.
+ """
+ if hash_type is None:
+ hash_type = HashType.SHA512
+ output_bit_length = max_value.bit_length() + hash_type.bit_length
+ iter_count = int(math.ceil(float(output_bit_length) / hash_type.bit_length))
+ if iter_count > 255:
+ raise ValueError(
+ 'The domain bit length must not be greater than H * 254. '
+ 'Given bit length: {}'.format(output_bit_length)
+ )
+ excess_bit_count = (iter_count * hash_type.bit_length) - output_bit_length
+ hash_output = 0
+ bytes_x = x if isinstance(x, bytes) else converters.LongToBytes(x)
+ for i in range(1, iter_count + 1):
+ hash_output <<= hash_type.bit_length
+ hash_output |= hash_type.hash(
+ six.ensure_binary(converters.LongToBytes(i) + bytes_x)
+ )
+ return (hash_output >> excess_bit_count) % max_value
+
+
+class PRNG(object):
+ """Hash based counter mode pseudorandom number generator.
+
+ The technique used in this class is same as the one used in RandomOracle
+ function.
+ """
+
+ def __init__(self, seed, counter_byte_len=4):
+ """Creates the PRNG with the given seed.
+
+ Args:
+ seed: at least 32 byte number or string.
+ counter_byte_len: the max number of counter bytes to use. After exceeding
+ the counter, this PRNG should not be used.
+
+ Raises:
+ ValueError: when the seed is not at least 32 bytes.
+ """
+ self.seed = (
+ seed if isinstance(seed, bytes) else converters.LongToBytes(seed)
+ )
+ if len(self.seed) < 32:
+ raise ValueError(
+ 'seed needs to be at least 32 bytes, the given bytes: {}'.format(
+ self.seed
+ )
+ )
+ self.cur_pad = 0
+ self.cur_bytes = b''
+ self.cur_byte_len = counter_byte_len
+ self.limit = 1 << (self.cur_byte_len * 8)
+
+ def _GetMore(self):
+ assert self.cur_pad < self.limit, 'Limit has been reached.'
+ hash_output = six.ensure_binary(
+ hashlib.sha512(
+ six.ensure_binary(self._PaddedCountBytes() + self.seed)
+ ).digest()
+ )
+ self.cur_pad += 1
+ self.cur_bytes += hash_output
+
+ def _PaddedCountBytes(self):
+ counter_bytes = converters.LongToBytes(self.cur_pad)
+ # Although we could use {:\x004}.format, Python seems to print space when
+ # doing this way for the null character.
+ return b'\000' * (self.cur_byte_len - len(counter_bytes)) + counter_bytes
+
+ def _GetNBitRand(self, n):
+ """Gets a random number in [0, 2^n) interval."""
+ byte_len = (n + 7) >> 3
+ excess_len = (8 - (n % 8)) % 8
+ while len(self.cur_bytes) < byte_len:
+ self._GetMore()
+ self.cur_bytes, rand = (
+ self.cur_bytes[byte_len:],
+ self.cur_bytes[:byte_len],
+ )
+ rand_num = converters.BytesToLong(rand) >> excess_len
+ return rand_num
+
+ def GetRand(self, upper_limit):
+ """Gets a random number in [0, upper_limit) interval."""
+ bit_len = (upper_limit - 1).bit_length()
+ rand_num = self._GetNBitRand(bit_len)
+ while rand_num >= upper_limit:
+ rand_num = self._GetNBitRand(bit_len)
+ return rand_num
+
+
+class OpenSSLHelper(object):
+ """A singleton wrapper class for openssl ctx and seeding its rand.
+
+ Context is used for caching already allocated big nums. Each openssl operation
+ requires a context to be passed to Get temporary big nums avoiding allocating
+ new big nums for these temporary nums thus making big num operations use
+ memory resources more efficiently. Usage in openssl library:
+
+ BN_CTX_start(ctx)
+ ....
+ temp = BN_CTX_get(ctx)
+ ....
+ BN_CTX_end(ctx)
+ Please note that BN_CTX_start and BN_CTX_end is not implemented here as this
+ is only passed to various openssl big num operations.
+ """
+
+ _instance = None
+
+ def __new__(cls, *args, **kwargs):
+ if not cls._instance:
+ cls._instance = super(OpenSSLHelper, cls).__new__(cls, *args, **kwargs)
+ return cls._instance
+
+ def __init__(self):
+ self._ctx = ssl.BN_CTX_new()
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl
+
+ def __del__(self):
+ # clean up
+ self.ssl.BN_CTX_free(self._ctx)
+
+ @property
+ def ctx(self):
+ return self._ctx
+
+
+@total_ordering
+class BigNum(object):
+ """A wrapper class for openssl bn numbers.
+
+ Used for arithmetic operations on long numbers.
+ """
+
+ _ZERO = None
+ _ONE = None
+ _TWO = None
+
+ def __init__(self, bn_num):
+ self._bn_num = bn_num
+ self._helper = OpenSSLHelper()
+ self.immutable = True
+ # So that garbage collection doesn't collect ssl before this object.
+ self.ssl = ssl
+
+ @classmethod
+ def Zero(cls):
+ if not cls._ZERO:
+ cls._ZERO = cls.FromLongNumber(0)
+ return cls._ZERO
+
+ @classmethod
+ def One(cls):
+ if not cls._ONE:
+ cls._ONE = cls.FromLongNumber(1)
+ return cls._ONE
+
+ @classmethod
+ def Two(cls):
+ if not cls._TWO:
+ cls._TWO = cls.FromLongNumber(2)
+ return cls._TWO
+
+ @classmethod
+ def FromLongNumber(cls, long_number: int) -> 'BigNum':
+ """Returns a BigNum constructed from the given long number."""
+ bytes_num = converters.LongToBytes(long_number)
+ return cls.FromBytes(bytes_num)
+
+ @classmethod
+ def FromBytes(cls, number_in_bytes):
+ """Returns a BigNum constructed from the given long number."""
+ bn_num = ssl.BN_new()
+ ssl.BN_bin2bn(number_in_bytes, len(number_in_bytes), bn_num)
+ return BigNum(bn_num)
+
+ @classmethod
+ def GenerateSafePrime(cls, prime_length):
+ """Returns a safe prime BigNum with the given bit-length."""
+ bn_prime_num = ssl.BN_new()
+ ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 1, None, None, None)
+ return BigNum(bn_prime_num)
+
+ @classmethod
+ def GeneratePrime(cls, prime_length: int) -> 'BigNum':
+ """Returns a prime BigNum with the given bit-length."""
+ bn_prime_num = ssl.BN_new()
+ ssl.BN_generate_prime_ex(bn_prime_num, prime_length, 0, None, None, None)
+ return BigNum(bn_prime_num)
+
+ def GeneratePrimeForSubGroup(self, prime_length: int) -> 'BigNum':
+ """Returns a prime BigNum, p, satisfying p = (self * k) + 1 for some k.
+
+ Args:
+ prime_length: the bit length of the returned prime.
+
+ Returns:
+ a prime BigNum, p = (self * k) + 1 for some k.
+ """
+ bn_prime_num = ssl.BN_new()
+ ssl.BN_generate_prime_ex(
+ bn_prime_num, prime_length, 0, self._bn_num, None, None
+ )
+ return BigNum(bn_prime_num)
+
+ def IsPrime(self, error_probability=1e-6):
+ """Returns True if this big num is prime, False otherwise."""
+ rounds = int(math.ceil(-math.log(error_probability) / math.log(4)))
+ return ssl.BN_is_prime_ex(self._bn_num, rounds, self._helper.ctx, None) != 0
+
+ def IsSafePrime(self, error_probability=1e-6):
+ """Returns True if this big num is a safe prime, False otherwise."""
+ return self.IsPrime(error_probability) and (
+ (self - self.One()) / self.Two()
+ ).IsPrime(error_probability)
+
+ def IsBitSet(self, n):
+ """Returns True if the n-th bit is set, False otherwise."""
+ return ssl.BN_is_bit_set(self._bn_num, n)
+
+ def BitLength(self):
+ return ssl.BN_num_bits(self._bn_num)
+
+ def Clone(self):
+ """Clones this big num."""
+ return BigNum(ssl.BN_dup(self._bn_num))
+
+ def Mutable(self):
+ """Sets this BigNum to mutable so that it can be changed."""
+ self.immutable = False
+ return self
+
+ def __hash__(self):
+ return hash((self._bn_num, self.immutable))
+
+ def __del__(self):
+ self.ssl.BN_free(self._bn_num)
+
+ def __add__(self, other):
+ return self._ComputeResult(ssl.BN_add, None, other)
+
+ def __iadd__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_add, None, other)
+
+ def __sub__(self, other):
+ return self._ComputeResult(ssl.BN_sub, None, other)
+
+ def __isub__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_sub, None, other)
+
+ def __mul__(self, other):
+ return self._ComputeResult(ssl.BN_mul, self._helper.ctx, other)
+
+ def __imul__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_mul, self._helper.ctx, other)
+
+ def __mod__(self, modulus):
+ return self._ComputeResult(ssl.BN_nnmod, self._helper.ctx, modulus)
+
+ def __imod__(self, modulus):
+ return self._ComputeResultInPlace(ssl.BN_nnmod, self._helper.ctx, modulus)
+
+ def __pow__(self, other):
+ return self._ComputeResult(ssl.BN_exp, self._helper.ctx, other)
+
+ def __ipow__(self, other):
+ return self._ComputeResultInPlace(ssl.BN_exp, self._helper.ctx, other)
+
+ def __rshift__(self, n):
+ bn_num = ssl.BN_new()
+ ssl.BN_rshift(bn_num, self._bn_num, n)
+ return BigNum(bn_num)
+
+ def __irshift__(self, n):
+ ssl.BN_rshift(self._bn_num, self._bn_num, n)
+ return self
+
+ def __lshift__(self, n):
+ bn_num = ssl.BN_new()
+ ssl.BN_lshift(bn_num, self._bn_num, n)
+ return BigNum(bn_num)
+
+ def __ilshift__(self, n):
+ ssl.BN_lshift(self._bn_num, self._bn_num, n)
+ return self
+
+ def __div__(self, other):
+ return self._Div(BigNum(ssl.BN_new()), self, other)
+
+ def __truediv__(self, other):
+ return self._Div(BigNum(ssl.BN_new()), self, other)
+
+ def __idiv__(self, other):
+ return self._Div(self, self, other)
+
+ def _Div(self, big_result, big_num, other_big_num):
+ """Divides two bignums.
+
+ Args:
+ big_result: The bignum where the result is stored.
+ big_num: The numerator.
+ other_big_num: The denominator.
+
+ Returns:
+ big_result.
+
+ Raises:
+ ValueError: If the remainder is non-zero.
+ """
+ if isinstance(other_big_num, self.__class__):
+ bn_remainder = ssl.BN_new()
+ ssl.BN_div(
+ big_result._bn_num,
+ bn_remainder,
+ big_num._bn_num,
+ other_big_num._bn_num,
+ self._helper.ctx,
+ )
+ try:
+ if ssl.BN_cmp(bn_remainder, self.Zero()._bn_num) != 0:
+ raise ValueError(
+ 'There is a remainder in division of {} and {}'.format(
+ big_num.GetAsLong(), other_big_num.GetAsLong()
+ )
+ )
+ return big_result
+ finally:
+ ssl.BN_free(bn_remainder)
+
+ def ModMul(self, other, modulus):
+ """Modular multiplies this with other based on the modulus.
+
+ For efficiency, please use Montgomery multiplication module if this is done
+ multiple times with the same modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ a new BigNum holding the result.
+ """
+ return self._ComputeResult(ssl.BN_mod_mul, self._helper.ctx, other, modulus)
+
+ def IModMul(self, other, modulus):
+ """Modular multiplies this with other based on the modulus.
+
+ Stores the result in this BigNum.
+ For efficiency, please use Montgomery multiplication module if this is done
+ multiple times with the same modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ self
+ """
+ return self._ComputeResultInPlace(
+ ssl.BN_mod_mul, self._helper.ctx, other, modulus
+ )
+
+ def ModExp(self, other, modulus):
+ """Modular exponentiates this with other based on the modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ a new BigNum holding the result.
+ """
+ return self._ComputeResult(ssl.BN_mod_exp, self._helper.ctx, other, modulus)
+
+ def IModExp(self, other, modulus):
+ """Modular exponentiates this with other based on the modulus.
+
+ Args:
+ other: the other BigNum
+ modulus: the modulus of the operation
+
+ Returns:
+ self
+ """
+ return self._ComputeResultInPlace(
+ ssl.BN_mod_exp, self._helper.ctx, other, modulus
+ )
+
+ def GCD(self, other):
+ """Gets gcd as a BigNum."""
+ return self._ComputeResult(ssl.BN_gcd, self._helper.ctx, other)
+
+ def ModInverse(self, modulus):
+ """Gets the inverse of this BigNum in mod modulus."""
+ try:
+ return self._ComputeResult(ssl.BN_mod_inverse, self._helper.ctx, modulus)
+ except AssertionError as a:
+ raise ValueError(
+ 'This big num {} and modulus {} are not relatively '
+ 'primes.\nThe Assertion Error: {}'.format(
+ self.GetAsLong(), modulus.GetAsLong(), a
+ )
+ )
+
+ def ModSqrt(self, modulus):
+ """Gets the sqrt of this BigNum in mod modulus.
+
+ Args:
+ modulus: the modulus of the operation
+
+ Returns:
+ a new BigNum holding the result.
+ """
+ big_num_result = self._ComputeResult(
+ ssl.BN_mod_sqrt, self._helper.ctx, modulus
+ )
+ return big_num_result
+
+ def IModSqrt(self, modulus):
+ """Gets the sqrt of this BigNum in mod modulus.
+
+ Args:
+ modulus: the modulus of the operation
+
+ Returns:
+ self
+ """
+ return self._ComputeResultInPlace(
+ ssl.BN_mod_sqrt, self._helper.ctx, modulus
+ )
+
+ def GenerateRand(self):
+ """Generates a cryptographically strong pseudo-random between 0 & self.
+
+ Returns:
+ A BigNum in [0, self._big_num) range.
+ """
+ bn_rand = ssl.BN_new()
+ ssl.BN_rand_range(bn_rand, self._bn_num)
+ return BigNum(bn_rand)
+
+ def GenerateRandWithStart(self, start_big_num):
+ """Generates a cryptographically strong pseudo-random between start & self.
+
+ Args:
+ start_big_num: start BigNum value of the interval.
+
+ Returns:
+ A BigNum in [start, self._big_num) range.
+ """
+ return (self - start_big_num).GenerateRand() + start_big_num
+
+ def ModNegate(self, modulus):
+ return modulus - (self % modulus)
+
+ def AddOne(self):
+ return self + self.One()
+
+ def SubtractOne(self):
+ return self - self.One()
+
+ def __str__(self):
+ return str(self.GetAsLong())
+
+ def __eq__(self, other):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return ssl.BN_cmp(self._bn_num, other._bn_num) == 0
+ raise ValueError('Cannot compare BigNum with type {}'.format(type(other)))
+ # pylint: enable=protected-access
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __lt__(self, other):
+ # pylint: disable=protected-access
+ if isinstance(other, self.__class__):
+ return ssl.BN_cmp(self._bn_num, other._bn_num) == -1
+ raise ValueError('Cannot compare BigNum with type {}'.format(type(other)))
+ # pylint: enable=protected-access
+
+ def _ComputeResult(self, func, ctx, *args):
+ return self._ComputeResultIntoBigNum(
+ BigNum(ssl.BN_new()), func, ctx, self, *args
+ )
+
+ def _ComputeResultInPlace(self, func, ctx, *args):
+ if self.immutable:
+ raise ValueError(
+ 'This operation will change this immutable BigNum. Call '
+ 'Mutable method to change it.'
+ )
+ return self._ComputeResultIntoBigNum(self, func, ctx, self, *args)
+
+ @classmethod
+ def _ComputeResultIntoBigNum(cls, big_num_result, func, ctx, *args):
+ # pylint: disable=protected-access
+ if all(isinstance(big_num, cls) for big_num in args):
+ args = [big_num._bn_num for big_num in args]
+ if ctx:
+ args.append(ctx)
+ func(big_num_result._bn_num, *args)
+ return big_num_result
+ return NotImplemented
+ # pylint: enable=protected-access
+
+ def GetAsLong(self):
+ """Gets the long number in this BigNum."""
+ return converters.BytesToLong(self.GetAsBytes())
+
+ def GetAsBytes(self):
+ """Gets the long number as bytes in this BigNum."""
+ num_bits = ssl.BN_num_bits(self._bn_num)
+ num_bytes = int(math.ceil(num_bits / 8.0))
+ bytes_num = ctypes.create_string_buffer(num_bytes)
+ ssl.BN_bn2bin(self._bn_num, bytes_num)
+ return bytes_num.raw
+
+
+class BigNumCache(object):
+ """A singleton cache holding BigNum representations of small numbers."""
+
+ _instance = None
+
+ def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
+ if not cls._instance:
+ cls._instance = super(BigNumCache, cls).__new__(cls)
+ return cls._instance
+
+ def __init__(self, max_count: int):
+ self._cache = {}
+ self._max_count = max_count
+
+ def Get(self, num: int) -> BigNum:
+ """Gets the BigNum from the cache or creates a new BigNum.
+
+ If max_count is reached, a new BigNum is created and returned without
+ storing in the cache.
+ Args:
+ num: the long or integer to convert to BigNum.
+
+ Returns:
+ a BigNum for the given num.
+ """
+ if num not in self._cache:
+ if len(self._cache) >= self._max_count:
+ return BigNum.FromLongNumber(num)
+ self._cache[num] = BigNum.FromLongNumber(num)
+ return self._cache[num]
diff --git a/private_join_and_compute/py/crypto_util/ssl_util_test.py b/private_join_and_compute/py/crypto_util/ssl_util_test.py
new file mode 100644
index 0000000..ec9d24e
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/ssl_util_test.py
@@ -0,0 +1,543 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Test class for ssl_util module."""
+
+import os
+import unittest
+from unittest import mock
+from unittest.mock import call
+from unittest.mock import patch
+
+from private_join_and_compute.py.crypto_util import converters
+from private_join_and_compute.py.crypto_util import ssl_util
+from private_join_and_compute.py.crypto_util.ssl_util import PRNG
+from private_join_and_compute.py.crypto_util.ssl_util import TempBNs
+
+
+class SSLUtilTest(unittest.TestCase):
+
+ def setUp(self):
+ self.test_path = os.path.join(
+ os.getcwd(), 'privacy/blinders/testing/data/random_oracle'
+ )
+
+ def testRandomOracleRaisesValueErrorForVeryLargeDomains(self):
+ self.assertRaises(ValueError, ssl_util.RandomOracle, 1, 1 << 130048)
+
+ def _GenericRandomTestForCasesThatShouldReturnOneNum(
+ self, expected_value, rand_func, *args
+ ):
+ # There is at least %50 chance one iteration would catch the error if
+ # rand_func also returns something outside the interval. Doing the same test
+ # 20 times would increase the overall chance to %99.9999 in the worst case
+ # scenario (i.e., the rand_func may return only one other element except the
+ # the expected value).
+ for _ in range(20):
+ actual_value = rand_func(*args)
+ self.assertEqual(
+ actual_value,
+ expected_value,
+ 'The generated rand is {} but should be {} instead.'.format(
+ actual_value, expected_value
+ ),
+ )
+
+ def testGetRandomInRangeSingleNumber(self):
+ self._GenericRandomTestForCasesThatShouldReturnOneNum(
+ 2**30 - 1, ssl_util.GetRandomInRange, 2**30 - 1, 2**30
+ )
+
+ def testGetRandomInRangeMultipleNumbers(self):
+ rand = ssl_util.GetRandomInRange(11111111111, 11111111111111111111111)
+ self.assertTrue(11111111111 <= rand < 11111111111111111111111) # pylint: disable=g-generic-assert
+
+ def testModExp(self):
+ self.assertEqual(1, ssl_util.ModExp(3, 4, 80))
+
+ def testModInverse(self):
+ self.assertEqual(5, ssl_util.ModInverse(2, 9))
+
+ def testGetRandomInRangeReturnOnlyOneValueWhenIntervalIsOne(self):
+ random = ssl_util.GetRandomInRange(99999999999999998, 99999999999999999)
+ self.assertEqual(99999999999999998, random)
+
+ def testGetRandomInRangeReturnsAValueInRange(self):
+ random = ssl_util.GetRandomInRange(99999999999999998, 100000000000000000000)
+ self.assertLessEqual(99999999999999998, random)
+ self.assertLess(random, 100000000000000000000)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForValues(self, mocked_ssl):
+ with TempBNs(x=10, y=20) as bn:
+ self.assertEqual(10, ssl_util.BnToLong(bn.x))
+ self.assertEqual(20, ssl_util.BnToLong(bn.y))
+ x_addr = bn.x
+ y_addr = bn.y
+ self.assertEqual(2, mocked_ssl.BN_free.call_count)
+ mocked_ssl.BN_free.assert_any_call(x_addr)
+ mocked_ssl.BN_free.assert_any_call(y_addr)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForLists(self, mocked_ssl):
+ with TempBNs(x=10, y=[20, 30], z=40) as bn:
+ self.assertEqual(10, ssl_util.BnToLong(bn.x))
+ self.assertEqual(20, ssl_util.BnToLong(bn.y[0]))
+ self.assertEqual(30, ssl_util.BnToLong(bn.y[1]))
+ self.assertEqual(40, ssl_util.BnToLong(bn.z))
+ addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
+ self.assertEqual(4, mocked_ssl.BN_free.call_count)
+ for addr in addrs:
+ mocked_ssl.BN_free.assert_any_call(addr)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForBytes(self, mocked_ssl):
+ with TempBNs(x='\001', y=['\002', '\003'], z='\004') as bn:
+ self.assertEqual(1, ssl_util.BnToLong(bn.x))
+ self.assertEqual(2, ssl_util.BnToLong(bn.y[0]))
+ self.assertEqual(3, ssl_util.BnToLong(bn.y[1]))
+ self.assertEqual(4, ssl_util.BnToLong(bn.z))
+ addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
+ self.assertEqual(4, mocked_ssl.BN_free.call_count)
+ for addr in addrs:
+ mocked_ssl.BN_free.assert_any_call(addr)
+
+ @patch(
+ 'private_join_and_compute.py.crypto_util.ssl_util.ssl', wraps=ssl_util.ssl
+ )
+ def testTempBNsForBytesOrLong(self, mocked_ssl):
+ with TempBNs(x=1, y=['\002', 3], z='\004') as bn:
+ self.assertEqual(1, ssl_util.BnToLong(bn.x))
+ self.assertEqual(2, ssl_util.BnToLong(bn.y[0]))
+ self.assertEqual(3, ssl_util.BnToLong(bn.y[1]))
+ self.assertEqual(4, ssl_util.BnToLong(bn.z))
+ addrs = [bn.x, bn.y[0], bn.y[1], bn.z]
+ self.assertEqual(4, mocked_ssl.BN_free.call_count)
+ for addr in addrs:
+ mocked_ssl.BN_free.assert_any_call(addr)
+
+ def testTempBNsRaisesAssertionErrorWhenAListIsEmpty(self):
+ self.assertRaises(AssertionError, TempBNs, x=10, y=[20, 30], z=[])
+
+ def testTempBNsRaisesAssertionErrorWhenAlreadySetKeyUsed(self):
+ self.assertRaises(AssertionError, TempBNs, _args=10)
+
+ def testBigNumInitializes(self):
+ big_num = ssl_util.BigNum.FromLongNumber(1)
+ self.assertEqual(1, big_num.GetAsLong())
+
+ def testOpenSSLHelperIsSingleton(self):
+ helper1 = ssl_util.OpenSSLHelper()
+ helper2 = ssl_util.OpenSSLHelper()
+ self.assertIs(helper1, helper2)
+
+ def testBigNumGeneratesSafePrime(self):
+ big_prime = ssl_util.BigNum.GenerateSafePrime(100)
+ self.assertTrue(
+ big_prime.IsPrime()
+ and (
+ big_prime.SubtractOne() / ssl_util.BigNum.FromLongNumber(2)
+ ).IsPrime()
+ )
+ self.assertEqual(100, big_prime.BitLength())
+
+ def testBigNumIsSafePrime(self):
+ prime = ssl_util.BigNum.FromLongNumber(23)
+ self.assertTrue(prime.IsSafePrime())
+ prime = ssl_util.BigNum.FromLongNumber(29)
+ self.assertFalse(prime.IsSafePrime())
+
+ def testBigNumGeneratesPrime(self):
+ big_prime = ssl_util.BigNum.GeneratePrime(100)
+ self.assertTrue(big_prime.IsPrime())
+ self.assertEqual(100, big_prime.BitLength())
+
+ def testBigNumGeneratesPrimeForSubGroup(self):
+ prime = ssl_util.BigNum.GeneratePrime(50)
+ big_prime = prime.GeneratePrimeForSubGroup(100)
+ self.assertTrue(big_prime.IsPrime())
+ self.assertEqual(ssl_util.BigNum.One(), big_prime % prime)
+ self.assertEqual(100, big_prime.BitLength())
+
+ def testBigNumBitLength(self):
+ big_prime = ssl_util.BigNum.FromLongNumber(15)
+ self.assertEqual(4, big_prime.BitLength())
+ big_prime = ssl_util.BigNum.FromLongNumber(16)
+ self.assertEqual(5, big_prime.BitLength())
+
+ def testBigNumAdds(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 + big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(5, big_num3.GetAsLong())
+
+ def testBigNumAddsInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 += big_num2
+ self.assertEqual(5, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumSubtracts(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(4)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 - big_num2
+ self.assertEqual(4, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(1, big_num3.GetAsLong())
+
+ def testBigNumSubtractsInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(4).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 -= big_num2
+ self.assertEqual(1, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumOperationsInPlaceRaisesValueErrorOnImmutableBigNums(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ self.assertRaises(ValueError, big_num1.__iadd__, big_num2)
+
+ def testBigNumMultiplies(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 * big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(6, big_num3.GetAsLong())
+
+ def testBigNumMultipliesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 *= big_num2
+ self.assertEqual(6, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumMods(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(5)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1 % big_num2
+ self.assertEqual(5, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(2, big_num3.GetAsLong())
+
+ def testBigNumModsInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(5).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 %= big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumExponentiates(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num3 = big_num1**big_num2
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(8, big_num3.GetAsLong())
+
+ def testBigNumExponentiatesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ big_num1 **= big_num2
+ self.assertEqual(8, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+
+ def testBigNumRShifts(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num1 = big_num >> 1
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(4, big_num.GetAsLong())
+
+ def testBigNumRShiftsInPlace(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num >>= 1
+ self.assertEqual(2, big_num.GetAsLong())
+
+ def testBigNumLShifts(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num1 = big_num << 1
+ self.assertEqual(8, big_num1.GetAsLong())
+ self.assertEqual(4, big_num.GetAsLong())
+
+ def testBigNumLShiftsInPlace(self):
+ big_num = ssl_util.BigNum.FromLongNumber(4)
+ big_num <<= 1
+ self.assertEqual(8, big_num.GetAsLong())
+
+ def testBigNumDivides(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(6)
+ big_num2 = ssl_util.BigNum.FromLongNumber(2)
+ self.assertEqual(3, (big_num1 / big_num2).GetAsLong())
+ self.assertEqual(6, big_num1.GetAsLong())
+ self.assertEqual(2, big_num2.GetAsLong())
+
+ def testBigNumDividesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(6)
+ big_num2 = ssl_util.BigNum.FromLongNumber(2)
+ big_num1 /= big_num2
+ self.assertEqual(3, big_num1.GetAsLong())
+ self.assertEqual(2, big_num2.GetAsLong())
+
+ def testBigNumDivisionByZeroRaisesAssertionError(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(6)
+ big_num2 = ssl_util.BigNum.FromLongNumber(0)
+ self.assertRaises(AssertionError, big_num1.__div__, big_num2)
+
+ def testBigNumDivisionRaisesValueErrorWhenThereIsARemainder(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(5)
+ big_num2 = ssl_util.BigNum.FromLongNumber(2)
+ self.assertRaises(ValueError, big_num1.__div__, big_num2)
+
+ def testBigNumModMultiplies(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(5)
+ big_num3 = big_num1.ModMul(big_num2, mod_big_num)
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(5, mod_big_num.GetAsLong())
+ self.assertEqual(1, big_num3.GetAsLong())
+
+ def testBigNumModMultipliesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(5)
+ big_num1.IModMul(big_num2, mod_big_num)
+ self.assertEqual(1, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(5, mod_big_num.GetAsLong())
+
+ def testBigNumModExponentiates(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2)
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(7)
+ big_num3 = big_num1.ModExp(big_num2, mod_big_num)
+ self.assertEqual(2, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(7, mod_big_num.GetAsLong())
+ self.assertEqual(1, big_num3.GetAsLong())
+
+ def testBigNumModExponentiatesInPlace(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(2).Mutable()
+ big_num2 = ssl_util.BigNum.FromLongNumber(3)
+ mod_big_num = ssl_util.BigNum.FromLongNumber(7)
+ big_num1.IModExp(big_num2, mod_big_num)
+ self.assertEqual(1, big_num1.GetAsLong())
+ self.assertEqual(3, big_num2.GetAsLong())
+ self.assertEqual(7, mod_big_num.GetAsLong())
+
+ def testBigNumGCD(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(20)
+ big_num3 = ssl_util.BigNum.FromLongNumber(15)
+ big_num4 = big_num2.GCD(big_num1)
+ big_num5 = big_num2.GCD(big_num3)
+ self.assertEqual(11, big_num1.GetAsLong())
+ self.assertEqual(20, big_num2.GetAsLong())
+ self.assertEqual(15, big_num3.GetAsLong())
+ self.assertEqual(1, big_num4.GetAsLong())
+ self.assertEqual(5, big_num5.GetAsLong())
+
+ def testBigNumModInverse(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num_mod = ssl_util.BigNum.FromLongNumber(20)
+ big_num_result = big_num1.ModInverse(big_num_mod)
+ self.assertEqual(11, big_num1.GetAsLong())
+ self.assertEqual(20, big_num_mod.GetAsLong())
+ self.assertEqual(11, big_num_result.GetAsLong())
+
+ def testBigNumModSqrt(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num_mod = ssl_util.BigNum.FromLongNumber(19)
+ big_num_result = big_num1.ModSqrt(big_num_mod)
+ self.assertEqual(11, big_num1.GetAsLong())
+ self.assertEqual(19, big_num_mod.GetAsLong())
+ self.assertEqual(7, big_num_result.GetAsLong())
+
+ def testBigNumModInverseInvalidForNotRelativelyPrimes(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(10)
+ big_num_mod = ssl_util.BigNum.FromLongNumber(20)
+ self.assertRaises(ValueError, big_num1.ModInverse, big_num_mod)
+ self.assertEqual(10, big_num1.GetAsLong())
+ self.assertEqual(20, big_num_mod.GetAsLong())
+
+ def testBigNumNegates(self):
+ big_num = ssl_util.BigNum.FromLongNumber(10)
+ big_num = big_num.ModNegate(ssl_util.BigNum.FromLongNumber(6))
+ self.assertEqual(2, big_num.GetAsLong())
+
+ def testBigNumAddsOne(self):
+ big_num = ssl_util.BigNum.FromLongNumber(10)
+ self.assertEqual(11, big_num.AddOne().GetAsLong())
+
+ def testBigNumSubtractOne(self):
+ big_num = ssl_util.BigNum.FromLongNumber(10)
+ self.assertEqual(9, big_num.SubtractOne().GetAsLong())
+
+ def testBigNumGeneratesRandsBetweenZeroAndGivenBigNum(self):
+ big_num = ssl_util.BigNum.FromLongNumber(3)
+ big_rand = big_num.GenerateRand()
+ self.assertTrue(0 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert
+
+ def testBigNumGeneratesZeroForRandWhenTheUpperBoundIsOne(self):
+ big_num = ssl_util.BigNum.FromLongNumber(1)
+ self._GenericRandomTestForCasesThatShouldReturnOneNum(
+ ssl_util.BigNum.Zero(), big_num.GenerateRand
+ )
+
+ def testBigNumGeneratesRandsBetweenStartAndGivenBigNum(self):
+ big_num = ssl_util.BigNum.FromLongNumber(3)
+ big_rand = big_num.GenerateRandWithStart(ssl_util.BigNum.FromLongNumber(1))
+ self.assertTrue(1 <= big_rand.GetAsLong() < 3) # pylint: disable=g-generic-assert
+
+ def testBigNumGeneratesSingleRandWhenIntervalIsOne(self):
+ start = ssl_util.BigNum.FromLongNumber(2**30 - 1)
+ end = ssl_util.BigNum.FromLongNumber(2**30)
+ self._GenericRandomTestForCasesThatShouldReturnOneNum(
+ start, end.GenerateRandWithStart, start
+ )
+
+ def testBigNumIsBitSet(self):
+ big_num = ssl_util.BigNum.FromLongNumber(11)
+ self.assertTrue(big_num.IsBitSet(0))
+ self.assertTrue(big_num.IsBitSet(1))
+ self.assertFalse(big_num.IsBitSet(2))
+ self.assertTrue(big_num.IsBitSet(3))
+
+ def testBigNumEq(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(11)
+ self.assertEqual(big_num1, big_num2)
+
+ def testBigNumNeq(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(12)
+ self.assertNotEqual(big_num1, big_num2)
+
+ def testBigNumGt(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(12)
+ self.assertGreater(big_num2, big_num1)
+
+ def testBigNumGtEq(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ big_num2 = ssl_util.BigNum.FromLongNumber(11)
+ big_num3 = ssl_util.BigNum.FromLongNumber(12)
+ self.assertGreaterEqual(big_num2, big_num1)
+ self.assertGreaterEqual(big_num3, big_num2)
+
+ def testBigNumComparisonWithOtherTypesRaisesValueError(self):
+ big_num1 = ssl_util.BigNum.FromLongNumber(11)
+ self.assertRaises(ValueError, big_num1.__lt__, 11)
+
+ def testClonesCreatesANewBigNum(self):
+ big_num = ssl_util.BigNum.FromLongNumber(0).Mutable()
+ clone_big_num = big_num.Clone()
+ big_num += ssl_util.BigNum.One()
+ self.assertEqual(ssl_util.BigNum.Zero(), clone_big_num)
+ self.assertEqual(ssl_util.BigNum.One(), big_num)
+
+ def testBigNumCacheIsSingleton(self):
+ cache1 = ssl_util.BigNumCache(10)
+ cache2 = ssl_util.BigNumCache(11)
+ self.assertIs(cache1, cache2)
+
+ def testBigNumCacheReturnsTheSameCachedBigNum(self):
+ cache = ssl_util.BigNumCache(10)
+ self.assertIs(cache.Get(1), cache.Get(1))
+
+ def testBigNumCacheReturnsDifferentBigNumWhenCacheIsFull(self):
+ cache = ssl_util.BigNumCache(10)
+ for i in range(10):
+ cache.Get(i)
+ self.assertIsNot(cache.Get(11), cache.Get(11))
+
+ def testStringRepresentation(self):
+ big_num = ssl_util.BigNum.FromLongNumber(11)
+ self.assertEqual('11', '{}'.format(big_num))
+
+
+class _HashMock(object):
+
+ def __init__(self):
+ self.with_patch = patch('hashlib.sha512')
+
+ def __enter__(self):
+ hashlib_mock = self.with_patch.__enter__()
+ sha512_mock = mock.MagicMock()
+ hashlib_mock.return_value = sha512_mock
+ return sha512_mock, hashlib_mock
+
+ def __exit__(self, t, value, traceback):
+ self.with_patch.__exit__(t, value, traceback)
+
+
+class PRNGTest(unittest.TestCase):
+
+ def testPRNG(self):
+ with _HashMock() as (hash_mock, hashlib_mock):
+ hash_mock.digest.return_value = b'\x7f' + b'\x01' * 64
+ prng = PRNG(b'\x02' * 32)
+ self.assertEqual(0, prng.GetRand(2))
+ self.assertEqual(1, prng.GetRand(256))
+ self.assertEqual(2, prng.GetRand(257))
+ self.assertEqual(128, prng.GetRand(32768))
+ self.assertEqual(257, prng.GetRand(65536))
+ hash_mock.digest.assert_called_once_with()
+ hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x02' * 32)
+
+ def testGetNBitRandReturnsAtLeastUpperLimit(self):
+ with _HashMock() as (hash_mock, hashlib_mock):
+ hash_mock.digest.return_value = b'\x81\x82\xff\x05' + b'\x00' * 60
+ prng = PRNG(b'\x00' * 32)
+ self.assertEqual(5, prng.GetRand(129))
+ hash_mock.digest.assert_called_once_with()
+ hashlib_mock.assert_called_once_with(b'\x00' * 4 + b'\x00' * 32)
+
+ def testRaisesValueErrorWhenSeedIsNotAtLeastFourBytes(self):
+ self.assertRaises(ValueError, PRNG, b'\x00' * 31)
+
+ def testRaisesValueErrorWhenMaxNumberOfHashingIsDone(self):
+ prng = PRNG(b'\x00' * 32, 1)
+ upper_limit = 1 << 512
+ for _ in range(256):
+ prng.GetRand(upper_limit)
+ self.assertRaises(AssertionError, prng.GetRand, 2)
+ self.assertEqual(0, prng.GetRand(1))
+
+ def testGetsMoreBytesWithHashingUntilSufficientBytesArePresent(self):
+ with _HashMock() as (hash_mock, _):
+ hash_mock.digest.side_effect = [
+ b'\x80' + b'\x00' * 63,
+ b'\x00' * 64,
+ b'\x00' * 64,
+ ]
+ prng = PRNG(b'\x00' * 32, 1)
+ upper_limit = 1 << 1025
+ self.assertEqual(1 << 1024, prng.GetRand(upper_limit))
+ hash_mock.digest.assert_has_calls([call(), call(), call()])
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/private_join_and_compute/py/crypto_util/supported_curves.py b/private_join_and_compute/py/crypto_util/supported_curves.py
new file mode 100644
index 0000000..414389c
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/supported_curves.py
@@ -0,0 +1,32 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""A list of supported elliptic curves."""
+
+
+class SupportedCurve:
+ """A SupportedCurve helper class.
+
+ The class encapsulates a curve name as well as the curve ID, as encoded by
+ the OpenSSL enum in openssl/ec.h.
+ """
+
+ def __init__(self, curve_name: str, curve_id: int):
+ self.curve_name = curve_name
+ self.id = curve_id
+
+
+SupportedCurve.SECP256R1 = SupportedCurve('secp256r1', 415)
+SupportedCurve.SECP384R1 = SupportedCurve('secp384r1', 715)
diff --git a/private_join_and_compute/py/crypto_util/supported_hashes.py b/private_join_and_compute/py/crypto_util/supported_hashes.py
new file mode 100644
index 0000000..76d843a
--- /dev/null
+++ b/private_join_and_compute/py/crypto_util/supported_hashes.py
@@ -0,0 +1,37 @@
+# Copyright 2019 Google LLC.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""A list of supported hash functions."""
+
+import hashlib
+
+
+class HashType:
+ """A wrapper around a hash function."""
+
+ def __init__(self, bit_length: int, name: str):
+ self.bit_length = bit_length
+ self.name = name
+
+ def hash(self, data: bytes) -> int:
+ """Hashes a sequence of bytes to an integer."""
+ hasher = hashlib.new(self.name)
+ hasher.update(data)
+ return int(hasher.hexdigest(), 16)
+
+
+HashType.SHA256 = HashType(256, 'sha256')
+HashType.SHA384 = HashType(384, 'sha384')
+HashType.SHA512 = HashType(512, 'sha512')