diff options
author | Michael Butler <butlermichael@google.com> | 2019-08-09 16:44:42 -0700 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2019-10-28 16:50:18 -0700 |
commit | 869066c5fcc4aa785a857fbf8d4d5aa62338d2b1 (patch) | |
tree | bee69c7169f34a6cae421f1700643f79496eb846 | |
parent | a82ce77d0365f09adffd71430e2f0d5cfa15686e (diff) | |
download | ml-869066c5fcc4aa785a857fbf8d4d5aa62338d2b1.tar.gz |
Cleanup Manager's Device class
This CL includes the following changes:
* Change instances of "const char*" to "const std::string&"
* Make DriverDevice members constant
* Change listManifestByInterface to getAllHalInstanceNames
Elaborating on the last point, previously the NNAPI runtime would get
a handle to the defaultServiceManager1_2, then call its method
listManifestByInterface in order to get the names of all HAL instances.
getAllHalInstanceNames is a new, simpler HIDL call that performs both
of these actions and directly returns the names.
Fixes: 76116804
Test: mma
Test: NeuralNetworksTest_static
Test: CtsNNAPITestCases
Change-Id: Icf98281b033fb7bcd8b1d1bb4d02d398b0979081
Merged-In: Icf98281b033fb7bcd8b1d1bb4d02d398b0979081
(cherry picked from commit 1cc1fe6c49c1c5da85227b4bbe5201f7f30caf53)
-rw-r--r-- | nn/runtime/ExecutionPlan.cpp | 5 | ||||
-rw-r--r-- | nn/runtime/Manager.cpp | 198 | ||||
-rw-r--r-- | nn/runtime/Manager.h | 10 | ||||
-rw-r--r-- | nn/runtime/NeuralNetworks.cpp | 20 | ||||
-rw-r--r-- | nn/runtime/TypeManager.cpp | 2 | ||||
-rw-r--r-- | nn/runtime/VersionedInterfaces.cpp | 3 | ||||
-rw-r--r-- | nn/runtime/test/TestIntrospectionControl.cpp | 6 | ||||
-rw-r--r-- | nn/runtime/test/TestPartitioning.cpp | 18 |
8 files changed, 135 insertions, 127 deletions
diff --git a/nn/runtime/ExecutionPlan.cpp b/nn/runtime/ExecutionPlan.cpp index 901305216..2c640c35b 100644 --- a/nn/runtime/ExecutionPlan.cpp +++ b/nn/runtime/ExecutionPlan.cpp @@ -71,8 +71,9 @@ int compile(const Device& device, const ModelBuilder& model, int executionPrefer *preparedModel = nullptr; std::optional<CacheToken> cacheToken; - if (device.isCachingSupported() && token->ok() && token->updateFromString(device.getName()) && - token->updateFromString(device.getVersionString()) && + if (device.isCachingSupported() && token->ok() && + token->updateFromString(device.getName().c_str()) && + token->updateFromString(device.getVersionString().c_str()) && token->update(&executionPreference, sizeof(executionPreference)) && token->finish()) { cacheToken.emplace(token->getCacheToken()); } diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp index 610072016..229fdc083 100644 --- a/nn/runtime/Manager.cpp +++ b/nn/runtime/Manager.cpp @@ -52,28 +52,35 @@ const Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX // A Device with actual underlying driver class DriverDevice : public Device { public: - DriverDevice(std::string name, const sp<V1_0::IDevice>& device); - - // Returns true if succesfully initialized. - bool initialize(); - - const char* getName() const override { return mName.c_str(); } - const char* getVersionString() const override { return mVersionString.c_str(); } - int64_t getFeatureLevel() const override { return mInterface->getFeatureLevel(); } - int32_t getType() const override { return mInterface->getType(); } - hidl_vec<Extension> getSupportedExtensions() const override; + // Create a DriverDevice from a name and an IDevice. + // Returns nullptr on failure. + static std::shared_ptr<DriverDevice> create(std::string name, sp<V1_0::IDevice> device); + + // Prefer using DriverDevice::create + DriverDevice(std::string name, std::string versionString, + std::shared_ptr<VersionedIDevice> device, Capabilities capabilities, + std::vector<Extension> supportedExtensions, + std::pair<uint32_t, uint32_t> numCacheFiles); + + const std::string& getName() const override { return kName; } + const std::string& getVersionString() const override { return kVersionString; } + int64_t getFeatureLevel() const override { return kInterface->getFeatureLevel(); } + int32_t getType() const override { return kInterface->getType(); } + const std::vector<Extension>& getSupportedExtensions() const override { + return kSupportedExtensions; + } void getSupportedOperations(const MetaModel& metaModel, hidl_vec<bool>* supportedOperations) const override; PerformanceInfo getPerformance(OperandType type) const override; PerformanceInfo getRelaxedFloat32toFloat16PerformanceScalar() const override { - return mCapabilities.relaxedFloat32toFloat16PerformanceScalar; + return kCapabilities.relaxedFloat32toFloat16PerformanceScalar; } PerformanceInfo getRelaxedFloat32toFloat16PerformanceTensor() const override { - return mCapabilities.relaxedFloat32toFloat16PerformanceTensor; + return kCapabilities.relaxedFloat32toFloat16PerformanceTensor; } bool isCachingSupported() const override { // Caching is supported if either of numModelCache or numDataCache is greater than 0. - return mNumCacheFiles.first > 0 || mNumCacheFiles.second > 0; + return kNumCacheFiles.first > 0 || kNumCacheFiles.second > 0; } std::pair<int, std::shared_ptr<PreparedModel>> prepareModel( @@ -88,12 +95,12 @@ class DriverDevice : public Device { std::pair<int, std::shared_ptr<PreparedModel>> prepareModelFromCacheInternal( const std::string& cacheDir, const CacheToken& token) const; - std::string mName; - std::string mVersionString; - const std::shared_ptr<VersionedIDevice> mInterface; - Capabilities mCapabilities; - hidl_vec<Extension> mSupportedExtensions; - std::pair<uint32_t, uint32_t> mNumCacheFiles; + const std::string kName; + const std::string kVersionString; + const std::shared_ptr<VersionedIDevice> kInterface; + const Capabilities kCapabilities; + const std::vector<Extension> kSupportedExtensions; + const std::pair<uint32_t, uint32_t> kNumCacheFiles; #ifdef NN_DEBUGGABLE // For debugging: behavior of IDevice::getSupportedOperations for SampleDriver. @@ -124,70 +131,79 @@ class DriverPreparedModel : public PreparedModel { const std::shared_ptr<VersionedIPreparedModel> mPreparedModel; }; -DriverDevice::DriverDevice(std::string name, const sp<V1_0::IDevice>& device) - : mName(std::move(name)), mInterface(VersionedIDevice::create(mName, device)) {} - -bool DriverDevice::initialize() { +DriverDevice::DriverDevice(std::string name, std::string versionString, + std::shared_ptr<VersionedIDevice> device, Capabilities capabilities, + std::vector<Extension> supportedExtensions, + std::pair<uint32_t, uint32_t> numCacheFiles) + : kName(std::move(name)), + kVersionString(std::move(versionString)), + kInterface(std::move(device)), + kCapabilities(std::move(capabilities)), + kSupportedExtensions(std::move(supportedExtensions)), + kNumCacheFiles(numCacheFiles) { #ifdef NN_DEBUGGABLE static const char samplePrefix[] = "sample"; - - mSupported = (mName.substr(0, sizeof(samplePrefix) - 1) == samplePrefix) - ? getProp("debug.nn.sample.supported") - : 0; + if (kName.substr(0, sizeof(samplePrefix) - 1) == samplePrefix) { + mSupported = getProp("debug.nn.sample.supported"); + } #endif // NN_DEBUGGABLE +} - ErrorStatus status = ErrorStatus::GENERAL_FAILURE; - - if (mInterface == nullptr) { - LOG(ERROR) << "DriverDevice contains invalid interface object."; - return false; +std::shared_ptr<DriverDevice> DriverDevice::create(std::string name, sp<V1_0::IDevice> device) { + CHECK(device != nullptr); + std::shared_ptr<VersionedIDevice> versionedDevice = + VersionedIDevice::create(name, std::move(device)); + if (versionedDevice == nullptr) { + LOG(ERROR) << "DriverDevice::create failed to create VersionedIDevice object for service " + << name; + return nullptr; } - std::tie(status, mCapabilities) = mInterface->getCapabilities(); - if (status != ErrorStatus::NONE) { - LOG(ERROR) << "IDevice::getCapabilities returned the error " << toString(status); - return false; + auto [capabilitiesStatus, capabilities] = versionedDevice->getCapabilities(); + if (capabilitiesStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getCapabilities returned the error " + << toString(capabilitiesStatus); + return nullptr; } - VLOG(MANAGER) << "Capab " << toString(mCapabilities); + VLOG(MANAGER) << "Capab " << toString(capabilities); - std::tie(status, mVersionString) = mInterface->getVersionString(); + const auto [versionStatus, versionString] = versionedDevice->getVersionString(); // TODO(miaowang): add a validation test case for in case of error. - if (status != ErrorStatus::NONE) { - LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(status); - return false; + if (versionStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(versionStatus); + return nullptr; } - std::tie(status, mSupportedExtensions) = mInterface->getSupportedExtensions(); - if (status != ErrorStatus::NONE) { - LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " << toString(status); - return false; + const auto [extensionsStatus, supportedExtensions] = versionedDevice->getSupportedExtensions(); + if (extensionsStatus != ErrorStatus::NONE) { + LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " + << toString(extensionsStatus); + return nullptr; } - std::tie(status, mNumCacheFiles.first, mNumCacheFiles.second) = - mInterface->getNumberOfCacheFilesNeeded(); - if (status != ErrorStatus::NONE) { + const auto [cacheFilesStatus, numModelCacheFiles, numDataCacheFiles] = + versionedDevice->getNumberOfCacheFilesNeeded(); + if (cacheFilesStatus != ErrorStatus::NONE) { LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned the error " - << toString(status); - return false; + << toString(cacheFilesStatus); + return nullptr; } // The following limit is enforced by VTS constexpr uint32_t maxNumCacheFiles = static_cast<uint32_t>(Constant::MAX_NUMBER_OF_CACHE_FILES); - if (mNumCacheFiles.first > maxNumCacheFiles || mNumCacheFiles.second > maxNumCacheFiles) { + if (numModelCacheFiles > maxNumCacheFiles || numDataCacheFiles > maxNumCacheFiles) { LOG(ERROR) << "IDevice::getNumberOfCacheFilesNeeded returned invalid number of cache files: " "numModelCacheFiles = " - << mNumCacheFiles.first << ", numDataCacheFiles = " << mNumCacheFiles.second + << numModelCacheFiles << ", numDataCacheFiles = " << numDataCacheFiles << ", maxNumCacheFiles = " << maxNumCacheFiles; - return false; + return nullptr; } - return true; -} - -hidl_vec<Extension> DriverDevice::getSupportedExtensions() const { - return mSupportedExtensions; + return std::make_shared<DriverDevice>( + std::move(name), versionString, std::move(versionedDevice), std::move(capabilities), + supportedExtensions, std::make_pair(numModelCacheFiles, numDataCacheFiles)); } void DriverDevice::getSupportedOperations(const MetaModel& metaModel, @@ -195,7 +211,7 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel, // Query the driver for what it can do. ErrorStatus status = ErrorStatus::GENERAL_FAILURE; hidl_vec<bool> supportedOperations; - std::tie(status, supportedOperations) = mInterface->getSupportedOperations(metaModel); + std::tie(status, supportedOperations) = kInterface->getSupportedOperations(metaModel); const Model& hidlModel = metaModel.getModel(); if (status != ErrorStatus::NONE) { @@ -222,7 +238,7 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel, return; } - const uint32_t baseAccumulator = std::hash<std::string>{}(mName); + const uint32_t baseAccumulator = std::hash<std::string>{}(kName); for (size_t operationIndex = 0; operationIndex < outSupportedOperations->size(); operationIndex++) { if (!(*outSupportedOperations)[operationIndex]) { @@ -256,7 +272,7 @@ void DriverDevice::getSupportedOperations(const MetaModel& metaModel, } PerformanceInfo DriverDevice::getPerformance(OperandType type) const { - return lookup(mCapabilities.operandPerformance, type); + return lookup(kCapabilities.operandPerformance, type); } // Opens cache file by filename and sets the handle to the opened fd. Returns false on fail. The @@ -326,7 +342,7 @@ static bool getCacheHandles(const std::string& cacheDir, const CacheToken& token static std::pair<int, std::shared_ptr<PreparedModel>> prepareModelCheck( ErrorStatus status, const std::shared_ptr<VersionedIPreparedModel>& preparedModel, - const char* prepareName, const char* serviceName) { + const char* prepareName, const std::string& serviceName) { if (status != ErrorStatus::NONE) { LOG(ERROR) << prepareName << " on " << serviceName << " failed: " << "prepareReturnStatus=" << toString(status); @@ -348,7 +364,7 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelInterna hidl_vec<hidl_handle> modelCache, dataCache; if (!maybeToken.has_value() || - !getCacheHandles(cacheDir, *maybeToken, mNumCacheFiles, + !getCacheHandles(cacheDir, *maybeToken, kNumCacheFiles, /*createIfNotExist=*/true, &modelCache, &dataCache)) { modelCache.resize(0); dataCache.resize(0); @@ -357,9 +373,9 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelInterna static const CacheToken kNullToken{}; const CacheToken token = maybeToken.value_or(kNullToken); const auto [status, preparedModel] = - mInterface->prepareModel(model, preference, modelCache, dataCache, token); + kInterface->prepareModel(model, preference, modelCache, dataCache, token); - return prepareModelCheck(status, preparedModel, "prepareModel", getName()); + return prepareModelCheck(status, preparedModel, "prepareModel", kName); } std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelFromCacheInternal( @@ -369,15 +385,15 @@ std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModelFromCac VLOG(COMPILATION) << "prepareModelFromCache"; hidl_vec<hidl_handle> modelCache, dataCache; - if (!getCacheHandles(cacheDir, token, mNumCacheFiles, + if (!getCacheHandles(cacheDir, token, kNumCacheFiles, /*createIfNotExist=*/false, &modelCache, &dataCache)) { return {ANEURALNETWORKS_OP_FAILED, nullptr}; } const auto [status, preparedModel] = - mInterface->prepareModelFromCache(modelCache, dataCache, token); + kInterface->prepareModelFromCache(modelCache, dataCache, token); - return prepareModelCheck(status, preparedModel, "prepareModelFromCache", getName()); + return prepareModelCheck(status, preparedModel, "prepareModelFromCache", kName); } std::pair<int, std::shared_ptr<PreparedModel>> DriverDevice::prepareModel( @@ -575,11 +591,13 @@ class CpuDevice : public Device { return instance; } - const char* getName() const override { return kName.c_str(); } - const char* getVersionString() const override { return kVersionString.c_str(); } + const std::string& getName() const override { return kName; } + const std::string& getVersionString() const override { return kVersionString; } int64_t getFeatureLevel() const override { return kFeatureLevel; } int32_t getType() const override { return ANEURALNETWORKS_DEVICE_CPU; } - hidl_vec<Extension> getSupportedExtensions() const override { return {/* No extensions. */}; } + const std::vector<Extension>& getSupportedExtensions() const override { + return kSupportedExtensions; + } void getSupportedOperations(const MetaModel& metaModel, hidl_vec<bool>* supportedOperations) const override; PerformanceInfo getPerformance(OperandType) const override { return kPerformance; } @@ -604,6 +622,7 @@ class CpuDevice : public Device { // Since the performance is a ratio compared to the CPU performance, // by definition the performance of the CPU is 1.0. const PerformanceInfo kPerformance = {.execTime = 1.0f, .powerUsage = 1.0f}; + const std::vector<Extension> kSupportedExtensions{/* No extensions. */}; }; // A special abstracted PreparedModel for the CPU, constructed by CpuDevice. @@ -755,42 +774,33 @@ std::shared_ptr<Device> DeviceManager::getCpuDevice() { std::shared_ptr<Device> DeviceManager::forTest_makeDriverDevice(const std::string& name, const sp<V1_0::IDevice>& device) { - auto driverDevice = std::make_shared<DriverDevice>(name, device); - CHECK(driverDevice->initialize()); + const auto driverDevice = DriverDevice::create(name, device); + CHECK(driverDevice != nullptr); return driverDevice; } void DeviceManager::findAvailableDevices() { - using ::android::hidl::manager::V1_2::IServiceManager; VLOG(MANAGER) << "findAvailableDevices"; - sp<IServiceManager> manager = hardware::defaultServiceManager1_2(); - if (manager == nullptr) { - LOG(ERROR) << "Unable to open defaultServiceManager"; - return; + // register driver devices + const auto names = hardware::getAllHalInstanceNames(V1_0::IDevice::descriptor); + for (const auto& name : names) { + VLOG(MANAGER) << "Found interface " << name; + sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); + if (device == nullptr) { + LOG(ERROR) << "Got a null IDEVICE for " << name; + continue; + } + registerDevice(name, device); } - manager->listManifestByInterface( - V1_0::IDevice::descriptor, [this](const hidl_vec<hidl_string>& names) { - for (const auto& name : names) { - VLOG(MANAGER) << "Found interface " << name.c_str(); - sp<V1_0::IDevice> device = V1_0::IDevice::getService(name); - if (device == nullptr) { - LOG(ERROR) << "Got a null IDEVICE for " << name.c_str(); - continue; - } - registerDevice(name.c_str(), device); - } - }); - // register CPU fallback device mDevices.push_back(CpuDevice::get()); mDevicesCpuOnly.push_back(CpuDevice::get()); } -void DeviceManager::registerDevice(const char* name, const sp<V1_0::IDevice>& device) { - auto d = std::make_shared<DriverDevice>(name, device); - if (d->initialize()) { +void DeviceManager::registerDevice(const std::string& name, const sp<V1_0::IDevice>& device) { + if (const auto d = DriverDevice::create(name, device)) { mDevices.push_back(d); } } diff --git a/nn/runtime/Manager.h b/nn/runtime/Manager.h index 82042d108..605931a60 100644 --- a/nn/runtime/Manager.h +++ b/nn/runtime/Manager.h @@ -68,11 +68,11 @@ class Device { virtual ~Device() = default; // Introspection methods returning device information - virtual const char* getName() const = 0; - virtual const char* getVersionString() const = 0; + virtual const std::string& getName() const = 0; + virtual const std::string& getVersionString() const = 0; virtual int64_t getFeatureLevel() const = 0; virtual int32_t getType() const = 0; - virtual hal::hidl_vec<hal::Extension> getSupportedExtensions() const = 0; + virtual const std::vector<hal::Extension>& getSupportedExtensions() const = 0; // See the MetaModel class in MetaModel.h for more details. virtual void getSupportedOperations(const MetaModel& metaModel, @@ -142,7 +142,7 @@ class DeviceManager { } // Register a test device. - void forTest_registerDevice(const char* name, const sp<hal::V1_0::IDevice>& device) { + void forTest_registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device) { registerDevice(name, device); } @@ -166,7 +166,7 @@ class DeviceManager { DeviceManager(); // Adds a device for the manager to use. - void registerDevice(const char* name, const sp<hal::V1_0::IDevice>& device); + void registerDevice(const std::string& name, const sp<hal::V1_0::IDevice>& device); void findAvailableDevices(); diff --git a/nn/runtime/NeuralNetworks.cpp b/nn/runtime/NeuralNetworks.cpp index 17667fd50..f4bce637b 100644 --- a/nn/runtime/NeuralNetworks.cpp +++ b/nn/runtime/NeuralNetworks.cpp @@ -582,7 +582,7 @@ int ANeuralNetworksDevice_getName(const ANeuralNetworksDevice* device, const cha return ANEURALNETWORKS_UNEXPECTED_NULL; } const Device* d = reinterpret_cast<const Device*>(device); - *name = d->getName(); + *name = d->getName().c_str(); return ANEURALNETWORKS_NO_ERROR; } @@ -592,7 +592,7 @@ int ANeuralNetworksDevice_getVersion(const ANeuralNetworksDevice* device, const return ANEURALNETWORKS_UNEXPECTED_NULL; } const Device* d = reinterpret_cast<const Device*>(device); - *version = d->getVersionString(); + *version = d->getVersionString().c_str(); return ANEURALNETWORKS_NO_ERROR; } @@ -1186,16 +1186,12 @@ int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* devic return ANEURALNETWORKS_UNEXPECTED_NULL; } - Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(device)); - hidl_vec<Extension> supportedExtensions = d->getSupportedExtensions(); - - *isExtensionSupported = false; - for (const Extension& supportedExtension : supportedExtensions) { - if (supportedExtension.name == extensionName) { - *isExtensionSupported = true; - break; - } - } + const Device* d = reinterpret_cast<const Device*>(device); + const auto& supportedExtensions = d->getSupportedExtensions(); + *isExtensionSupported = std::any_of(supportedExtensions.begin(), supportedExtensions.end(), + [extensionName](const auto& supportedExtension) { + return supportedExtension.name == extensionName; + }); return ANEURALNETWORKS_NO_ERROR; } diff --git a/nn/runtime/TypeManager.cpp b/nn/runtime/TypeManager.cpp index 854e28ab9..40b0e3d44 100644 --- a/nn/runtime/TypeManager.cpp +++ b/nn/runtime/TypeManager.cpp @@ -181,7 +181,7 @@ bool TypeManager::isExtensionsUseAllowed(const AppPackageInfo& appPackageInfo, void TypeManager::findAvailableExtensions() { for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) { - for (const Extension extension : device->getSupportedExtensions()) { + for (const Extension& extension : device->getSupportedExtensions()) { registerExtension(extension, device->getName()); } } diff --git a/nn/runtime/VersionedInterfaces.cpp b/nn/runtime/VersionedInterfaces.cpp index ba6e2af7c..16ebe0ae8 100644 --- a/nn/runtime/VersionedInterfaces.cpp +++ b/nn/runtime/VersionedInterfaces.cpp @@ -306,6 +306,8 @@ std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExec std::shared_ptr<VersionedIDevice> VersionedIDevice::create(std::string serviceName, sp<V1_0::IDevice> device) { + CHECK(device != nullptr) << "VersionedIDevice::create passed invalid device object."; + auto core = Core::create(std::move(device)); if (!core.has_value()) { LOG(ERROR) << "VersionedIDevice::create failed to create Core."; @@ -320,7 +322,6 @@ VersionedIDevice::VersionedIDevice(std::string serviceName, Core core) : mServiceName(std::move(serviceName)), mCore(std::move(core)) {} std::optional<VersionedIDevice::Core> VersionedIDevice::Core::create(sp<V1_0::IDevice> device) { - // verify input CHECK(device != nullptr) << "VersionedIDevice::Core::create passed invalid device object."; // create death handler object diff --git a/nn/runtime/test/TestIntrospectionControl.cpp b/nn/runtime/test/TestIntrospectionControl.cpp index 9d0cbe6c3..76a6f5ee2 100644 --- a/nn/runtime/test/TestIntrospectionControl.cpp +++ b/nn/runtime/test/TestIntrospectionControl.cpp @@ -229,9 +229,9 @@ TEST_F(IntrospectionControlTest, SimpleAddModel) { // Verify that the mCompilation is actually using the "test-all" device. CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(mCompilation); - const char* deviceNameBuffer = + const std::string& deviceNameBuffer = c->forTest_getExecutionPlan().forTest_simpleGetDevice()->getName(); - EXPECT_TRUE(driverName.compare(deviceNameBuffer) == 0); + EXPECT_EQ(driverName, deviceNameBuffer); float input1[2] = {1.0f, 2.0f}; float input2[2] = {3.0f, 4.0f}; @@ -655,7 +655,7 @@ TEST_P(TimingTest, Test) { switch (kDriverKind) { case DriverKind::CPU: { // There should be only one driver -- the CPU - const char* name = DeviceManager::get()->getDrivers()[0]->getName(); + const std::string& name = DeviceManager::get()->getDrivers()[0]->getName(); ASSERT_TRUE(selectDeviceByName(name)); break; } diff --git a/nn/runtime/test/TestPartitioning.cpp b/nn/runtime/test/TestPartitioning.cpp index b01e58015..2c774a6e8 100644 --- a/nn/runtime/test/TestPartitioning.cpp +++ b/nn/runtime/test/TestPartitioning.cpp @@ -1254,7 +1254,7 @@ TEST_F(PartitioningTest, SimpleModel) { ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(planA.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(planA.forTest_simpleGetDevice()->getName(), "good"); + ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "good"); // Simple partition (two devices are each capable of everything, none better than CPU). // No need to compare the original model to the model from the plan -- we @@ -1342,7 +1342,7 @@ TEST_F(PartitioningTest, SliceModel) { ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(planA.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(planA.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(planA.forTest_simpleGetDevice()->getName(), "V1_2"); + ASSERT_EQ(planA.forTest_simpleGetDevice()->getName(), "V1_2"); // Compound partition (V1_0, V1_1, V1_2 devices are available, in decreasing // order of performance; model is distributed across all three devices). @@ -1442,7 +1442,7 @@ TEST_F(PartitioningTest, SliceModelToEmpty) { ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(plan.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "V1_2"); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "V1_2"); } TEST_F(PartitioningTest, Cpu) { @@ -1682,7 +1682,7 @@ TEST_F(PartitioningTest, OemOperations) { const auto& planBestOEM = compilationBestOEM.getExecutionPlan(); ASSERT_EQ(planBestOEM.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); ASSERT_NE(planBestOEM.forTest_simpleGetDevice().get(), nullptr); - ASSERT_STREQ(planBestOEM.forTest_simpleGetDevice()->getName(), "goodOEM"); + ASSERT_EQ(planBestOEM.forTest_simpleGetDevice()->getName(), "goodOEM"); // Verify that we get an error if no driver can run an OEM operation. const auto devicesNoOEM = makeDevices({{"noOEM", 0.5, ~0U, PartitioningDriver::OEMNo}}); @@ -1724,7 +1724,7 @@ TEST_F(PartitioningTest, RelaxedFP) { ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan), ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), expectDevice); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), expectDevice); }; ASSERT_NO_FATAL_FAILURE(TrivialTest(false, "f32")); @@ -1772,7 +1772,7 @@ TEST_F(PartitioningTest, Perf) { ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan), ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "good"); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "good"); } { @@ -1790,7 +1790,7 @@ TEST_F(PartitioningTest, Perf) { ASSERT_EQ(model.partitionTheWork(devices, ExecutePreference::PREFER_LOW_POWER, &plan), ANEURALNETWORKS_NO_ERROR); ASSERT_EQ(plan.forTest_getKind(), ExecutionPlan::Kind::SIMPLE); - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), "base"); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), "base"); } }; @@ -1853,7 +1853,7 @@ class CacheTest : public PartitioningTest { // Find the cache info for the device. const uint8_t* token = nullptr; if (plan.forTest_getKind() == ExecutionPlan::Kind::SIMPLE) { - ASSERT_STREQ(plan.forTest_simpleGetDevice()->getName(), deviceName); + ASSERT_EQ(plan.forTest_simpleGetDevice()->getName(), deviceName); token = plan.forTest_simpleGetCacheToken(); } else if (plan.forTest_getKind() == ExecutionPlan::Kind::COMPOUND) { const auto& steps = plan.forTest_compoundGetSteps(); @@ -1861,7 +1861,7 @@ class CacheTest : public PartitioningTest { for (const auto& step : steps) { // In general, two or more partitions can be on the same device. However, this will // not happen on the test models with only 2 operations. - if (strcmp(step->getDevice()->getName(), deviceName) == 0) { + if (step->getDevice()->getName() == deviceName) { ASSERT_FALSE(found); token = step->forTest_getCacheToken(); found = true; |