summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Gross <dgross@google.com>2019-05-30 09:06:47 -0700
committerDavid Gross <dgross@google.com>2019-05-30 12:31:41 -0700
commit00fa8d35cbd38d5ea7cab48dc96dcb9851578aab (patch)
tree8c56e470d0f3ee0f177198055827e801922e5b6a
parent7d1ea8b70c42e0a4182a78630f032af0cee0bfb0 (diff)
downloadml-00fa8d35cbd38d5ea7cab48dc96dcb9851578aab.tar.gz
Partially recover from a driver crash
If a driver crashes, every object associated with that driver becomes "dead", and any method invocation on such an object fails with a transport error. In the NNAPI, this is a problem for IDevice and IPreparedModel objects. Without some mechanism to recover from a driver crash, all further uses of an IDevice or IPreparedModel will fail -- e.g., it's impossible to execute an already-compiled model, and it's impossible to create a new compiled model. The only way to recover from this is to restart the application. This fix addresses the first part of this problem. All references to IDevice in the runtime go through VersionedIDevice, so it sufficies to replace the IDevice reference in a VersionedIDevice when the IDevice dies. Therefore, it is now possible to create a new compiled model after a driver crash (the crash will appear to be a transient error). A previously-compiled model is still dead, and this fix does not address that problem. When we attempt to replace the IDevice, we use tryGetService() rather than getService(): Rather than waiting for the driver to become available, we recover it if it is available, and otherwise retain the behavior prior to this change -- i.e., the attempt to use the IDevice fails, and the runtime employs a fallback path if possible. This way we avoid a potentially long wait for the driver to come back up (up to 5 seconds, by default, per init start_period behavior). As an alternative approach, it might be possible to handle recovery by means of a death recipient, rather than during a VersionedIDevice method call. However, that alternative approach would probably result in more transient failures because of a crash, because the recovery would then be asynchronous with respect to calls that are vulnerable to a dead driver. Bug: 118623798 Test: NeuralNetworksTest_static Test: NeuralNetworksTest_mt_static Test: Ran NeuralNetworksTest_static --gtest_filter=TrivialTest.AddTwo --gtest_repeat=-1 and killed driver during the running; verified that there are no failures (we use the CPU fallback path) and that we eventually recover from the driver death (saw in the logcat that we run on device, then attempt recovery and fail several times and so run on CPU, then succeed in recovery and go back to running on device). Test: Modified VersionedIDevice::recoverable<> so that the first time we find a dead object, we sleep 20 seconds, allowing time for another thread to recover from the driver crash, so that the sleeper needs to tolerate the recovery already having happened. Ran NeuralNetworksTest_mt_static --gtest_filter=GeneratedTests.add --gtest_repeat=-1 and killed driver during the running; verified that there are no failures (we use the CPU fallback path) and that we took the recovery path (by observing that the sleep happened and by inspecting the logcat). Test: Modified NeuralNetworksTest_static TrivialTest.AddTwo to use introspection/control interface to force a particular driver; set debug.nn.partition to 2 to turn off CPU fallback; ran NeuralNetworksTest_static --gtest_filter=TrivialTest.AddTwo --gtest_repeat=-1 and killed driver during the running; verified that there are several failures (as we attempt recovery and fail several times) but that we eventually recover from the driver death (saw in the logcat that we went through the recovery path and that we go back to using the driver). Test: Modified each sample-* driver to sleep(10) when it begins its asynchronous execution; ran NeuralNetworksTest_static --gtest_filter=GeneratedTests.add with useCpuOnly = 0, computeMode = ComputeMode::ASYNC, allowSyncExecHal = 0 and killed driver and confirmed (1) that the runtime was not blocked and (2) that an appropriate log message was recorded. See http://ag/6575732. Test: Modified each sample-* driver to do asynchronous prepareModel and to sleep(10) when it begins its asynchronous preparation; ran NeuralNetworksTest_static --gtest_filter=GeneratedTests.add with useCpuOnly = 0, computeMode = ComputeMode::ASYNC, allowSyncExecHal = 0 and killed driver and confirmed (1) that the runtime was not blocked and (2) that an appropriate log message was recorded. See http://ag/6575732. Test: Modified each sample-* driver to return an error for launching an asynchronous call (tested execution and prepareModel separately), but not make the corresponding call to callback->notify; ran NeuralNetworksTest_static --gtest_filter=GeneratedTests.add with useCpuOnly = 0, computeMode = ComputeMode::ASYNC, allowSyncExecHal = 0 and confirmed that the execution succeeded and that appropriate messages were logged (preparation or execution failure followed by CPU fallback). See http://ag/7669359. Change-Id: I55b779bc2a38243d5df122433672a9f2e073c8b4
-rw-r--r--nn/runtime/Manager.cpp2
-rw-r--r--nn/runtime/VersionedInterfaces.cpp363
-rw-r--r--nn/runtime/VersionedInterfaces.h233
3 files changed, 459 insertions, 139 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index c5d7b1fad..b479d3964 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -96,7 +96,7 @@ class DriverDevice : public Device {
};
DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device)
- : mName(std::move(name)), mInterface(VersionedIDevice::create(device)) {}
+ : mName(std::move(name)), mInterface(VersionedIDevice::create(mName, device)) {}
// TODO: handle errors from initialize correctly
bool DriverDevice::initialize() {
diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp
index d0439b213..0a359582d 100644
--- a/nn/runtime/VersionedInterfaces.cpp
+++ b/nn/runtime/VersionedInterfaces.cpp
@@ -25,6 +25,7 @@
#include <android-base/scopeguard.h>
#include <android-base/thread_annotations.h>
#include <functional>
+#include <type_traits>
namespace android {
namespace nn {
@@ -40,6 +41,10 @@ void sendFailureMessage(const sp<IPreparedModelCallback>& cb) {
cb->notify(ErrorStatus::GENERAL_FAILURE, nullptr);
}
+void sendFailureMessage(const sp<PreparedModelCallback>& cb) {
+ sendFailureMessage(static_cast<sp<IPreparedModelCallback>>(cb));
+}
+
void sendFailureMessage(const sp<IExecutionCallback>& cb) {
cb->notify(ErrorStatus::GENERAL_FAILURE);
}
@@ -219,18 +224,33 @@ bool VersionedIPreparedModel::operator!=(nullptr_t) const {
return mPreparedModelV1_0 != nullptr;
}
-std::shared_ptr<VersionedIDevice> VersionedIDevice::create(sp<V1_0::IDevice> device) {
+std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName,
+ sp<V1_0::IDevice> device) {
+ auto core = Core::create(std::move(device));
+ if (!core.has_value()) {
+ LOG(ERROR) << "VersionedIDevice::create -- Failed to create Core.";
+ return nullptr;
+ }
+
+ // return a valid VersionedIDevice object
+ return std::make_shared<VersionedIDevice>(std::move(serviceName), std::move(core.value()));
+}
+
+VersionedIDevice::VersionedIDevice(std::string serviceName, Core core)
+ : mServiceName(std::move(serviceName)), mCore(std::move(core)) {}
+
+std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) {
// verify input
if (!device) {
- LOG(ERROR) << "VersionedIDevice::create -- passed invalid device object.";
- return nullptr;
+ LOG(ERROR) << "VersionedIDevice::Core::create -- passed invalid device object.";
+ return {};
}
// create death handler object
sp<IDeviceDeathHandler> deathHandler = new (std::nothrow) IDeviceDeathHandler();
if (!deathHandler) {
- LOG(ERROR) << "VersionedIDevice::create -- Failed to create IDeviceDeathHandler.";
- return nullptr;
+ LOG(ERROR) << "VersionedIDevice::Core::create -- Failed to create IDeviceDeathHandler.";
+ return {};
}
// linkToDeath registers a callback that will be invoked on service death to
@@ -239,60 +259,192 @@ std::shared_ptr<VersionedIDevice> VersionedIDevice::create(sp<V1_0::IDevice> dev
// providing the response.
const Return<bool> ret = device->linkToDeath(deathHandler, 0);
if (!ret.isOk() || ret != true) {
- LOG(ERROR) << "VersionedIDevice::create -- Failed to register a death recipient for the "
- "IDevice object.";
- return nullptr;
+ LOG(ERROR)
+ << "VersionedIDevice::Core::create -- Failed to register a death recipient for the "
+ "IDevice object.";
+ return {};
}
- // return a valid VersionedIDevice object
- return std::make_shared<VersionedIDevice>(std::move(device), std::move(deathHandler));
+ // return a valid Core object
+ return Core(std::move(device), std::move(deathHandler));
}
// HIDL guarantees all V1_1 interfaces inherit from their corresponding V1_0 interfaces.
-VersionedIDevice::VersionedIDevice(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
+VersionedIDevice::Core::Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler)
: mDeviceV1_0(std::move(device)),
mDeviceV1_1(V1_1::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
mDeviceV1_2(V1_2::IDevice::castFrom(mDeviceV1_0).withDefault(nullptr)),
mDeathHandler(std::move(deathHandler)) {}
-VersionedIDevice::~VersionedIDevice() {
- // It is safe to ignore any errors resulting from this unlinkToDeath call
- // because the VersionedIDevice object is already being destroyed and its
- // underlying IDevice object is no longer being used by the NN runtime.
- mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
+VersionedIDevice::Core::~Core() {
+ if (mDeathHandler != nullptr) {
+ CHECK(mDeviceV1_0 != nullptr);
+ // It is safe to ignore any errors resulting from this unlinkToDeath call
+ // because the VersionedIDevice::Core object is already being destroyed and
+ // its underlying IDevice object is no longer being used by the NN runtime.
+ mDeviceV1_0->unlinkToDeath(mDeathHandler).isOk();
+ }
+}
+
+VersionedIDevice::Core::Core(Core&& other) noexcept
+ : mDeviceV1_0(std::move(other.mDeviceV1_0)),
+ mDeviceV1_1(std::move(other.mDeviceV1_1)),
+ mDeviceV1_2(std::move(other.mDeviceV1_2)),
+ mDeathHandler(std::move(other.mDeathHandler)) {
+ other.mDeathHandler = nullptr;
+}
+
+VersionedIDevice::Core& VersionedIDevice::Core::operator=(Core&& other) noexcept {
+ if (this != &other) {
+ mDeviceV1_0 = std::move(other.mDeviceV1_0);
+ mDeviceV1_1 = std::move(other.mDeviceV1_1);
+ mDeviceV1_2 = std::move(other.mDeviceV1_2);
+ mDeathHandler = std::move(other.mDeathHandler);
+ other.mDeathHandler = nullptr;
+ }
+ return *this;
+}
+
+template <typename T_IDevice>
+std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> VersionedIDevice::Core::getDeviceAndDeathHandler()
+ const {
+ return {getDevice<T_IDevice>(), mDeathHandler};
+}
+
+template <typename T_IDevice, typename T_Callback>
+Return<ErrorStatus> callProtected(
+ const char* context, const std::function<Return<ErrorStatus>(const sp<T_IDevice>&)>& fn,
+ const sp<T_IDevice>& device, const sp<T_Callback>& callback,
+ const sp<IDeviceDeathHandler>& deathHandler) {
+ const auto scoped = deathHandler->protectCallback(callback);
+ Return<ErrorStatus> ret = fn(device);
+ // Suppose there was a transport error. We have the following cases:
+ // 1. Either not due to a dead device, or due to a device that was
+ // already dead at the time of the call to protectCallback(). In
+ // this case, the callback was never signalled.
+ // 2. Due to a device that died after the call to protectCallback() but
+ // before fn() completed. In this case, the callback was (or will
+ // be) signalled by the deathHandler.
+ // Furthermore, what if there was no transport error, but the ErrorStatus is
+ // other than NONE? We'll conservatively signal the callback anyway, just in
+ // case the driver was sloppy and failed to do so.
+ if (!ret.isOk() || ret != ErrorStatus::NONE) {
+ // What if the deathHandler has signalled or will signal the callback?
+ // This is fine -- we're permitted to signal multiple times; and we're
+ // sending the same signal that the deathHandler does.
+ //
+ // What if the driver signalled the callback? Then this signal is
+ // ignored.
+
+ if (ret.isOk()) {
+ LOG(ERROR) << context << " returned " << toString(static_cast<ErrorStatus>(ret));
+ } else {
+ LOG(ERROR) << context << " failure: " << ret.description();
+ }
+ sendFailureMessage(callback);
+ }
+ callback->wait();
+ return ret;
+}
+template <typename T_Return, typename T_IDevice>
+Return<T_Return> callProtected(const char*,
+ const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
+ const sp<T_IDevice>& device, const std::nullptr_t&,
+ const sp<IDeviceDeathHandler>&) {
+ return fn(device);
+}
+
+template <typename T_Return, typename T_IDevice, typename T_Callback>
+Return<T_Return> VersionedIDevice::recoverable(
+ const char* context, const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
+ const T_Callback& callback) const EXCLUDES(mMutex) {
+ CHECK_EQ(callback == nullptr, (std::is_same_v<T_Callback, std::nullptr_t>));
+
+ sp<T_IDevice> device;
+ sp<IDeviceDeathHandler> deathHandler;
+ std::tie(device, deathHandler) = getDeviceAndDeathHandler<T_IDevice>();
+
+ Return<T_Return> ret = callProtected(context, fn, device, callback, deathHandler);
+
+ if (ret.isDeadObject()) {
+ {
+ std::unique_lock lock(mMutex);
+ // It's possible that another device has already done the recovery.
+ // It's harmless but wasteful for us to do so in this case.
+ auto pingReturn = mCore.getDevice<T_IDevice>()->ping();
+ if (pingReturn.isDeadObject()) {
+ VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context << ") -- Recovering "
+ << mServiceName;
+ sp<V1_0::IDevice> recoveredDevice = V1_0::IDevice::tryGetService(mServiceName);
+ if (recoveredDevice == nullptr) {
+ VLOG(DRIVER) << "VersionedIDevice::recoverable got a null IDEVICE for "
+ << mServiceName;
+ return ret;
+ }
+
+ auto core = Core::create(std::move(recoveredDevice));
+ if (!core.has_value()) {
+ LOG(ERROR) << "VersionedIDevice::recoverable -- Failed to create Core.";
+ return ret;
+ }
+
+ mCore = std::move(core.value());
+ } else {
+ VLOG(DRIVER) << "VersionedIDevice::recoverable(" << context
+ << ") -- Someone else recovered " << mServiceName;
+ // Might still have a transport error, which we need to check
+ // before pingReturn goes out of scope.
+ (void)pingReturn.isOk();
+ }
+ std::tie(device, deathHandler) = mCore.getDeviceAndDeathHandler<T_IDevice>();
+ }
+ ret = callProtected(context, fn, device, callback, deathHandler);
+ // It's possible that the device died again, but we're only going to
+ // attempt recovery once per call to recoverable().
+ }
+ return ret;
}
std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() {
const std::pair<ErrorStatus, Capabilities> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
std::pair<ErrorStatus, Capabilities> result;
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_2");
- Return<void> ret = mDeviceV1_2->getCapabilities_1_2(
- [&result](ErrorStatus error, const Capabilities& capabilities) {
- result = std::make_pair(error, capabilities);
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getCapabilities_1_2(
+ [&result](ErrorStatus error, const Capabilities& capabilities) {
+ result = std::make_pair(error, capabilities);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getCapabilities_1_2 failure: " << ret.description();
return {ErrorStatus::GENERAL_FAILURE, {}};
}
- } else if (mDeviceV1_1 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities_1_1");
- Return<void> ret = mDeviceV1_1->getCapabilities_1_1(
- [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) {
- // Time taken to convert capabilities is trivial
- result = std::make_pair(error, convertToV1_2(capabilities));
+ Return<void> ret = recoverable<void, V1_1::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_1::IDevice>& device) {
+ return device->getCapabilities_1_1(
+ [&result](ErrorStatus error, const V1_1::Capabilities& capabilities) {
+ // Time taken to convert capabilities is trivial
+ result = std::make_pair(error, convertToV1_2(capabilities));
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getCapabilities_1_1 failure: " << ret.description();
return kFailure;
}
- } else if (mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_0::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_INITIALIZATION, "getCapabilities");
- Return<void> ret = mDeviceV1_0->getCapabilities(
- [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) {
- // Time taken to convert capabilities is trivial
- result = std::make_pair(error, convertToV1_2(capabilities));
+ Return<void> ret = recoverable<void, V1_0::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_0::IDevice>& device) {
+ return device->getCapabilities(
+ [&result](ErrorStatus error, const V1_0::Capabilities& capabilities) {
+ // Time taken to convert capabilities is trivial
+ result = std::make_pair(error, convertToV1_2(capabilities));
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getCapabilities failure: " << ret.description();
@@ -309,18 +461,21 @@ std::pair<ErrorStatus, Capabilities> VersionedIDevice::getCapabilities() {
std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() {
const std::pair<ErrorStatus, hidl_vec<Extension>> kFailure = {ErrorStatus::GENERAL_FAILURE, {}};
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedExtensions");
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
std::pair<ErrorStatus, hidl_vec<Extension>> result;
- Return<void> ret = mDeviceV1_2->getSupportedExtensions(
- [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
- result = std::make_pair(error, extensions);
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getSupportedExtensions(
+ [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
+ result = std::make_pair(error, extensions);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
return kFailure;
}
return result;
- } else if (mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_0::IDevice>() != nullptr) {
return {ErrorStatus::NONE, {/* No extensions. */}};
} else {
LOG(ERROR) << "Device not available!";
@@ -354,11 +509,14 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
return std::make_pair(status, std::move(remappedSupported));
};
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations_1_2");
- Return<void> ret = mDeviceV1_2->getSupportedOperations_1_2(
- model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
- result = std::make_pair(error, supported);
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&model, &result](const sp<V1_2::IDevice>& device) {
+ return device->getSupportedOperations_1_2(
+ model, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
+ result = std::make_pair(error, supported);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedOperations_1_2 failure: " << ret.description();
@@ -367,7 +525,7 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
return result;
}
- if (mDeviceV1_1 != nullptr) {
+ if (getDevice<V1_1::IDevice>() != nullptr) {
const bool compliant = compliantWithV1_1(model);
if (compliant || slicer) {
V1_1::Model model11;
@@ -383,9 +541,13 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
}
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION,
"getSupportedOperations_1_1");
- Return<void> ret = mDeviceV1_1->getSupportedOperations_1_1(
- model11, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
- result = std::make_pair(error, supported);
+ Return<void> ret = recoverable<void, V1_1::IDevice>(
+ __FUNCTION__, [&model11, &result](const sp<V1_1::IDevice>& device) {
+ return device->getSupportedOperations_1_1(
+ model11,
+ [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
+ result = std::make_pair(error, supported);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedOperations_1_1 failure: " << ret.description();
@@ -398,7 +560,7 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
return result;
}
- if (mDeviceV1_0 != nullptr) {
+ if (getDevice<V1_0::IDevice>() != nullptr) {
const bool compliant = compliantWithV1_0(model);
if (compliant || slicer) {
V1_0::Model model10;
@@ -413,9 +575,13 @@ std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
std::tie(model10, submodelOperationIndexToModelOperationIndex) = *slice10;
}
NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedOperations");
- Return<void> ret = mDeviceV1_0->getSupportedOperations(
- model10, [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
- result = std::make_pair(error, supported);
+ Return<void> ret = recoverable<void, V1_0::IDevice>(
+ __FUNCTION__, [&model10, &result](const sp<V1_0::IDevice>& device) {
+ return device->getSupportedOperations(
+ model10,
+ [&result](ErrorStatus error, const hidl_vec<bool>& supported) {
+ result = std::make_pair(error, supported);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getSupportedOperations failure: " << ret.description();
@@ -443,12 +609,16 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
return kFailure;
}
- const auto scoped = mDeathHandler->protectCallback(callback);
-
// If 1.2 device, try preparing model
- if (mDeviceV1_2 != nullptr) {
- const Return<ErrorStatus> ret = mDeviceV1_2->prepareModel_1_2(model, preference, modelCache,
- dataCache, token, callback);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
+ __FUNCTION__,
+ [&model, &preference, &modelCache, &dataCache, &token,
+ &callback](const sp<V1_2::IDevice>& device) {
+ return device->prepareModel_1_2(model, preference, modelCache, dataCache, token,
+ callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel_1_2 failure: " << ret.description();
return kFailure;
@@ -462,7 +632,7 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
// If 1.1 device, try preparing model (requires conversion)
- if (mDeviceV1_1 != nullptr) {
+ if (getDevice<V1_1::IDevice>() != nullptr) {
bool compliant = false;
V1_1::Model model11;
{
@@ -477,8 +647,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
}
if (compliant) {
- const Return<ErrorStatus> ret =
- mDeviceV1_1->prepareModel_1_1(model11, preference, callback);
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_1::IDevice>(
+ __FUNCTION__,
+ [&model11, &preference, &callback](const sp<V1_1::IDevice>& device) {
+ return device->prepareModel_1_1(model11, preference, callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel_1_1 failure: " << ret.description();
return kFailure;
@@ -493,14 +667,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
makeVersionedIPreparedModel(callback->getPreparedModel())};
}
- // TODO: partition the model such that v1.2 ops are not passed to v1.1
- // device
LOG(ERROR) << "Could not handle prepareModel_1_1!";
return kFailure;
}
// If 1.0 device, try preparing model (requires conversion)
- if (mDeviceV1_0 != nullptr) {
+ if (getDevice<V1_0::IDevice>() != nullptr) {
bool compliant = false;
V1_0::Model model10;
{
@@ -515,7 +687,12 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
}
}
if (compliant) {
- const Return<ErrorStatus> ret = mDeviceV1_0->prepareModel(model10, callback);
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_0::IDevice>(
+ __FUNCTION__,
+ [&model10, &callback](const sp<V1_0::IDevice>& device) {
+ return device->prepareModel(model10, callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModel failure: " << ret.description();
return kFailure;
@@ -529,8 +706,6 @@ std::pair<ErrorStatus, std::shared_ptr<VersionedIPreparedModel>> VersionedIDevic
makeVersionedIPreparedModel(callback->getPreparedModel())};
}
- // TODO: partition the model such that v1.1 ops are not passed to v1.0
- // device
LOG(ERROR) << "Could not handle prepareModel!";
return kFailure;
}
@@ -553,11 +728,13 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
return kFailure;
}
- const auto scoped = mDeathHandler->protectCallback(callback);
-
- if (mDeviceV1_2 != nullptr) {
- const Return<ErrorStatus> ret =
- mDeviceV1_2->prepareModelFromCache(modelCache, dataCache, token, callback);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ const Return<ErrorStatus> ret = recoverable<ErrorStatus, V1_2::IDevice>(
+ __FUNCTION__,
+ [&modelCache, &dataCache, &token, &callback](const sp<V1_2::IDevice>& device) {
+ return device->prepareModelFromCache(modelCache, dataCache, token, callback);
+ },
+ callback);
if (!ret.isOk()) {
LOG(ERROR) << "prepareModelFromCache failure: " << ret.description();
return kFailure;
@@ -571,7 +748,7 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
return {callback->getStatus(), makeVersionedIPreparedModel(callback->getPreparedModel())};
}
- if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) {
+ if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
LOG(ERROR) << "prepareModelFromCache called on V1_1 or V1_0 device";
return kFailure;
}
@@ -581,12 +758,13 @@ VersionedIDevice::prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
}
DeviceStatus VersionedIDevice::getStatus() {
- if (mDeviceV1_0 == nullptr) {
+ if (getDevice<V1_0::IDevice>() == nullptr) {
LOG(ERROR) << "Device not available!";
return DeviceStatus::UNKNOWN;
}
- Return<DeviceStatus> ret = mDeviceV1_0->getStatus();
+ Return<DeviceStatus> ret = recoverable<DeviceStatus, V1_0::IDevice>(
+ __FUNCTION__, [](const sp<V1_0::IDevice>& device) { return device->getStatus(); });
if (!ret.isOk()) {
LOG(ERROR) << "getStatus failure: " << ret.description();
@@ -598,11 +776,11 @@ DeviceStatus VersionedIDevice::getStatus() {
int64_t VersionedIDevice::getFeatureLevel() {
constexpr int64_t kFailure = -1;
- if (mDeviceV1_2 != nullptr) {
+ if (getDevice<V1_2::IDevice>() != nullptr) {
return __ANDROID_API_Q__;
- } else if (mDeviceV1_1 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr) {
return __ANDROID_API_P__;
- } else if (mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_0::IDevice>() != nullptr) {
return __ANDROID_API_O_MR1__;
} else {
LOG(ERROR) << "Device not available!";
@@ -614,10 +792,12 @@ int32_t VersionedIDevice::getType() const {
constexpr int32_t kFailure = -1;
std::pair<ErrorStatus, DeviceType> result;
- if (mDeviceV1_2 != nullptr) {
- Return<void> ret =
- mDeviceV1_2->getType([&result](ErrorStatus error, DeviceType deviceType) {
- result = std::make_pair(error, deviceType);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getType([&result](ErrorStatus error, DeviceType deviceType) {
+ result = std::make_pair(error, deviceType);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getType failure: " << ret.description();
@@ -634,17 +814,20 @@ std::pair<ErrorStatus, hidl_string> VersionedIDevice::getVersionString() {
const std::pair<ErrorStatus, hidl_string> kFailure = {ErrorStatus::GENERAL_FAILURE, ""};
std::pair<ErrorStatus, hidl_string> result;
- if (mDeviceV1_2 != nullptr) {
- Return<void> ret = mDeviceV1_2->getVersionString(
- [&result](ErrorStatus error, const hidl_string& version) {
- result = std::make_pair(error, version);
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getVersionString(
+ [&result](ErrorStatus error, const hidl_string& version) {
+ result = std::make_pair(error, version);
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getVersion failure: " << ret.description();
return kFailure;
}
return result;
- } else if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
return {ErrorStatus::NONE, "UNKNOWN"};
} else {
LOG(ERROR) << "Could not handle getVersionString";
@@ -657,17 +840,21 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi
0, 0};
std::tuple<ErrorStatus, uint32_t, uint32_t> result;
- if (mDeviceV1_2 != nullptr) {
- Return<void> ret = mDeviceV1_2->getNumberOfCacheFilesNeeded(
- [&result](ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
- result = {error, numModelCache, numDataCache};
+ if (getDevice<V1_2::IDevice>() != nullptr) {
+ Return<void> ret = recoverable<void, V1_2::IDevice>(
+ __FUNCTION__, [&result](const sp<V1_2::IDevice>& device) {
+ return device->getNumberOfCacheFilesNeeded([&result](ErrorStatus error,
+ uint32_t numModelCache,
+ uint32_t numDataCache) {
+ result = {error, numModelCache, numDataCache};
+ });
});
if (!ret.isOk()) {
LOG(ERROR) << "getNumberOfCacheFilesNeeded failure: " << ret.description();
return kFailure;
}
return result;
- } else if (mDeviceV1_1 != nullptr || mDeviceV1_0 != nullptr) {
+ } else if (getDevice<V1_1::IDevice>() != nullptr || getDevice<V1_0::IDevice>() != nullptr) {
return {ErrorStatus::NONE, 0, 0};
} else {
LOG(ERROR) << "Could not handle getNumberOfCacheFilesNeeded";
@@ -676,11 +863,11 @@ std::tuple<ErrorStatus, uint32_t, uint32_t> VersionedIDevice::getNumberOfCacheFi
}
bool VersionedIDevice::operator==(nullptr_t) const {
- return mDeviceV1_0 == nullptr;
+ return getDevice<V1_0::IDevice>() == nullptr;
}
bool VersionedIDevice::operator!=(nullptr_t) const {
- return mDeviceV1_0 != nullptr;
+ return getDevice<V1_0::IDevice>() != nullptr;
}
} // namespace nn
diff --git a/nn/runtime/VersionedInterfaces.h b/nn/runtime/VersionedInterfaces.h
index 81a66013f..f0aaf6d6e 100644
--- a/nn/runtime/VersionedInterfaces.h
+++ b/nn/runtime/VersionedInterfaces.h
@@ -20,9 +20,14 @@
#include "HalInterfaces.h"
#include <android-base/macros.h>
+#include <cstddef>
+#include <functional>
#include <memory>
+#include <optional>
+#include <shared_mutex>
#include <string>
#include <tuple>
+#include <utility>
#include "Callbacks.h"
namespace android {
@@ -53,6 +58,9 @@ class VersionedIPreparedModel;
class VersionedIDevice {
DISALLOW_IMPLICIT_CONSTRUCTORS(VersionedIDevice);
+ // forward declaration of nested class
+ class Core;
+
public:
/**
* Create a VersionedIDevice object.
@@ -60,40 +68,26 @@ class VersionedIDevice {
* Prefer using this function over the constructor, as it adds more
* protections.
*
- * This call linksToDeath a hidl_death_recipient that can
- * proactively handle the case when the service containing the IDevice
- * object crashes.
- *
+ * @param serviceName The name of the service that provides "device".
* @param device A device object that is at least version 1.0 of the IDevice
* interface.
* @return A valid VersionedIDevice object, otherwise nullptr.
*/
- static std::shared_ptr<VersionedIDevice> create(sp<V1_0::IDevice> device);
+ static std::shared_ptr<VersionedIDevice> create(std::string serviceName,
+ sp<V1_0::IDevice> device);
/**
* Constructor for the VersionedIDevice object.
*
- * VersionedIDevice is constructed with the V1_0::IDevice object, which
- * represents a device that is at least v1.0 of the interface. The
- * constructor downcasts to the latest version of the IDevice interface, and
- * will default to using the latest version of all IDevice interface
- * methods automatically.
- *
- * @param device A device object that is at least version 1.0 of the IDevice
- * interface.
- * @param deathHandler A hidl_death_recipient that will proactively handle
- * the case when the service containing the IDevice
- * object crashes.
- */
- VersionedIDevice(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler);
-
- /**
- * Destructor for the VersionedIDevice object.
+ * VersionedIDevice will default to using the latest version of all IDevice
+ * interface methods automatically.
*
- * This destructor unlinksToDeath this object's hidl_death_recipient as it
- * no longer needs to handle the case where the IDevice's service crashes.
+ * @param serviceName The name of the service that provides core.getDevice<V1_0::IDevice>().
+ * @param core An object that encapsulates a V1_0::IDevice, any appropriate downcasts to
+ * newer interfaces, and a hidl_death_recipient that will proactively handle
+ * the case when the service containing the IDevice object crashes.
*/
- ~VersionedIDevice();
+ VersionedIDevice(std::string serviceName, Core core);
/**
* Gets the capabilities of a driver.
@@ -426,6 +420,13 @@ class VersionedIDevice {
std::tuple<ErrorStatus, uint32_t, uint32_t> getNumberOfCacheFilesNeeded();
/**
+ * Returns the name of the service that implements the driver
+ *
+ * @return serviceName The name of the service.
+ */
+ std::string getServiceName() const { return mServiceName; }
+
+ /**
* Returns whether this handle to an IDevice object is valid or not.
*
* @return bool true if V1_0::IDevice (which could be V1_1::IDevice) is
@@ -443,33 +444,165 @@ class VersionedIDevice {
private:
/**
- * All versions of IDevice are necessary because the driver could be v1.0,
- * v1.1, or a later version. All these pointers logically represent the same
- * object.
- *
- * The general strategy is: HIDL returns a V1_0 device object, which
- * (if not nullptr) could be v1.0, v1.1, or a greater version. The V1_0
- * object is then "dynamically cast" to a V1_1 object. If successful,
- * mDeviceV1_1 will point to the same object as mDeviceV1_0; otherwise,
- * mDeviceV1_1 will be nullptr.
- *
- * In general:
- * * If the device is truly v1.0, mDeviceV1_0 will point to a valid object
- * and mDeviceV1_1 will be nullptr.
- * * If the device is truly v1.1 or later, both mDeviceV1_0 and mDeviceV1_1
- * will point to the same valid object.
- *
- * Idiomatic usage: if mDeviceV1_1 is non-null, do V1_1 dispatch; otherwise,
- * do V1_0 dispatch.
- */
- sp<V1_0::IDevice> mDeviceV1_0;
- sp<V1_1::IDevice> mDeviceV1_1;
- sp<V1_2::IDevice> mDeviceV1_2;
-
- /**
- * HIDL callback to be invoked if the service for mDeviceV1_0 crashes.
+ * This is a utility class for VersionedIDevice that encapsulates a
+ * V1_0::IDevice, any appropriate downcasts to newer interfaces, and a
+ * hidl_death_recipient that will proactively handle the case when the
+ * service containing the IDevice object crashes.
+ *
+ * This is a convenience class to help VersionedIDevice recover from an
+ * IDevice object crash: It bundles together all the data that needs to
+ * change when recovering from a crash, and simplifies the process of
+ * instantiating that data (at VersionedIDevice creation time) and
+ * re-instantiating that data (at crash recovery time).
*/
- const sp<IDeviceDeathHandler> mDeathHandler;
+ class Core {
+ public:
+ /**
+ * Constructor for the Core object.
+ *
+ * Core is constructed with a V1_0::IDevice object, which represents a
+ * device that is at least v1.0 of the interface. The constructor
+ * downcasts to the latest version of the IDevice interface, allowing
+ * VersionedIDevice to default to using the latest version of all
+ * IDevice interface methods automatically.
+ *
+ * @param device A device object that is at least version 1.0 of the IDevice
+ * interface.
+ * @param deathHandler A hidl_death_recipient that will proactively handle
+ * the case when the service containing the IDevice
+ * object crashes.
+ */
+ Core(sp<V1_0::IDevice> device, sp<IDeviceDeathHandler> deathHandler);
+
+ /**
+ * Destructor for the Core object.
+ *
+ * This destructor unlinksToDeath this object's hidl_death_recipient as it
+ * no longer needs to handle the case where the IDevice's service crashes.
+ */
+ ~Core();
+
+ // Support move but not copy
+ Core(Core&&) noexcept;
+ Core& operator=(Core&&) noexcept;
+ Core(const Core&) = delete;
+ Core& operator=(const Core&) = delete;
+
+ /**
+ * Create a Core object.
+ *
+ * Prefer using this function over the constructor, as it adds more
+ * protections.
+ *
+ * This call linksToDeath a hidl_death_recipient that can
+ * proactively handle the case when the service containing the IDevice
+ * object crashes.
+ *
+ * @param device A device object that is at least version 1.0 of the IDevice
+ * interface.
+ * @return A valid Core object, otherwise nullopt.
+ */
+ static std::optional<Core> create(sp<V1_0::IDevice> device);
+
+ /**
+ * Returns sp<*::IDevice> that is a downcast of the sp<V1_0::IDevice>
+ * passed to the constructor. This will be nullptr if that IDevice is
+ * not actually of the specified downcast type.
+ */
+ template <typename T_IDevice>
+ sp<T_IDevice> getDevice() const;
+ template <>
+ sp<V1_0::IDevice> getDevice() const {
+ return mDeviceV1_0;
+ }
+ template <>
+ sp<V1_1::IDevice> getDevice() const {
+ return mDeviceV1_1;
+ }
+ template <>
+ sp<V1_2::IDevice> getDevice() const {
+ return mDeviceV1_2;
+ }
+
+ /**
+ * Returns sp<*::IDevice> (as per getDevice()) and the
+ * hidl_death_recipient that will proactively handle the case when the
+ * service containing the IDevice object crashes.
+ */
+ template <typename T_IDevice>
+ std::pair<sp<T_IDevice>, sp<IDeviceDeathHandler>> getDeviceAndDeathHandler() const;
+
+ private:
+ /**
+ * All versions of IDevice are necessary because the driver could be v1.0,
+ * v1.1, or a later version. All these pointers logically represent the same
+ * object.
+ *
+ * The general strategy is: HIDL returns a V1_0 device object, which
+ * (if not nullptr) could be v1.0, v1.1, or a greater version. The V1_0
+ * object is then "dynamically cast" to a V1_1 object. If successful,
+ * mDeviceV1_1 will point to the same object as mDeviceV1_0; otherwise,
+ * mDeviceV1_1 will be nullptr.
+ *
+ * In general:
+ * * If the device is truly v1.0, mDeviceV1_0 will point to a valid object
+ * and mDeviceV1_1 will be nullptr.
+ * * If the device is truly v1.1 or later, both mDeviceV1_0 and mDeviceV1_1
+ * will point to the same valid object.
+ *
+ * Idiomatic usage: if mDeviceV1_1 is non-null, do V1_1 dispatch; otherwise,
+ * do V1_0 dispatch.
+ */
+ sp<V1_0::IDevice> mDeviceV1_0;
+ sp<V1_1::IDevice> mDeviceV1_1;
+ sp<V1_2::IDevice> mDeviceV1_2;
+
+ /**
+ * HIDL callback to be invoked if the service for mDeviceV1_0 crashes.
+ *
+ * nullptr if this Core instance is a move victim and hence has no
+ * callback to be unlinked.
+ */
+ sp<IDeviceDeathHandler> mDeathHandler;
+ };
+
+ // This method retrieves the appropriate mCore.mDevice* field, under a read lock.
+ template <typename T_IDevice>
+ sp<T_IDevice> getDevice() const EXCLUDES(mMutex) {
+ std::shared_lock lock(mMutex);
+ return mCore.getDevice<T_IDevice>();
+ }
+
+ // This method retrieves the appropriate mCore.mDevice* fields, under a read lock.
+ template <typename T_IDevice>
+ auto getDeviceAndDeathHandler() const EXCLUDES(mMutex) {
+ std::shared_lock lock(mMutex);
+ return mCore.getDeviceAndDeathHandler<T_IDevice>();
+ }
+
+ // This method calls the function fn in a manner that supports recovering
+ // from a driver crash: If the driver implementation is dead because the
+ // driver crashed either before the call to fn or during the call to fn, we
+ // will attempt to obtain a new instance of the same driver and call fn
+ // again.
+ //
+ // If a callback is provided, this method protects it against driver death
+ // and waits for it (callback->wait()).
+ template <typename T_Return, typename T_IDevice, typename T_Callback = std::nullptr_t>
+ Return<T_Return> recoverable(const char* context,
+ const std::function<Return<T_Return>(const sp<T_IDevice>&)>& fn,
+ const T_Callback& callback = nullptr) const EXCLUDES(mMutex);
+
+ // The name of the service that implements the driver.
+ const std::string mServiceName;
+
+ // Guards access to mCore.
+ mutable std::shared_mutex mMutex;
+
+ // Data that can be rewritten during driver recovery. Guarded againt
+ // synchronous access by a mutex: Any number of concurrent read accesses is
+ // permitted, but a write access excludes all other accesses.
+ mutable Core mCore GUARDED_BY(mMutex);
};
/** This class wraps an IPreparedModel object of any version. */