aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohannes Reifferscheid <jreiffers@google.com>2024-05-10 02:26:21 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2024-05-10 02:36:59 -0700
commitda17b4ad6bd7d1347bad36376f218ae7f7511007 (patch)
tree859c3b39bdcbd5dae77f1da89b7af129661e9959
parentb0413907c4db1ea36afe4ccfeacadb9509b44e68 (diff)
downloadtensorflow-da17b4ad6bd7d1347bad36376f218ae7f7511007.tar.gz
Support scatter with unsigned indices.
PiperOrigin-RevId: 632428123
-rw-r--r--third_party/xla/xla/service/gpu/fusions/BUILD1
-rw-r--r--third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc60
-rw-r--r--third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc36
3 files changed, 66 insertions, 31 deletions
diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD
index 25431941a41..e78029820bd 100644
--- a/third_party/xla/xla/service/gpu/fusions/BUILD
+++ b/third_party/xla/xla/service/gpu/fusions/BUILD
@@ -429,6 +429,7 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
+ "@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
index c78a6b4f057..ded3ca4b6e4 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc
@@ -18,11 +18,9 @@ limitations under the License.
#include <optional>
#include <vector>
-#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
@@ -31,10 +29,12 @@ limitations under the License.
#include "mlir/IR/AffineExpr.h" // from @llvm-project
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
+#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/ValueRange.h" // from @llvm-project
+#include "mlir/Support/LLVM.h" // from @llvm-project
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
@@ -52,16 +52,13 @@ namespace xla {
namespace gpu {
namespace {
+namespace ma = mlir::arith;
+
using llvm::SmallVector;
using mlir::Location;
using mlir::OpBuilder;
using mlir::Value;
using mlir::ValueRange;
-using mlir::arith::AddIOp;
-using mlir::arith::AndIOp;
-using mlir::arith::CmpIOp;
-using mlir::arith::CmpIPredicate;
-using mlir::arith::ConstantIndexOp;
using mlir::func::ReturnOp;
using mlir::tensor::InsertOp;
using mlir_converter::ApplyAffineMap;
@@ -208,7 +205,6 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
b.setInsertionPointToStart(entry_function.addEntryBlock());
SmallVector<Value> result_tensors{entry_function.getArguments().back()};
- auto c0 = b.create<ConstantIndexOp>(0);
auto scatter_result = EmitThreadLoopNest(
b, result_tensors, thread_id_to_update_map,
@@ -224,39 +220,41 @@ absl::Status MlirScatterFusion::EmitEntryFunction(
// Extract slice offsets from scatter_indices operand, compute if the
// whole slice of scatter_update operand will fit into the output.
- mlir::Value is_in_bounds =
- b.create<mlir::arith::ConstantIntOp>(1, b.getI1Type());
+ mlir::Value in_bounds = b.create<ma::ConstantIntOp>(1, b.getI1Type());
SmallVector<Value, 4> indices{
llvm::ArrayRef(update_tensor_indices).drop_front()};
- for (int i = 0; i < scatter_operand->shape().rank(); ++i) {
- Value extracted_index = c0;
- if (i < scatter_indices->shape().dimensions(1)) {
- SmallVector<Value, 4> indices_tensor_indices = {
- update_tensor_indices.front(), b.create<ConstantIndexOp>(i)};
- extracted_index = ProvideParameter(
- root_computation, scatter, kScatterIndicesIndex,
- indices_tensor_indices, call_targets, entry_function, b);
- if (extracted_index.getType() != b.getIndexType()) {
- extracted_index = b.create<mlir::arith::IndexCastOp>(
- b.getIndexType(), extracted_index);
- }
+ for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) {
+ SmallVector<Value, 4> indices_tensor_indices = {
+ update_tensor_indices.front(), b.create<ma::ConstantIndexOp>(i)};
+ auto index = ProvideParameter(
+ root_computation, scatter, kScatterIndicesIndex,
+ indices_tensor_indices, call_targets, entry_function, b);
+ auto index_ty = mlir::cast<mlir::IntegerType>(index.getType());
+ if (index_ty.isUnsigned()) {
+ auto int_ty = b.getIntegerType(index_ty.getWidth());
+ index = b.create<mlir::UnrealizedConversionCastOp>(int_ty, index)
+ .getResult(0);
+ index = b.create<ma::IndexCastUIOp>(b.getIndexType(), index);
+ } else {
+ index = b.create<ma::IndexCastOp>(b.getIndexType(), index);
+ auto c0 = b.create<ma::ConstantIndexOp>(0);
+ in_bounds = b.create<ma::AndIOp>(
+ in_bounds,
+ b.create<ma::CmpIOp>(ma::CmpIPredicate::sge, index, c0));
}
- is_in_bounds = b.create<AndIOp>(
- is_in_bounds,
- b.create<CmpIOp>(CmpIPredicate::sge, extracted_index, c0));
- Value ub = b.create<ConstantIndexOp>(
+ Value ub = b.create<ma::ConstantIndexOp>(
scatter_operand->shape().dimensions(i) -
scatter_update->shape().dimensions(i + 1));
- is_in_bounds = b.create<AndIOp>(
- is_in_bounds,
- b.create<CmpIOp>(CmpIPredicate::sle, extracted_index, ub));
- indices[i] = b.create<AddIOp>(extracted_index, indices[i]);
+ in_bounds = b.create<ma::AndIOp>(
+ in_bounds,
+ b.create<ma::CmpIOp>(ma::CmpIPredicate::sle, index, ub));
+ indices[i] = b.create<ma::AddIOp>(index, indices[i]);
}
// Call scatter's computation if is_in_bounds.
Value output_tensor = output_tensors.front();
Value predicated_update =
b.create<scf::IfOp>(
- is_in_bounds,
+ in_bounds,
[&](OpBuilder& then_builder, Location then_loc) -> void {
Value updated_output = EmitScatterComputation(
scatter, indices, update_elem, output_tensor,
diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
index 12fca854ae5..713314969c8 100644
--- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
+++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc
@@ -223,6 +223,42 @@ TEST_F(MlirScatterFusionTest, Scatter_UniqueIndices) {
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
}
+TEST_F(MlirScatterFusionTest, Scatter_Unsigned) {
+ auto kHloString = R"(
+ HloModule module
+
+ add {
+ %p0 = f32[] parameter(0)
+ %p1 = f32[] parameter(1)
+ ROOT %sum = f32[] add(%p0, %p1)
+ }
+ scatter {
+ %operand = f32[10,5] parameter(0)
+ %indices = u32[24,1] parameter(1)
+ %update = f32[24,2,3] parameter(2)
+
+ ROOT %scatter = f32[10,5] scatter(%operand, %indices, %update),
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ to_apply=add
+ }
+ ENTRY entry {
+ %c1 = f32[] constant(1)
+ %c1_tensor = f32[10,5] broadcast(%c1), dimensions={}
+ %indices = u32[24,1] parameter(0)
+ %update = f32[24, 2, 3] parameter(1)
+ ROOT %fusion = f32[10, 5] fusion(%c1_tensor, %indices, %update),
+ kind=kLoop, calls=scatter
+ }
+ )";
+ TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"(
+ // CHECK: func.func @fused_computation(
+ )"));
+ EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3}));
+}
+
TEST_F(MlirScatterFusionTest, Scatter_Add) {
auto kHloString = R"(
HloModule module