aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorge Karpenkov <cheshire@google.com>2024-05-10 06:35:26 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2024-05-10 06:52:48 -0700
commitb0c214c5a7ce533d152650b13830f63f56b7e868 (patch)
tree1e0c081cb802307d6b1472e9e811527b12234b1a
parentf9659ecee7e4f9c8186fa2fe6e22c4b551c0b230 (diff)
downloadtensorflow-upstream-master.tar.gz
[XLA] [NFC] Upstream ReduceWindowRewriter passupstream-master
PiperOrigin-RevId: 632478633
-rw-r--r--third_party/xla/xla/service/BUILD40
-rw-r--r--third_party/xla/xla/service/reduce_window_rewriter.cc545
-rw-r--r--third_party/xla/xla/service/reduce_window_rewriter.h72
-rw-r--r--third_party/xla/xla/service/reduce_window_rewriter_test.cc184
4 files changed, 841 insertions, 0 deletions
diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD
index 029895a4867..fe3f249ec44 100644
--- a/third_party/xla/xla/service/BUILD
+++ b/third_party/xla/xla/service/BUILD
@@ -7511,6 +7511,46 @@ cc_library(
)
cc_library(
+ name = "reduce_window_rewriter",
+ srcs = ["reduce_window_rewriter.cc"],
+ hdrs = ["reduce_window_rewriter.h"],
+ deps = [
+ ":hlo_pass",
+ "//xla:shape_util",
+ "//xla:status",
+ "//xla:status_macros",
+ "//xla:statusor",
+ "//xla:util",
+ "//xla:window_util",
+ "//xla:xla_data_proto_cc",
+ "//xla/hlo/ir:hlo",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/log",
+ "@com_google_absl//absl/log:check",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ "@local_tsl//tsl/platform:errors",
+ "@local_tsl//tsl/platform:statusor",
+ ],
+)
+
+xla_cc_test(
+ name = "reduce_window_rewriter_test",
+ srcs = ["reduce_window_rewriter_test.cc"],
+ deps = [
+ ":reduce_window_rewriter",
+ "//xla:test",
+ "//xla:xla_data_proto_cc",
+ "//xla/tests:hlo_test_base",
+ "@com_google_absl//absl/strings",
+ "@local_tsl//tsl/platform:test_main",
+ ],
+)
+
+cc_library(
name = "stochastic_convert_decomposer",
srcs = ["stochastic_convert_decomposer.cc"],
hdrs = ["stochastic_convert_decomposer.h"],
diff --git a/third_party/xla/xla/service/reduce_window_rewriter.cc b/third_party/xla/xla/service/reduce_window_rewriter.cc
new file mode 100644
index 00000000000..39b201e108c
--- /dev/null
+++ b/third_party/xla/xla/service/reduce_window_rewriter.cc
@@ -0,0 +1,545 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/reduce_window_rewriter.h"
+
+#include <cstddef>
+#include <cstdint>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/log/check.h"
+#include "absl/log/log.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "xla/hlo/ir/hlo_casting_utils.h"
+#include "xla/hlo/ir/hlo_computation.h"
+#include "xla/hlo/ir/hlo_instruction.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_opcode.h"
+#include "xla/shape.h"
+#include "xla/shape_util.h"
+#include "xla/status.h"
+#include "xla/status_macros.h"
+#include "xla/util.h"
+#include "xla/window_util.h"
+#include "xla/xla_data.pb.h"
+#include "tsl/platform/errors.h"
+#include "tsl/platform/statusor.h"
+
+namespace xla {
+
+static size_t FlattenShapeIndex(const ShapeIndex& shape_index) {
+ if (shape_index.empty()) {
+ return 0;
+ }
+ CHECK_EQ(shape_index.size(), 1);
+ return shape_index.back();
+}
+
+static Shape ShapeAtIndex(const Shape& shape, const ShapeIndex& shape_index) {
+ if (shape_index.empty()) {
+ return shape;
+ }
+ CHECK_EQ(shape_index.size(), 1);
+ return ShapeUtil::GetTupleElementShape(shape, shape_index.back());
+}
+
+static HloInstruction* GetAtIndex(HloInstruction* hlo,
+ const ShapeIndex& shape_index) {
+ if (shape_index.empty()) {
+ return hlo;
+ }
+ CHECK_EQ(shape_index.size(), 1);
+ return hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeAtIndex(hlo->shape(), shape_index), hlo, shape_index.back()));
+}
+
+// Transform reduce-win(x) ->
+// if rank(x) == 1:
+// then: reshape_r2_r1(reduce-win(reshape_r1_r2(x)))
+// else: no change
+Status ReduceWindowRewriter::ReplaceReduceWindowWithReshape(
+ HloReduceWindowInstruction* reduce_window) {
+ VLOG(2) << "Converting R1 reduce window: " << reduce_window->ToString();
+
+ std::vector<Shape> r2_output_shapes;
+ ShapeUtil::ForEachSubshape(
+ reduce_window->shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (!ShapeUtil::IsLeafIndex(reduce_window->shape(), shape_index)) {
+ return;
+ }
+ Shape r2_output_shape = subshape;
+ ShapeUtil::AppendMajorDimension(1, &r2_output_shape);
+ UpdateLayout(&r2_output_shape);
+ r2_output_shapes.push_back(r2_output_shape);
+
+ VLOG(2) << "ReduceWindowRewriter: Converting R2 result to R1: "
+ << ShapeUtil::HumanStringWithLayout(r2_output_shape);
+ });
+
+ Window r2_window = reduce_window->window();
+ WindowDimension* dim = r2_window.add_dimensions();
+ dim->set_size(1);
+ dim->set_stride(1);
+ dim->set_base_dilation(1);
+ dim->set_window_dilation(1);
+
+ std::vector<HloInstruction*> r2_operands;
+ for (HloInstruction* operand : reduce_window->inputs()) {
+ Shape r2_input_shape = operand->shape();
+ ShapeUtil::AppendMajorDimension(1, &r2_input_shape);
+ UpdateLayout(&r2_input_shape);
+
+ VLOG(2) << "ReduceWindowRewriter: Converting R1 operand to R2: "
+ << ShapeUtil::HumanStringWithLayout(r2_input_shape);
+ HloInstruction* r2_operand = operand->parent()->AddInstruction(
+ HloInstruction::CreateReshape(r2_input_shape, operand));
+ VLOG(2) << "R2 new operand: " << r2_operand->ToString();
+ r2_operands.push_back(r2_operand);
+ }
+ HloInstruction* new_reduce_window = reduce_window->parent()->AddInstruction(
+ HloInstruction::CreateReduceWindow(
+ reduce_window->shape().IsTuple()
+ ? ShapeUtil::MakeTupleShape(r2_output_shapes)
+ : r2_output_shapes[0],
+ r2_operands, reduce_window->init_values(), r2_window,
+ reduce_window->to_apply()));
+
+ VLOG(2) << "R2 resulting reduce window: " << new_reduce_window->ToString();
+
+ std::vector<HloInstruction*> final_reshapes;
+ ShapeUtil::ForEachSubshape(
+ reduce_window->shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (!ShapeUtil::IsLeafIndex(reduce_window->shape(), shape_index)) {
+ return;
+ }
+ HloInstruction* final_reshape =
+ new_reduce_window->parent()->AddInstruction(
+ HloInstruction::CreateReshape(
+ subshape, GetAtIndex(new_reduce_window, shape_index)));
+ final_reshapes.push_back(final_reshape);
+ });
+ HloInstruction* result;
+ if (reduce_window->shape().IsTuple()) {
+ result = new_reduce_window->parent()->AddInstruction(
+ HloInstruction::CreateTuple(final_reshapes));
+ } else {
+ CHECK_EQ(final_reshapes.size(), 1);
+ result = final_reshapes[0];
+ }
+ TF_RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(result));
+ TF_RETURN_IF_ERROR(
+ new_reduce_window->parent()->RemoveInstruction(reduce_window));
+
+ return OkStatus();
+}
+
+absl::StatusOr<bool> ReduceWindowRewriter::TryOptimizeCumSumOrProd(
+ HloReduceWindowInstruction* reduce_window) {
+ const Shape& operand_shape = reduce_window->inputs().front()->shape();
+
+ // Try to find the scan axis. We expect all window dimensions to be trivial,
+ // except for one.
+ int64_t rank = operand_shape.rank();
+ const Window& window = reduce_window->window();
+ int64_t scan_dim_num = -1;
+ for (int i = 0; i < rank; ++i) {
+ const WindowDimension& window_dim = window.dimensions(i);
+ if (window_util::IsTrivialWindowDimension(window_dim)) {
+ continue;
+ }
+ if (scan_dim_num != -1) {
+ // At least two non-trivial dimensions exist, so, no cigar.
+ return false;
+ }
+ scan_dim_num = i;
+ }
+
+ if (scan_dim_num == -1) {
+ return false;
+ }
+
+ const int64_t scan_length = operand_shape.dimensions(scan_dim_num);
+ absl::Span<HloInstruction* const> init_values = reduce_window->init_values();
+ const WindowDimension& scan_window_dim = window.dimensions(scan_dim_num);
+
+ bool forward_scan = (scan_window_dim.padding_low() == scan_length - 1 ||
+ scan_window_dim.padding_low() == scan_length) &&
+ scan_window_dim.padding_high() == 0;
+ bool reverse_scan = (scan_window_dim.padding_high() == scan_length - 1 ||
+ scan_window_dim.padding_high() == scan_length) &&
+ scan_window_dim.padding_low() == 0;
+ // We accept two values for low padding: the input length for exclusive scan,
+ // and scan_length - 1 for inclusive scan.
+ if (scan_window_dim.stride() != 1 || scan_window_dim.size() != scan_length ||
+ (!forward_scan && !reverse_scan) || scan_window_dim.window_reversal() ||
+ scan_window_dim.base_dilation() != 1 ||
+ scan_window_dim.window_dilation() != 1) {
+ return false;
+ }
+ bool is_exclusive = forward_scan
+ ? (scan_window_dim.padding_low() == scan_length)
+ : (scan_window_dim.padding_high() == scan_length);
+
+ if (scan_length <= base_length_) {
+ return false;
+ }
+
+ if (reduce_window->to_apply()->root_instruction()->shape().IsTuple() &&
+ reduce_window->to_apply()->root_instruction()->opcode() !=
+ HloOpcode::kTuple) {
+ return false;
+ }
+
+ VLOG(2) << "Rewriting Scan: " << reduce_window->ToString();
+ HloComputation* parent = reduce_window->parent();
+ std::vector<HloInstruction*> sources(reduce_window->inputs().begin(),
+ reduce_window->inputs().end());
+
+ // Since we need to tile this dimension, it's convenient to have it logically
+ // last.
+ std::vector<int64_t> permutation(rank);
+ absl::c_iota(permutation, 0);
+ permutation[scan_dim_num] = rank - 1;
+ permutation[rank - 1] = scan_dim_num;
+ if (scan_dim_num != rank - 1) {
+ for (size_t i = 0; i < sources.size(); ++i) {
+ sources[i] = parent->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::PermuteDimensions(permutation, sources[i]->shape()),
+ sources[i], permutation));
+ }
+ }
+
+ // We don't actually need to match the computation - this transformation will
+ // work for an commutative/associative reducer, which is what we assume for
+ // ReduceWindow anyway.
+
+ // Break the scan into an "inner" and an "outer" scan - this is basically a
+ // tree reduction:
+ // (The explanation below assumes an R1 scan for simplicity. For Rk scan, all
+ // shapes have k-1 "batch" dimensions that need to be preserved.)
+ //
+ // 1) If necessary, pad input from {N} to {K}, where K is a multiple of 128.
+ // 2) Reshape from {K} to {K / 128, 128}.
+ // 3) Scan each 128 dimension.
+ // 4) Slice out the last column.
+ // 5) Exclusive scan across the last column.
+ // 6) Broadcast it back into {K / 128, 128}
+ // 7) Add up the results of (3) and (6).
+ // 8) Reshape back into {K}
+ // 9) Slice off the padding.
+ //
+ // For example, consider a cumulative sum over an R1 of length 9, with a base
+ // case of 3 instead of 128. Let the input be:
+ // [0 1 2 3 4 5 6 7 8]
+ //
+ // We need no padding, so we go directly to (2):
+ // [0 1 2
+ // 3 4 5
+ // 6 7 8]
+ //
+ // The result of the scan in (3) is:
+ // [0 1 3
+ // 3 7 12
+ // 6 13 21]
+ //
+ // Slicing out the last column we get (4):
+ // [ 3
+ // 12
+ // 21]
+ //
+ // And after scanning and broadcasting (5 and 6):
+ // [ 0 0 0
+ // 3 3 3
+ // 15 15 15]
+ //
+ // Finally, we add up the two scans (3) and (6), getting (7):
+ // [ 0 1 3
+ // 6 10 15
+ // 21 28 36]
+ //
+ // And reshape back into [0 1 3 6 10 15 21 28 36].
+ //
+ // For reverse scans, we perform the same as forward scans, except: we perform
+ // a reverse scan at (3), slice out the first column at (4), and perform an
+ // exclusive reverse scan of the first columnt at (5).
+
+ // Pad.
+ const int64_t padded_length = RoundUpTo(scan_length, base_length_);
+ if (scan_length != padded_length) {
+ for (size_t i = 0; i < sources.size(); ++i) {
+ auto* source = sources[i];
+ Shape padded_shape = source->shape();
+ padded_shape.set_dimensions(rank - 1, padded_length);
+
+ UpdateLayout(&padded_shape);
+ auto padding_config = MakeNoPaddingConfig(rank);
+ padding_config.mutable_dimensions(rank - 1)->set_edge_padding_high(
+ padded_length - scan_length);
+
+ sources[i] = parent->AddInstruction(HloInstruction::CreatePad(
+ padded_shape, source, init_values[i], padding_config));
+ }
+ }
+
+ // Reshape to R(k+1).
+ const int64_t num_columns = padded_length / base_length_;
+ std::vector<HloInstruction*> tiled_sources;
+ std::vector<Shape> tiled_shapes;
+ for (size_t i = 0; i < sources.size(); ++i) {
+ auto* source = sources[i];
+ Shape tiled_shape = source->shape();
+ tiled_shape.set_dimensions(rank - 1, num_columns);
+
+ UpdateLayout(&tiled_shape);
+ ShapeUtil::AppendMajorDimension(base_length_, &tiled_shape);
+ tiled_shapes.push_back(tiled_shape);
+ tiled_sources.push_back(parent->AddInstruction(
+ HloInstruction::CreateReshape(tiled_shape, source)));
+ }
+
+ // Outer scan.
+ Window outer_window =
+ window_util::MakeWindow(std::vector<int64_t>(rank + 1, 1));
+ outer_window.mutable_dimensions(rank)->set_size(base_length_);
+ if (forward_scan) {
+ outer_window.mutable_dimensions(rank)->set_padding_low(base_length_ - 1);
+ } else {
+ outer_window.mutable_dimensions(rank)->set_padding_high(base_length_ - 1);
+ }
+ auto outer_reduce_window =
+ parent->AddInstruction(HloInstruction::CreateReduceWindow(
+ reduce_window->shape().IsTuple()
+ ? ShapeUtil::MakeTupleShape(tiled_shapes)
+ : tiled_shapes[0],
+ tiled_sources, init_values, outer_window, reduce_window->to_apply()));
+
+ // Slice out the last (first if reverse scan) column.
+ std::vector<Shape> column_shapes;
+ std::vector<HloInstruction*> last_cols;
+ ShapeUtil::ForEachSubshape(
+ outer_reduce_window->shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (!ShapeUtil::IsLeafIndex(outer_reduce_window->shape(),
+ shape_index)) {
+ return;
+ }
+ Shape column_shape = subshape;
+ column_shape.set_dimensions(rank, 1);
+
+ UpdateLayout(&column_shape);
+ std::vector<int64_t> col_slice_starts(rank + 1, 0);
+ std::vector<int64_t> col_slice_limits(
+ SpanToVector(subshape.dimensions()));
+ if (forward_scan) {
+ col_slice_starts[rank] = base_length_ - 1;
+ } else {
+ col_slice_limits[rank] = 1;
+ }
+ auto last_col = parent->AddInstruction(HloInstruction::CreateSlice(
+ column_shape, GetAtIndex(outer_reduce_window, shape_index),
+ col_slice_starts, col_slice_limits,
+ std::vector<int64_t>(rank + 1, 1)));
+ column_shape.DeleteDimension(rank);
+ last_col = parent->AddInstruction(
+ HloInstruction::CreateReshape(column_shape, last_col));
+ last_cols.push_back(last_col);
+
+ column_shape.set_dimensions(rank - 1, num_columns + 1);
+ UpdateLayout(&column_shape);
+ column_shapes.push_back(column_shape);
+ });
+
+ // Inner scan
+ Window inner_window = window_util::MakeWindow(std::vector<int64_t>(rank, 1));
+ inner_window.mutable_dimensions(rank - 1)->set_size(num_columns);
+ if (forward_scan) {
+ inner_window.mutable_dimensions(rank - 1)->set_padding_low(num_columns);
+ } else {
+ inner_window.mutable_dimensions(rank - 1)->set_padding_high(num_columns);
+ }
+ auto inner_reduce_window =
+ parent->AddInstruction(HloInstruction::CreateReduceWindow(
+ reduce_window->shape().IsTuple()
+ ? ShapeUtil::MakeTupleShape(column_shapes)
+ : column_shapes[0],
+ last_cols, init_values, inner_window, reduce_window->to_apply()));
+ std::vector<int64_t> exclusive_slice_starts(rank, 0);
+ std::vector<int64_t> exclusive_slice_limits =
+ SpanToVector(column_shapes[0].dimensions());
+ if (forward_scan) {
+ exclusive_slice_limits[rank - 1] = num_columns;
+ } else {
+ exclusive_slice_starts[rank - 1] = 1;
+ exclusive_slice_limits[rank - 1] = num_columns + 1;
+ }
+ std::vector<HloInstruction*> inner_scan_components;
+ ShapeUtil::ForEachSubshape(
+ inner_reduce_window->shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (!ShapeUtil::IsLeafIndex(inner_reduce_window->shape(),
+ shape_index)) {
+ return;
+ }
+ size_t idx = FlattenShapeIndex(shape_index);
+ auto last_col = last_cols[idx];
+ auto* inner_slice = parent->AddInstruction(HloInstruction::CreateSlice(
+ last_col->shape(), GetAtIndex(inner_reduce_window, shape_index),
+ exclusive_slice_starts, exclusive_slice_limits,
+ std::vector<int64_t>(rank, 1)));
+
+ std::vector<int64_t> rank_iota(rank);
+ absl::c_iota(rank_iota, 0);
+ auto* inner_scan_component =
+ parent->AddInstruction(HloInstruction::CreateBroadcast(
+ tiled_shapes[idx], inner_slice, rank_iota));
+ inner_scan_components.push_back(inner_scan_component);
+ });
+
+ // Combine inner and outer scans.
+ std::vector<HloInstruction*> map_operands;
+ ShapeUtil::ForEachSubshape(
+ outer_reduce_window->shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (!ShapeUtil::IsLeafIndex(outer_reduce_window->shape(),
+ shape_index)) {
+ return;
+ }
+ map_operands.push_back(GetAtIndex(outer_reduce_window, shape_index));
+ });
+ map_operands.insert(map_operands.end(), inner_scan_components.begin(),
+ inner_scan_components.end());
+
+ // Reshape back to Rk and slice out the padding.
+ std::vector<HloInstruction*> scans;
+ auto status = ShapeUtil::ForEachSubshapeWithStatus(
+ outer_reduce_window->shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) -> Status {
+ if (!ShapeUtil::IsLeafIndex(outer_reduce_window->shape(),
+ shape_index)) {
+ return OkStatus();
+ }
+ size_t idx = FlattenShapeIndex(shape_index);
+ auto source = sources[idx];
+ HloComputation* map_computation;
+ auto reduce_function_root =
+ reduce_window->to_apply()->root_instruction();
+ if (reduce_function_root->shape().IsTuple()) {
+ TF_RET_CHECK(reduce_function_root->opcode() == HloOpcode::kTuple);
+ // This corresponds to step 7: combining the inner scan with the outer
+ // scan using a map function.
+ auto* map_computation_root = reduce_function_root->operand(idx);
+ absl::flat_hash_map<const HloInstruction*,
+ std::unique_ptr<HloInstruction>>
+ replacements;
+ replacements[reduce_function_root] = nullptr;
+ map_computation = parent->parent()->AddEmbeddedComputation(
+ reduce_window->to_apply()->CloneWithReplacements(
+ &replacements,
+ /*extra_parameters=*/{}, nullptr, "clone",
+ map_computation_root));
+ } else {
+ map_computation = reduce_window->to_apply();
+ }
+ auto scan = parent->AddInstruction(HloInstruction::CreateMap(
+ ShapeAtIndex(outer_reduce_window->shape(), shape_index),
+ map_operands, map_computation));
+ scan = parent->AddInstruction(
+ HloInstruction::CreateReshape(source->shape(), scan));
+
+ // If necessary, transpose back to the original order.
+ if (scan_dim_num != rank - 1) {
+ scan = parent->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::PermuteDimensions(permutation, source->shape()), scan,
+ permutation));
+ }
+
+ // Remove the padding to the base length.
+ if (padded_length != scan_length) {
+ scan = parent->AddInstruction(HloInstruction::CreateSlice(
+ operand_shape, scan, std::vector<int64_t>(rank, 0),
+ operand_shape.dimensions(), std::vector<int64_t>(rank, 1)));
+ }
+
+ if (is_exclusive) {
+ auto padding_config = MakeNoPaddingConfig(rank);
+ if (forward_scan) {
+ padding_config.mutable_dimensions(scan_dim_num)
+ ->set_edge_padding_low(1);
+ } else {
+ padding_config.mutable_dimensions(scan_dim_num)
+ ->set_edge_padding_high(1);
+ }
+ scan = parent->AddInstruction(HloInstruction::CreatePad(
+ ShapeAtIndex(reduce_window->shape(), shape_index), scan,
+ init_values[idx], padding_config));
+ }
+ scans.push_back(scan);
+ return OkStatus();
+ });
+ TF_RETURN_IF_ERROR(status);
+
+ HloInstruction* scan;
+ if (reduce_window->shape().IsTuple()) {
+ scan = parent->AddInstruction(HloInstruction::CreateTuple(scans));
+ } else {
+ CHECK_EQ(scans.size(), 1);
+ scan = scans[0];
+ }
+ TF_RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(scan));
+ TF_RETURN_IF_ERROR(parent->RemoveInstruction(reduce_window));
+
+ return true;
+}
+
+absl::StatusOr<bool> ReduceWindowRewriter::Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) {
+ bool changed = false;
+ for (const auto& computation : module->computations(execution_threads)) {
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ HloReduceWindowInstruction* reduce_window =
+ DynCast<HloReduceWindowInstruction>(instruction);
+ if (!reduce_window) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(bool made_change,
+ TryOptimizeCumSumOrProd(reduce_window));
+ if (made_change) {
+ changed = true;
+ continue;
+ }
+
+ if (reduce_window->inputs().front()->shape().rank() != 1) {
+ continue;
+ }
+ TF_RETURN_IF_ERROR(ReplaceReduceWindowWithReshape(reduce_window));
+
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace xla
diff --git a/third_party/xla/xla/service/reduce_window_rewriter.h b/third_party/xla/xla/service/reduce_window_rewriter.h
new file mode 100644
index 00000000000..6fed774b27a
--- /dev/null
+++ b/third_party/xla/xla/service/reduce_window_rewriter.h
@@ -0,0 +1,72 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_
+#define XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_
+
+#include <cstdint>
+#include "absl/container/flat_hash_set.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
+#include "xla/hlo/ir/hlo_instructions.h"
+#include "xla/hlo/ir/hlo_module.h"
+#include "xla/service/hlo_pass_interface.h"
+#include "xla/status.h"
+#include "xla/statusor.h"
+
+namespace xla {
+
+// Rewrite ReduceWindow to be more performant in cases it is written in a
+// quadratic way:
+//
+// 1) Work around unimplemented cases in the implementation of ReduceWindow.
+//
+// This rewrites all R1 ReduceWindow nodes. We reshape the operand to an
+// R2, perform the operation, and reshape back to R1. The reshapes correspond to
+// a bitcast if the tensor length is less than or equal to a passed parameter.
+// The motivation for this is to avoid use of overly large reductions and the
+// complexities and restrictions therein.
+//
+// 2) Rewrite ReduceWindow ops that represent a CumSum/CumProd into a
+// tree-reduction (see details in the implementation).
+// Note that this may itself generate R1 ReduceWindow ops, which means this pass
+// needs to be run to a fixed point.
+class ReduceWindowRewriter : public HloModulePass {
+ public:
+ // `base_length` is a size of a reduce-window we are comfortable with
+ // executing.
+ explicit ReduceWindowRewriter(int64_t base_length)
+ : base_length_(base_length) {}
+
+ absl::string_view name() const override { return "reduce-window-rewriter"; }
+
+ using HloPassInterface::Run;
+ absl::StatusOr<bool> Run(
+ HloModule* module,
+ const absl::flat_hash_set<absl::string_view>& execution_threads) override;
+
+ private:
+ Status ReplaceReduceWindowWithReshape(
+ HloReduceWindowInstruction* reduce_window);
+
+ StatusOr<bool> TryOptimizeCumSumOrProd(
+ HloReduceWindowInstruction* reduce_window);
+
+ int64_t base_length_;
+};
+
+} // namespace xla
+
+#endif // XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_
diff --git a/third_party/xla/xla/service/reduce_window_rewriter_test.cc b/third_party/xla/xla/service/reduce_window_rewriter_test.cc
new file mode 100644
index 00000000000..b40314f6e4d
--- /dev/null
+++ b/third_party/xla/xla/service/reduce_window_rewriter_test.cc
@@ -0,0 +1,184 @@
+/* Copyright 2024 The OpenXLA Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/service/reduce_window_rewriter.h"
+
+#include <optional>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "xla/test.h"
+#include "xla/tests/hlo_test_base.h"
+#include "xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class ReduceWindowRewriterTest : public HloTestBase {
+ public:
+ void CheckReduceWindowRewrite(absl::string_view hlo,
+ std::optional<absl::string_view> expected) {
+ RunAndFilecheckHloRewrite(hlo, ReduceWindowRewriter{128}, expected);
+ }
+};
+
+TEST_F(ReduceWindowRewriterTest, EliminateR1) {
+ const char* hlo = R"(
+%binary_add {
+ %a = f32[] parameter(0)
+ %b = f32[] parameter(1)
+ ROOT %add = f32[] add(f32[] %a, f32[] %b)
+}
+
+ENTRY %EliminateR1 (input: f32[10]) -> f32[10] {
+ %input = f32[10]{0} parameter(0)
+ %constant = f32[] constant(0)
+ ROOT %reduce-window = f32[10]{0} reduce-window(f32[10]{0} %input, f32[] %constant), window={size=5 pad=2_2}, to_apply=%binary_add
+}
+)";
+
+ CheckReduceWindowRewrite(hlo, R"(
+// CHECK: [[reduce_window_1_0:%[^ ]+]] = f32[10,1]{0,1} reduce-window([[reshape_1:%[^ ]+]], [[constant_2:%[^ ]+]]), window={size=5x1 pad=2_2x0_0}, to_apply=[[binary_add_3:%[^ ]+]]
+// CHECK-NEXT: ROOT [[reshape_1_4:%[^ ]+]] = f32[10]{0} reshape([[reduce_window_1_0]])
+)");
+}
+
+TEST_F(ReduceWindowRewriterTest, EliminateR1Variadic) {
+ const char* hlo = R"(
+HloModule reduce-window
+
+add_float {
+ lhs.0 = f32[] parameter(0)
+ lhs.1 = f32[] parameter(1)
+ rhs.0 = f32[] parameter(2)
+ rhs.1 = f32[] parameter(3)
+ sum.0 = f32[] add(lhs.0, rhs.0)
+ sum.1 = f32[] add(lhs.1, rhs.1)
+ ROOT root = (f32[], f32[]) tuple(sum.0, sum.1)
+}
+
+ENTRY entry (arg: f32[10]) -> (f32[10], f32[10]) {
+ arg = f32[10]{0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce-window = (f32[10]{0}, f32[10]{0}) reduce-window(f32[10]{0} %arg, f32[10]{0} %arg, f32[] %constant, f32[] %constant), window={size=5 pad=2_2}, to_apply=%add_float
+})";
+
+ CheckReduceWindowRewrite(hlo, R"(
+// CHECK: ENTRY %entry (arg: f32[10]) -> (f32[10], f32[10]) {
+// CHECK-NEXT: [[arg_0:%[^ ]+]] = f32[10]{0} parameter(0)
+// CHECK-NEXT: [[reshape_1:%[^ ]+]] = f32[10,1]{0,1} reshape([[arg_0]])
+// CHECK-NEXT: [[reshape_1_2:%[^ ]+]] = f32[10,1]{0,1} reshape([[arg_0]])
+// CHECK-NEXT: [[constant_3:%[^ ]+]] = f32[] constant(0)
+// CHECK-NEXT: [[reduce_window_1_4:%[^ ]+]] = (f32[10,1]{0,1}, f32[10,1]{0,1}) reduce-window([[reshape_1]], [[reshape_1_2]], [[constant_3]], [[constant_3]]), window={size=5x1 pad=2_2x0_0}, to_apply=[[add_float_5:%[^ ]+]]
+// CHECK-NEXT: [[get_tuple_element_6:%[^ ]+]] = f32[10,1]{0,1} get-tuple-element([[reduce_window_1_4]]), index=0
+// CHECK-NEXT: [[reshape_2_7:%[^ ]+]] = f32[10]{0} reshape([[get_tuple_element_6]])
+// CHECK-NEXT: [[get_tuple_element_1_8:%[^ ]+]] = f32[10,1]{0,1} get-tuple-element([[reduce_window_1_4]]), index=1
+// CHECK-NEXT: [[reshape_3_9:%[^ ]+]] = f32[10]{0} reshape([[get_tuple_element_1_8]])
+// CHECK-NEXT: ROOT [[tuple_10:%[^ ]+]] = (f32[10]{0}, f32[10]{0}) tuple([[reshape_2_7]], [[reshape_3_9]])
+// CHECK-NEXT:}
+)");
+}
+
+TEST_F(ReduceWindowRewriterTest, OptimizeR1InclusiveScan) {
+ const char* hlo = R"(
+HloModule reduce-window
+
+add_float {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT mul = f32[] multiply(lhs, rhs)
+}
+
+ENTRY entry (arg: f32[46592]) -> f32[46592] {
+ arg = f32[46592]{0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce-window = f32[46592]{0} reduce-window(f32[46592]{0} %arg, f32[] %constant), window={size=46592 pad=46591_0}, to_apply=%add_float
+})";
+
+ CheckReduceWindowRewrite(hlo, R"(
+// CHECK: ENTRY %entry (arg: f32[46592]) -> f32[46592] {
+// CHECK-NEXT: [[arg_0:%[^ ]+]] = f32[46592]{0} parameter(0)
+// CHECK-NEXT: [[reshape_1:%[^ ]+]] = f32[364,128]{0,1} reshape([[arg_0]])
+// CHECK-NEXT: [[constant_2:%[^ ]+]] = f32[] constant(0)
+// CHECK-NEXT: [[reduce_window_1_3:%[^ ]+]] = f32[364,128]{0,1} reduce-window([[reshape_1]], [[constant_2]]), window={size=1x128 pad=0_0x127_0}, to_apply=[[add_float_4:%[^ ]+]]
+// CHECK-NEXT: [[slice_5:%[^ ]+]] = f32[364,1]{0,1} slice([[reduce_window_1_3]]), slice={[0:364], [127:128]}
+// CHECK-NEXT: [[reshape_1_6:%[^ ]+]] = f32[364]{0} reshape([[slice_5]])
+// CHECK-NEXT: [[reduce_window_2_7:%[^ ]+]] = f32[365]{0} reduce-window([[reshape_1_6]], [[constant_2]]), window={size=364 pad=364_0}, to_apply=[[add_float_4]]
+// CHECK-NEXT: [[slice_1_8:%[^ ]+]] = f32[364]{0} slice([[reduce_window_2_7]]), slice={[0:364]}
+// CHECK-NEXT: [[broadcast_9:%[^ ]+]] = f32[364,128]{0,1} broadcast([[slice_1_8]]), dimensions={0}
+// CHECK-NEXT: [[map_10:%[^ ]+]] = f32[364,128]{0,1} map([[reduce_window_1_3]], [[broadcast_9]]), dimensions={0,1}, to_apply=[[add_float_4]]
+// CHECK-NEXT: ROOT [[reshape_2_11:%[^ ]+]] = f32[46592]{0} reshape([[map_10]])
+// CHECK-NEXT:}
+)");
+}
+
+TEST_F(ReduceWindowRewriterTest, OptimizeR1InclusiveScanVariadic) {
+ const std::string hlo_string = R"(
+HloModule reduce-window
+
+MaxMin {
+ l.max = f32[] parameter(0)
+ l.min = f32[] parameter(1)
+ r.max = f32[] parameter(2)
+ r.min = f32[] parameter(3)
+ max = f32[] maximum(l.max, r.max)
+ min = f32[] minimum(l.min, r.min)
+ ROOT root = (f32[], f32[]) tuple(max, min)
+}
+
+ENTRY entry (arg_0: f32[46592], arg_1: f32[46592]) -> (f32[46592], f32[46592]) {
+ arg.0 = f32[46592]{0} parameter(0)
+ arg.1 = f32[46592]{0} parameter(1)
+ init_ninf = f32[] constant(-inf)
+ init_inf = f32[] constant(inf)
+ ROOT reduce-window = (f32[46592]{0}, f32[46592]{0}) reduce-window(f32[46592]{0} %arg.0, f32[46592]{0} %arg.1, f32[] %init_ninf, f32[] %init_inf), window={size=46592 pad=46591_0}, to_apply=%MaxMin
+}
+)";
+
+ CheckReduceWindowRewrite(hlo_string, R"(
+// CHECK: ENTRY %entry (arg.0: f32[46592], arg.1: f32[46592]) -> (f32[46592], f32[46592]) {
+// CHECK-NEXT: [[arg_0_0:%[^ ]+]] = f32[46592]{0} parameter(0)
+// CHECK-NEXT: [[reshape_1:%[^ ]+]] = f32[364,128]{0,1} reshape([[arg_0_0]])
+// CHECK-NEXT: [[arg_1_2:%[^ ]+]] = f32[46592]{0} parameter(1)
+// CHECK-NEXT: [[reshape_1_3:%[^ ]+]] = f32[364,128]{0,1} reshape([[arg_1_2]])
+// CHECK-NEXT: [[init_ninf_4:%[^ ]+]] = f32[] constant(-inf)
+// CHECK-NEXT: [[init_inf_5:%[^ ]+]] = f32[] constant(inf)
+// CHECK-NEXT: [[reduce_window_1_6:%[^ ]+]] = (f32[364,128]{0,1}, f32[364,128]{0,1}) reduce-window([[reshape_1]], [[reshape_1_3]], [[init_ninf_4]], [[init_inf_5]]), window={size=1x128 pad=0_0x127_0}, to_apply=[[MaxMin_7:%[^ ]+]]
+// CHECK-NEXT: [[get_tuple_element_4_8:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=0
+// CHECK-NEXT: [[get_tuple_element_5_9:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=1
+// CHECK-NEXT: [[get_tuple_element_10:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=0
+// CHECK-NEXT: [[slice_11:%[^ ]+]] = f32[364,1]{0,1} slice([[get_tuple_element_10]]), slice={[0:364], [127:128]}
+// CHECK-NEXT: [[reshape_2_12:%[^ ]+]] = f32[364]{0} reshape([[slice_11]])
+// CHECK-NEXT: [[get_tuple_element_1_13:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=1
+// CHECK-NEXT: [[slice_1_14:%[^ ]+]] = f32[364,1]{0,1} slice([[get_tuple_element_1_13]]), slice={[0:364], [127:128]}
+// CHECK-NEXT: [[reshape_3_15:%[^ ]+]] = f32[364]{0} reshape([[slice_1_14]])
+// CHECK-NEXT: [[reduce_window_2_16:%[^ ]+]] = (f32[365]{0}, f32[365]{0}) reduce-window([[reshape_2_12]], [[reshape_3_15]], [[init_ninf_4]], [[init_inf_5]]), window={size=364 pad=364_0}, to_apply=[[MaxMin_7]]
+// CHECK-NEXT: [[get_tuple_element_2_17:%[^ ]+]] = f32[365]{0} get-tuple-element([[reduce_window_2_16]]), index=0
+// CHECK-NEXT: [[slice_2_18:%[^ ]+]] = f32[364]{0} slice([[get_tuple_element_2_17]]), slice={[0:364]}
+// CHECK-NEXT: [[broadcast_19:%[^ ]+]] = f32[364,128]{0,1} broadcast([[slice_2_18]]), dimensions={0}
+// CHECK-NEXT: [[get_tuple_element_3_20:%[^ ]+]] = f32[365]{0} get-tuple-element([[reduce_window_2_16]]), index=1
+// CHECK-NEXT: [[slice_3_21:%[^ ]+]] = f32[364]{0} slice([[get_tuple_element_3_20]]), slice={[0:364]}
+// CHECK-NEXT: [[broadcast_1_22:%[^ ]+]] = f32[364,128]{0,1} broadcast([[slice_3_21]]), dimensions={0}
+// CHECK-NEXT: [[map_23:%[^ ]+]] = f32[364,128]{0,1} map([[get_tuple_element_4_8]], [[get_tuple_element_5_9]], [[broadcast_19]], [[broadcast_1_22]]), dimensions={0,1}, to_apply=[[MaxMin_7]].clone
+// CHECK-NEXT: [[reshape_4_24:%[^ ]+]] = f32[46592]{0} reshape([[map_23]])
+// CHECK-NEXT: [[map_1_25:%[^ ]+]] = f32[364,128]{0,1} map([[get_tuple_element_4_8]], [[get_tuple_element_5_9]], [[broadcast_19]], [[broadcast_1_22]]), dimensions={0,1}, to_apply=[[MaxMin_7]].clone.1
+// CHECK-NEXT: [[reshape_5_26:%[^ ]+]] = f32[46592]{0} reshape([[map_1_25]])
+// CHECK-NEXT: ROOT [[tuple_27:%[^ ]+]] = (f32[46592]{0}, f32[46592]{0}) tuple([[reshape_4_24]], [[reshape_5_26]])
+// CHECK-NEXT: }
+ )");
+}
+
+} // namespace
+} // namespace xla