diff options
author | Johannes Reifferscheid <jreiffers@google.com> | 2024-05-10 02:26:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2024-05-10 02:36:59 -0700 |
commit | da17b4ad6bd7d1347bad36376f218ae7f7511007 (patch) | |
tree | 859c3b39bdcbd5dae77f1da89b7af129661e9959 | |
parent | b0413907c4db1ea36afe4ccfeacadb9509b44e68 (diff) | |
download | tensorflow-da17b4ad6bd7d1347bad36376f218ae7f7511007.tar.gz |
Support scatter with unsigned indices.
PiperOrigin-RevId: 632428123
-rw-r--r-- | third_party/xla/xla/service/gpu/fusions/BUILD | 1 | ||||
-rw-r--r-- | third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc | 60 | ||||
-rw-r--r-- | third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc | 36 |
3 files changed, 66 insertions, 31 deletions
diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 25431941a41..e78029820bd 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -429,6 +429,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index c78a6b4f057..ded3ca4b6e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -18,11 +18,9 @@ limitations under the License. #include <optional> #include <vector> -#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -31,10 +29,12 @@ limitations under the License. #include "mlir/IR/AffineExpr.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -52,16 +52,13 @@ namespace xla { namespace gpu { namespace { +namespace ma = mlir::arith; + using llvm::SmallVector; using mlir::Location; using mlir::OpBuilder; using mlir::Value; using mlir::ValueRange; -using mlir::arith::AddIOp; -using mlir::arith::AndIOp; -using mlir::arith::CmpIOp; -using mlir::arith::CmpIPredicate; -using mlir::arith::ConstantIndexOp; using mlir::func::ReturnOp; using mlir::tensor::InsertOp; using mlir_converter::ApplyAffineMap; @@ -208,7 +205,6 @@ absl::Status MlirScatterFusion::EmitEntryFunction( b.setInsertionPointToStart(entry_function.addEntryBlock()); SmallVector<Value> result_tensors{entry_function.getArguments().back()}; - auto c0 = b.create<ConstantIndexOp>(0); auto scatter_result = EmitThreadLoopNest( b, result_tensors, thread_id_to_update_map, @@ -224,39 +220,41 @@ absl::Status MlirScatterFusion::EmitEntryFunction( // Extract slice offsets from scatter_indices operand, compute if the // whole slice of scatter_update operand will fit into the output. - mlir::Value is_in_bounds = - b.create<mlir::arith::ConstantIntOp>(1, b.getI1Type()); + mlir::Value in_bounds = b.create<ma::ConstantIntOp>(1, b.getI1Type()); SmallVector<Value, 4> indices{ llvm::ArrayRef(update_tensor_indices).drop_front()}; - for (int i = 0; i < scatter_operand->shape().rank(); ++i) { - Value extracted_index = c0; - if (i < scatter_indices->shape().dimensions(1)) { - SmallVector<Value, 4> indices_tensor_indices = { - update_tensor_indices.front(), b.create<ConstantIndexOp>(i)}; - extracted_index = ProvideParameter( - root_computation, scatter, kScatterIndicesIndex, - indices_tensor_indices, call_targets, entry_function, b); - if (extracted_index.getType() != b.getIndexType()) { - extracted_index = b.create<mlir::arith::IndexCastOp>( - b.getIndexType(), extracted_index); - } + for (int i = 0; i < scatter_indices->shape().dimensions(1); ++i) { + SmallVector<Value, 4> indices_tensor_indices = { + update_tensor_indices.front(), b.create<ma::ConstantIndexOp>(i)}; + auto index = ProvideParameter( + root_computation, scatter, kScatterIndicesIndex, + indices_tensor_indices, call_targets, entry_function, b); + auto index_ty = mlir::cast<mlir::IntegerType>(index.getType()); + if (index_ty.isUnsigned()) { + auto int_ty = b.getIntegerType(index_ty.getWidth()); + index = b.create<mlir::UnrealizedConversionCastOp>(int_ty, index) + .getResult(0); + index = b.create<ma::IndexCastUIOp>(b.getIndexType(), index); + } else { + index = b.create<ma::IndexCastOp>(b.getIndexType(), index); + auto c0 = b.create<ma::ConstantIndexOp>(0); + in_bounds = b.create<ma::AndIOp>( + in_bounds, + b.create<ma::CmpIOp>(ma::CmpIPredicate::sge, index, c0)); } - is_in_bounds = b.create<AndIOp>( - is_in_bounds, - b.create<CmpIOp>(CmpIPredicate::sge, extracted_index, c0)); - Value ub = b.create<ConstantIndexOp>( + Value ub = b.create<ma::ConstantIndexOp>( scatter_operand->shape().dimensions(i) - scatter_update->shape().dimensions(i + 1)); - is_in_bounds = b.create<AndIOp>( - is_in_bounds, - b.create<CmpIOp>(CmpIPredicate::sle, extracted_index, ub)); - indices[i] = b.create<AddIOp>(extracted_index, indices[i]); + in_bounds = b.create<ma::AndIOp>( + in_bounds, + b.create<ma::CmpIOp>(ma::CmpIPredicate::sle, index, ub)); + indices[i] = b.create<ma::AddIOp>(index, indices[i]); } // Call scatter's computation if is_in_bounds. Value output_tensor = output_tensors.front(); Value predicated_update = b.create<scf::IfOp>( - is_in_bounds, + in_bounds, [&](OpBuilder& then_builder, Location then_loc) -> void { Value updated_output = EmitScatterComputation( scatter, indices, update_elem, output_tensor, diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc index 12fca854ae5..713314969c8 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -223,6 +223,42 @@ TEST_F(MlirScatterFusionTest, Scatter_UniqueIndices) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } +TEST_F(MlirScatterFusionTest, Scatter_Unsigned) { + auto kHloString = R"( + HloModule module + + add { + %p0 = f32[] parameter(0) + %p1 = f32[] parameter(1) + ROOT %sum = f32[] add(%p0, %p1) + } + scatter { + %operand = f32[10,5] parameter(0) + %indices = u32[24,1] parameter(1) + %update = f32[24,2,3] parameter(2) + + ROOT %scatter = f32[10,5] scatter(%operand, %indices, %update), + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, + to_apply=add + } + ENTRY entry { + %c1 = f32[] constant(1) + %c1_tensor = f32[10,5] broadcast(%c1), dimensions={} + %indices = u32[24,1] parameter(0) + %update = f32[24, 2, 3] parameter(1) + ROOT %fusion = f32[10, 5] fusion(%c1_tensor, %indices, %update), + kind=kLoop, calls=scatter + } + )"; + TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK: func.func @fused_computation( + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + TEST_F(MlirScatterFusionTest, Scatter_Add) { auto kHloString = R"( HloModule module |