From b0c214c5a7ce533d152650b13830f63f56b7e868 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Fri, 10 May 2024 06:35:26 -0700 Subject: [XLA] [NFC] Upstream ReduceWindowRewriter pass PiperOrigin-RevId: 632478633 --- third_party/xla/xla/service/BUILD | 40 ++ .../xla/xla/service/reduce_window_rewriter.cc | 545 +++++++++++++++++++++ .../xla/xla/service/reduce_window_rewriter.h | 72 +++ .../xla/xla/service/reduce_window_rewriter_test.cc | 184 +++++++ 4 files changed, 841 insertions(+) create mode 100644 third_party/xla/xla/service/reduce_window_rewriter.cc create mode 100644 third_party/xla/xla/service/reduce_window_rewriter.h create mode 100644 third_party/xla/xla/service/reduce_window_rewriter_test.cc diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 029895a4867..fe3f249ec44 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -7510,6 +7510,46 @@ cc_library( ], ) +cc_library( + name = "reduce_window_rewriter", + srcs = ["reduce_window_rewriter.cc"], + hdrs = ["reduce_window_rewriter.h"], + deps = [ + ":hlo_pass", + "//xla:shape_util", + "//xla:status", + "//xla:status_macros", + "//xla:statusor", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduce_window_rewriter_test", + srcs = ["reduce_window_rewriter_test.cc"], + deps = [ + ":reduce_window_rewriter", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "stochastic_convert_decomposer", srcs = ["stochastic_convert_decomposer.cc"], diff --git a/third_party/xla/xla/service/reduce_window_rewriter.cc b/third_party/xla/xla/service/reduce_window_rewriter.cc new file mode 100644 index 00000000000..39b201e108c --- /dev/null +++ b/third_party/xla/xla/service/reduce_window_rewriter.cc @@ -0,0 +1,545 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/reduce_window_rewriter.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "xla/window_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +static size_t FlattenShapeIndex(const ShapeIndex& shape_index) { + if (shape_index.empty()) { + return 0; + } + CHECK_EQ(shape_index.size(), 1); + return shape_index.back(); +} + +static Shape ShapeAtIndex(const Shape& shape, const ShapeIndex& shape_index) { + if (shape_index.empty()) { + return shape; + } + CHECK_EQ(shape_index.size(), 1); + return ShapeUtil::GetTupleElementShape(shape, shape_index.back()); +} + +static HloInstruction* GetAtIndex(HloInstruction* hlo, + const ShapeIndex& shape_index) { + if (shape_index.empty()) { + return hlo; + } + CHECK_EQ(shape_index.size(), 1); + return hlo->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeAtIndex(hlo->shape(), shape_index), hlo, shape_index.back())); +} + +// Transform reduce-win(x) -> +// if rank(x) == 1: +// then: reshape_r2_r1(reduce-win(reshape_r1_r2(x))) +// else: no change +Status ReduceWindowRewriter::ReplaceReduceWindowWithReshape( + HloReduceWindowInstruction* reduce_window) { + VLOG(2) << "Converting R1 reduce window: " << reduce_window->ToString(); + + std::vector r2_output_shapes; + ShapeUtil::ForEachSubshape( + reduce_window->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (!ShapeUtil::IsLeafIndex(reduce_window->shape(), shape_index)) { + return; + } + Shape r2_output_shape = subshape; + ShapeUtil::AppendMajorDimension(1, &r2_output_shape); + UpdateLayout(&r2_output_shape); + r2_output_shapes.push_back(r2_output_shape); + + VLOG(2) << "ReduceWindowRewriter: Converting R2 result to R1: " + << ShapeUtil::HumanStringWithLayout(r2_output_shape); + }); + + Window r2_window = reduce_window->window(); + WindowDimension* dim = r2_window.add_dimensions(); + dim->set_size(1); + dim->set_stride(1); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + + std::vector r2_operands; + for (HloInstruction* operand : reduce_window->inputs()) { + Shape r2_input_shape = operand->shape(); + ShapeUtil::AppendMajorDimension(1, &r2_input_shape); + UpdateLayout(&r2_input_shape); + + VLOG(2) << "ReduceWindowRewriter: Converting R1 operand to R2: " + << ShapeUtil::HumanStringWithLayout(r2_input_shape); + HloInstruction* r2_operand = operand->parent()->AddInstruction( + HloInstruction::CreateReshape(r2_input_shape, operand)); + VLOG(2) << "R2 new operand: " << r2_operand->ToString(); + r2_operands.push_back(r2_operand); + } + HloInstruction* new_reduce_window = reduce_window->parent()->AddInstruction( + HloInstruction::CreateReduceWindow( + reduce_window->shape().IsTuple() + ? ShapeUtil::MakeTupleShape(r2_output_shapes) + : r2_output_shapes[0], + r2_operands, reduce_window->init_values(), r2_window, + reduce_window->to_apply())); + + VLOG(2) << "R2 resulting reduce window: " << new_reduce_window->ToString(); + + std::vector final_reshapes; + ShapeUtil::ForEachSubshape( + reduce_window->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (!ShapeUtil::IsLeafIndex(reduce_window->shape(), shape_index)) { + return; + } + HloInstruction* final_reshape = + new_reduce_window->parent()->AddInstruction( + HloInstruction::CreateReshape( + subshape, GetAtIndex(new_reduce_window, shape_index))); + final_reshapes.push_back(final_reshape); + }); + HloInstruction* result; + if (reduce_window->shape().IsTuple()) { + result = new_reduce_window->parent()->AddInstruction( + HloInstruction::CreateTuple(final_reshapes)); + } else { + CHECK_EQ(final_reshapes.size(), 1); + result = final_reshapes[0]; + } + TF_RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(result)); + TF_RETURN_IF_ERROR( + new_reduce_window->parent()->RemoveInstruction(reduce_window)); + + return OkStatus(); +} + +absl::StatusOr ReduceWindowRewriter::TryOptimizeCumSumOrProd( + HloReduceWindowInstruction* reduce_window) { + const Shape& operand_shape = reduce_window->inputs().front()->shape(); + + // Try to find the scan axis. We expect all window dimensions to be trivial, + // except for one. + int64_t rank = operand_shape.rank(); + const Window& window = reduce_window->window(); + int64_t scan_dim_num = -1; + for (int i = 0; i < rank; ++i) { + const WindowDimension& window_dim = window.dimensions(i); + if (window_util::IsTrivialWindowDimension(window_dim)) { + continue; + } + if (scan_dim_num != -1) { + // At least two non-trivial dimensions exist, so, no cigar. + return false; + } + scan_dim_num = i; + } + + if (scan_dim_num == -1) { + return false; + } + + const int64_t scan_length = operand_shape.dimensions(scan_dim_num); + absl::Span init_values = reduce_window->init_values(); + const WindowDimension& scan_window_dim = window.dimensions(scan_dim_num); + + bool forward_scan = (scan_window_dim.padding_low() == scan_length - 1 || + scan_window_dim.padding_low() == scan_length) && + scan_window_dim.padding_high() == 0; + bool reverse_scan = (scan_window_dim.padding_high() == scan_length - 1 || + scan_window_dim.padding_high() == scan_length) && + scan_window_dim.padding_low() == 0; + // We accept two values for low padding: the input length for exclusive scan, + // and scan_length - 1 for inclusive scan. + if (scan_window_dim.stride() != 1 || scan_window_dim.size() != scan_length || + (!forward_scan && !reverse_scan) || scan_window_dim.window_reversal() || + scan_window_dim.base_dilation() != 1 || + scan_window_dim.window_dilation() != 1) { + return false; + } + bool is_exclusive = forward_scan + ? (scan_window_dim.padding_low() == scan_length) + : (scan_window_dim.padding_high() == scan_length); + + if (scan_length <= base_length_) { + return false; + } + + if (reduce_window->to_apply()->root_instruction()->shape().IsTuple() && + reduce_window->to_apply()->root_instruction()->opcode() != + HloOpcode::kTuple) { + return false; + } + + VLOG(2) << "Rewriting Scan: " << reduce_window->ToString(); + HloComputation* parent = reduce_window->parent(); + std::vector sources(reduce_window->inputs().begin(), + reduce_window->inputs().end()); + + // Since we need to tile this dimension, it's convenient to have it logically + // last. + std::vector permutation(rank); + absl::c_iota(permutation, 0); + permutation[scan_dim_num] = rank - 1; + permutation[rank - 1] = scan_dim_num; + if (scan_dim_num != rank - 1) { + for (size_t i = 0; i < sources.size(); ++i) { + sources[i] = parent->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(permutation, sources[i]->shape()), + sources[i], permutation)); + } + } + + // We don't actually need to match the computation - this transformation will + // work for an commutative/associative reducer, which is what we assume for + // ReduceWindow anyway. + + // Break the scan into an "inner" and an "outer" scan - this is basically a + // tree reduction: + // (The explanation below assumes an R1 scan for simplicity. For Rk scan, all + // shapes have k-1 "batch" dimensions that need to be preserved.) + // + // 1) If necessary, pad input from {N} to {K}, where K is a multiple of 128. + // 2) Reshape from {K} to {K / 128, 128}. + // 3) Scan each 128 dimension. + // 4) Slice out the last column. + // 5) Exclusive scan across the last column. + // 6) Broadcast it back into {K / 128, 128} + // 7) Add up the results of (3) and (6). + // 8) Reshape back into {K} + // 9) Slice off the padding. + // + // For example, consider a cumulative sum over an R1 of length 9, with a base + // case of 3 instead of 128. Let the input be: + // [0 1 2 3 4 5 6 7 8] + // + // We need no padding, so we go directly to (2): + // [0 1 2 + // 3 4 5 + // 6 7 8] + // + // The result of the scan in (3) is: + // [0 1 3 + // 3 7 12 + // 6 13 21] + // + // Slicing out the last column we get (4): + // [ 3 + // 12 + // 21] + // + // And after scanning and broadcasting (5 and 6): + // [ 0 0 0 + // 3 3 3 + // 15 15 15] + // + // Finally, we add up the two scans (3) and (6), getting (7): + // [ 0 1 3 + // 6 10 15 + // 21 28 36] + // + // And reshape back into [0 1 3 6 10 15 21 28 36]. + // + // For reverse scans, we perform the same as forward scans, except: we perform + // a reverse scan at (3), slice out the first column at (4), and perform an + // exclusive reverse scan of the first columnt at (5). + + // Pad. + const int64_t padded_length = RoundUpTo(scan_length, base_length_); + if (scan_length != padded_length) { + for (size_t i = 0; i < sources.size(); ++i) { + auto* source = sources[i]; + Shape padded_shape = source->shape(); + padded_shape.set_dimensions(rank - 1, padded_length); + + UpdateLayout(&padded_shape); + auto padding_config = MakeNoPaddingConfig(rank); + padding_config.mutable_dimensions(rank - 1)->set_edge_padding_high( + padded_length - scan_length); + + sources[i] = parent->AddInstruction(HloInstruction::CreatePad( + padded_shape, source, init_values[i], padding_config)); + } + } + + // Reshape to R(k+1). + const int64_t num_columns = padded_length / base_length_; + std::vector tiled_sources; + std::vector tiled_shapes; + for (size_t i = 0; i < sources.size(); ++i) { + auto* source = sources[i]; + Shape tiled_shape = source->shape(); + tiled_shape.set_dimensions(rank - 1, num_columns); + + UpdateLayout(&tiled_shape); + ShapeUtil::AppendMajorDimension(base_length_, &tiled_shape); + tiled_shapes.push_back(tiled_shape); + tiled_sources.push_back(parent->AddInstruction( + HloInstruction::CreateReshape(tiled_shape, source))); + } + + // Outer scan. + Window outer_window = + window_util::MakeWindow(std::vector(rank + 1, 1)); + outer_window.mutable_dimensions(rank)->set_size(base_length_); + if (forward_scan) { + outer_window.mutable_dimensions(rank)->set_padding_low(base_length_ - 1); + } else { + outer_window.mutable_dimensions(rank)->set_padding_high(base_length_ - 1); + } + auto outer_reduce_window = + parent->AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window->shape().IsTuple() + ? ShapeUtil::MakeTupleShape(tiled_shapes) + : tiled_shapes[0], + tiled_sources, init_values, outer_window, reduce_window->to_apply())); + + // Slice out the last (first if reverse scan) column. + std::vector column_shapes; + std::vector last_cols; + ShapeUtil::ForEachSubshape( + outer_reduce_window->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (!ShapeUtil::IsLeafIndex(outer_reduce_window->shape(), + shape_index)) { + return; + } + Shape column_shape = subshape; + column_shape.set_dimensions(rank, 1); + + UpdateLayout(&column_shape); + std::vector col_slice_starts(rank + 1, 0); + std::vector col_slice_limits( + SpanToVector(subshape.dimensions())); + if (forward_scan) { + col_slice_starts[rank] = base_length_ - 1; + } else { + col_slice_limits[rank] = 1; + } + auto last_col = parent->AddInstruction(HloInstruction::CreateSlice( + column_shape, GetAtIndex(outer_reduce_window, shape_index), + col_slice_starts, col_slice_limits, + std::vector(rank + 1, 1))); + column_shape.DeleteDimension(rank); + last_col = parent->AddInstruction( + HloInstruction::CreateReshape(column_shape, last_col)); + last_cols.push_back(last_col); + + column_shape.set_dimensions(rank - 1, num_columns + 1); + UpdateLayout(&column_shape); + column_shapes.push_back(column_shape); + }); + + // Inner scan + Window inner_window = window_util::MakeWindow(std::vector(rank, 1)); + inner_window.mutable_dimensions(rank - 1)->set_size(num_columns); + if (forward_scan) { + inner_window.mutable_dimensions(rank - 1)->set_padding_low(num_columns); + } else { + inner_window.mutable_dimensions(rank - 1)->set_padding_high(num_columns); + } + auto inner_reduce_window = + parent->AddInstruction(HloInstruction::CreateReduceWindow( + reduce_window->shape().IsTuple() + ? ShapeUtil::MakeTupleShape(column_shapes) + : column_shapes[0], + last_cols, init_values, inner_window, reduce_window->to_apply())); + std::vector exclusive_slice_starts(rank, 0); + std::vector exclusive_slice_limits = + SpanToVector(column_shapes[0].dimensions()); + if (forward_scan) { + exclusive_slice_limits[rank - 1] = num_columns; + } else { + exclusive_slice_starts[rank - 1] = 1; + exclusive_slice_limits[rank - 1] = num_columns + 1; + } + std::vector inner_scan_components; + ShapeUtil::ForEachSubshape( + inner_reduce_window->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (!ShapeUtil::IsLeafIndex(inner_reduce_window->shape(), + shape_index)) { + return; + } + size_t idx = FlattenShapeIndex(shape_index); + auto last_col = last_cols[idx]; + auto* inner_slice = parent->AddInstruction(HloInstruction::CreateSlice( + last_col->shape(), GetAtIndex(inner_reduce_window, shape_index), + exclusive_slice_starts, exclusive_slice_limits, + std::vector(rank, 1))); + + std::vector rank_iota(rank); + absl::c_iota(rank_iota, 0); + auto* inner_scan_component = + parent->AddInstruction(HloInstruction::CreateBroadcast( + tiled_shapes[idx], inner_slice, rank_iota)); + inner_scan_components.push_back(inner_scan_component); + }); + + // Combine inner and outer scans. + std::vector map_operands; + ShapeUtil::ForEachSubshape( + outer_reduce_window->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (!ShapeUtil::IsLeafIndex(outer_reduce_window->shape(), + shape_index)) { + return; + } + map_operands.push_back(GetAtIndex(outer_reduce_window, shape_index)); + }); + map_operands.insert(map_operands.end(), inner_scan_components.begin(), + inner_scan_components.end()); + + // Reshape back to Rk and slice out the padding. + std::vector scans; + auto status = ShapeUtil::ForEachSubshapeWithStatus( + outer_reduce_window->shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) -> Status { + if (!ShapeUtil::IsLeafIndex(outer_reduce_window->shape(), + shape_index)) { + return OkStatus(); + } + size_t idx = FlattenShapeIndex(shape_index); + auto source = sources[idx]; + HloComputation* map_computation; + auto reduce_function_root = + reduce_window->to_apply()->root_instruction(); + if (reduce_function_root->shape().IsTuple()) { + TF_RET_CHECK(reduce_function_root->opcode() == HloOpcode::kTuple); + // This corresponds to step 7: combining the inner scan with the outer + // scan using a map function. + auto* map_computation_root = reduce_function_root->operand(idx); + absl::flat_hash_map> + replacements; + replacements[reduce_function_root] = nullptr; + map_computation = parent->parent()->AddEmbeddedComputation( + reduce_window->to_apply()->CloneWithReplacements( + &replacements, + /*extra_parameters=*/{}, nullptr, "clone", + map_computation_root)); + } else { + map_computation = reduce_window->to_apply(); + } + auto scan = parent->AddInstruction(HloInstruction::CreateMap( + ShapeAtIndex(outer_reduce_window->shape(), shape_index), + map_operands, map_computation)); + scan = parent->AddInstruction( + HloInstruction::CreateReshape(source->shape(), scan)); + + // If necessary, transpose back to the original order. + if (scan_dim_num != rank - 1) { + scan = parent->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::PermuteDimensions(permutation, source->shape()), scan, + permutation)); + } + + // Remove the padding to the base length. + if (padded_length != scan_length) { + scan = parent->AddInstruction(HloInstruction::CreateSlice( + operand_shape, scan, std::vector(rank, 0), + operand_shape.dimensions(), std::vector(rank, 1))); + } + + if (is_exclusive) { + auto padding_config = MakeNoPaddingConfig(rank); + if (forward_scan) { + padding_config.mutable_dimensions(scan_dim_num) + ->set_edge_padding_low(1); + } else { + padding_config.mutable_dimensions(scan_dim_num) + ->set_edge_padding_high(1); + } + scan = parent->AddInstruction(HloInstruction::CreatePad( + ShapeAtIndex(reduce_window->shape(), shape_index), scan, + init_values[idx], padding_config)); + } + scans.push_back(scan); + return OkStatus(); + }); + TF_RETURN_IF_ERROR(status); + + HloInstruction* scan; + if (reduce_window->shape().IsTuple()) { + scan = parent->AddInstruction(HloInstruction::CreateTuple(scans)); + } else { + CHECK_EQ(scans.size(), 1); + scan = scans[0]; + } + TF_RETURN_IF_ERROR(reduce_window->ReplaceAllUsesWith(scan)); + TF_RETURN_IF_ERROR(parent->RemoveInstruction(reduce_window)); + + return true; +} + +absl::StatusOr ReduceWindowRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (const auto& computation : module->computations(execution_threads)) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + HloReduceWindowInstruction* reduce_window = + DynCast(instruction); + if (!reduce_window) { + continue; + } + TF_ASSIGN_OR_RETURN(bool made_change, + TryOptimizeCumSumOrProd(reduce_window)); + if (made_change) { + changed = true; + continue; + } + + if (reduce_window->inputs().front()->shape().rank() != 1) { + continue; + } + TF_RETURN_IF_ERROR(ReplaceReduceWindowWithReshape(reduce_window)); + + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/reduce_window_rewriter.h b/third_party/xla/xla/service/reduce_window_rewriter.h new file mode 100644 index 00000000000..6fed774b27a --- /dev/null +++ b/third_party/xla/xla/service/reduce_window_rewriter.h @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ +#define XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ + +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/status.h" +#include "xla/statusor.h" + +namespace xla { + +// Rewrite ReduceWindow to be more performant in cases it is written in a +// quadratic way: +// +// 1) Work around unimplemented cases in the implementation of ReduceWindow. +// +// This rewrites all R1 ReduceWindow nodes. We reshape the operand to an +// R2, perform the operation, and reshape back to R1. The reshapes correspond to +// a bitcast if the tensor length is less than or equal to a passed parameter. +// The motivation for this is to avoid use of overly large reductions and the +// complexities and restrictions therein. +// +// 2) Rewrite ReduceWindow ops that represent a CumSum/CumProd into a +// tree-reduction (see details in the implementation). +// Note that this may itself generate R1 ReduceWindow ops, which means this pass +// needs to be run to a fixed point. +class ReduceWindowRewriter : public HloModulePass { + public: + // `base_length` is a size of a reduce-window we are comfortable with + // executing. + explicit ReduceWindowRewriter(int64_t base_length) + : base_length_(base_length) {} + + absl::string_view name() const override { return "reduce-window-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + Status ReplaceReduceWindowWithReshape( + HloReduceWindowInstruction* reduce_window); + + StatusOr TryOptimizeCumSumOrProd( + HloReduceWindowInstruction* reduce_window); + + int64_t base_length_; +}; + +} // namespace xla + +#endif // XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ diff --git a/third_party/xla/xla/service/reduce_window_rewriter_test.cc b/third_party/xla/xla/service/reduce_window_rewriter_test.cc new file mode 100644 index 00000000000..b40314f6e4d --- /dev/null +++ b/third_party/xla/xla/service/reduce_window_rewriter_test.cc @@ -0,0 +1,184 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/reduce_window_rewriter.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace { + +class ReduceWindowRewriterTest : public HloTestBase { + public: + void CheckReduceWindowRewrite(absl::string_view hlo, + std::optional expected) { + RunAndFilecheckHloRewrite(hlo, ReduceWindowRewriter{128}, expected); + } +}; + +TEST_F(ReduceWindowRewriterTest, EliminateR1) { + const char* hlo = R"( +%binary_add { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %a, f32[] %b) +} + +ENTRY %EliminateR1 (input: f32[10]) -> f32[10] { + %input = f32[10]{0} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce-window = f32[10]{0} reduce-window(f32[10]{0} %input, f32[] %constant), window={size=5 pad=2_2}, to_apply=%binary_add +} +)"; + + CheckReduceWindowRewrite(hlo, R"( +// CHECK: [[reduce_window_1_0:%[^ ]+]] = f32[10,1]{0,1} reduce-window([[reshape_1:%[^ ]+]], [[constant_2:%[^ ]+]]), window={size=5x1 pad=2_2x0_0}, to_apply=[[binary_add_3:%[^ ]+]] +// CHECK-NEXT: ROOT [[reshape_1_4:%[^ ]+]] = f32[10]{0} reshape([[reduce_window_1_0]]) +)"); +} + +TEST_F(ReduceWindowRewriterTest, EliminateR1Variadic) { + const char* hlo = R"( +HloModule reduce-window + +add_float { + lhs.0 = f32[] parameter(0) + lhs.1 = f32[] parameter(1) + rhs.0 = f32[] parameter(2) + rhs.1 = f32[] parameter(3) + sum.0 = f32[] add(lhs.0, rhs.0) + sum.1 = f32[] add(lhs.1, rhs.1) + ROOT root = (f32[], f32[]) tuple(sum.0, sum.1) +} + +ENTRY entry (arg: f32[10]) -> (f32[10], f32[10]) { + arg = f32[10]{0} parameter(0) + constant = f32[] constant(0) + ROOT reduce-window = (f32[10]{0}, f32[10]{0}) reduce-window(f32[10]{0} %arg, f32[10]{0} %arg, f32[] %constant, f32[] %constant), window={size=5 pad=2_2}, to_apply=%add_float +})"; + + CheckReduceWindowRewrite(hlo, R"( +// CHECK: ENTRY %entry (arg: f32[10]) -> (f32[10], f32[10]) { +// CHECK-NEXT: [[arg_0:%[^ ]+]] = f32[10]{0} parameter(0) +// CHECK-NEXT: [[reshape_1:%[^ ]+]] = f32[10,1]{0,1} reshape([[arg_0]]) +// CHECK-NEXT: [[reshape_1_2:%[^ ]+]] = f32[10,1]{0,1} reshape([[arg_0]]) +// CHECK-NEXT: [[constant_3:%[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: [[reduce_window_1_4:%[^ ]+]] = (f32[10,1]{0,1}, f32[10,1]{0,1}) reduce-window([[reshape_1]], [[reshape_1_2]], [[constant_3]], [[constant_3]]), window={size=5x1 pad=2_2x0_0}, to_apply=[[add_float_5:%[^ ]+]] +// CHECK-NEXT: [[get_tuple_element_6:%[^ ]+]] = f32[10,1]{0,1} get-tuple-element([[reduce_window_1_4]]), index=0 +// CHECK-NEXT: [[reshape_2_7:%[^ ]+]] = f32[10]{0} reshape([[get_tuple_element_6]]) +// CHECK-NEXT: [[get_tuple_element_1_8:%[^ ]+]] = f32[10,1]{0,1} get-tuple-element([[reduce_window_1_4]]), index=1 +// CHECK-NEXT: [[reshape_3_9:%[^ ]+]] = f32[10]{0} reshape([[get_tuple_element_1_8]]) +// CHECK-NEXT: ROOT [[tuple_10:%[^ ]+]] = (f32[10]{0}, f32[10]{0}) tuple([[reshape_2_7]], [[reshape_3_9]]) +// CHECK-NEXT:} +)"); +} + +TEST_F(ReduceWindowRewriterTest, OptimizeR1InclusiveScan) { + const char* hlo = R"( +HloModule reduce-window + +add_float { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +ENTRY entry (arg: f32[46592]) -> f32[46592] { + arg = f32[46592]{0} parameter(0) + constant = f32[] constant(0) + ROOT reduce-window = f32[46592]{0} reduce-window(f32[46592]{0} %arg, f32[] %constant), window={size=46592 pad=46591_0}, to_apply=%add_float +})"; + + CheckReduceWindowRewrite(hlo, R"( +// CHECK: ENTRY %entry (arg: f32[46592]) -> f32[46592] { +// CHECK-NEXT: [[arg_0:%[^ ]+]] = f32[46592]{0} parameter(0) +// CHECK-NEXT: [[reshape_1:%[^ ]+]] = f32[364,128]{0,1} reshape([[arg_0]]) +// CHECK-NEXT: [[constant_2:%[^ ]+]] = f32[] constant(0) +// CHECK-NEXT: [[reduce_window_1_3:%[^ ]+]] = f32[364,128]{0,1} reduce-window([[reshape_1]], [[constant_2]]), window={size=1x128 pad=0_0x127_0}, to_apply=[[add_float_4:%[^ ]+]] +// CHECK-NEXT: [[slice_5:%[^ ]+]] = f32[364,1]{0,1} slice([[reduce_window_1_3]]), slice={[0:364], [127:128]} +// CHECK-NEXT: [[reshape_1_6:%[^ ]+]] = f32[364]{0} reshape([[slice_5]]) +// CHECK-NEXT: [[reduce_window_2_7:%[^ ]+]] = f32[365]{0} reduce-window([[reshape_1_6]], [[constant_2]]), window={size=364 pad=364_0}, to_apply=[[add_float_4]] +// CHECK-NEXT: [[slice_1_8:%[^ ]+]] = f32[364]{0} slice([[reduce_window_2_7]]), slice={[0:364]} +// CHECK-NEXT: [[broadcast_9:%[^ ]+]] = f32[364,128]{0,1} broadcast([[slice_1_8]]), dimensions={0} +// CHECK-NEXT: [[map_10:%[^ ]+]] = f32[364,128]{0,1} map([[reduce_window_1_3]], [[broadcast_9]]), dimensions={0,1}, to_apply=[[add_float_4]] +// CHECK-NEXT: ROOT [[reshape_2_11:%[^ ]+]] = f32[46592]{0} reshape([[map_10]]) +// CHECK-NEXT:} +)"); +} + +TEST_F(ReduceWindowRewriterTest, OptimizeR1InclusiveScanVariadic) { + const std::string hlo_string = R"( +HloModule reduce-window + +MaxMin { + l.max = f32[] parameter(0) + l.min = f32[] parameter(1) + r.max = f32[] parameter(2) + r.min = f32[] parameter(3) + max = f32[] maximum(l.max, r.max) + min = f32[] minimum(l.min, r.min) + ROOT root = (f32[], f32[]) tuple(max, min) +} + +ENTRY entry (arg_0: f32[46592], arg_1: f32[46592]) -> (f32[46592], f32[46592]) { + arg.0 = f32[46592]{0} parameter(0) + arg.1 = f32[46592]{0} parameter(1) + init_ninf = f32[] constant(-inf) + init_inf = f32[] constant(inf) + ROOT reduce-window = (f32[46592]{0}, f32[46592]{0}) reduce-window(f32[46592]{0} %arg.0, f32[46592]{0} %arg.1, f32[] %init_ninf, f32[] %init_inf), window={size=46592 pad=46591_0}, to_apply=%MaxMin +} +)"; + + CheckReduceWindowRewrite(hlo_string, R"( +// CHECK: ENTRY %entry (arg.0: f32[46592], arg.1: f32[46592]) -> (f32[46592], f32[46592]) { +// CHECK-NEXT: [[arg_0_0:%[^ ]+]] = f32[46592]{0} parameter(0) +// CHECK-NEXT: [[reshape_1:%[^ ]+]] = f32[364,128]{0,1} reshape([[arg_0_0]]) +// CHECK-NEXT: [[arg_1_2:%[^ ]+]] = f32[46592]{0} parameter(1) +// CHECK-NEXT: [[reshape_1_3:%[^ ]+]] = f32[364,128]{0,1} reshape([[arg_1_2]]) +// CHECK-NEXT: [[init_ninf_4:%[^ ]+]] = f32[] constant(-inf) +// CHECK-NEXT: [[init_inf_5:%[^ ]+]] = f32[] constant(inf) +// CHECK-NEXT: [[reduce_window_1_6:%[^ ]+]] = (f32[364,128]{0,1}, f32[364,128]{0,1}) reduce-window([[reshape_1]], [[reshape_1_3]], [[init_ninf_4]], [[init_inf_5]]), window={size=1x128 pad=0_0x127_0}, to_apply=[[MaxMin_7:%[^ ]+]] +// CHECK-NEXT: [[get_tuple_element_4_8:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=0 +// CHECK-NEXT: [[get_tuple_element_5_9:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=1 +// CHECK-NEXT: [[get_tuple_element_10:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=0 +// CHECK-NEXT: [[slice_11:%[^ ]+]] = f32[364,1]{0,1} slice([[get_tuple_element_10]]), slice={[0:364], [127:128]} +// CHECK-NEXT: [[reshape_2_12:%[^ ]+]] = f32[364]{0} reshape([[slice_11]]) +// CHECK-NEXT: [[get_tuple_element_1_13:%[^ ]+]] = f32[364,128]{0,1} get-tuple-element([[reduce_window_1_6]]), index=1 +// CHECK-NEXT: [[slice_1_14:%[^ ]+]] = f32[364,1]{0,1} slice([[get_tuple_element_1_13]]), slice={[0:364], [127:128]} +// CHECK-NEXT: [[reshape_3_15:%[^ ]+]] = f32[364]{0} reshape([[slice_1_14]]) +// CHECK-NEXT: [[reduce_window_2_16:%[^ ]+]] = (f32[365]{0}, f32[365]{0}) reduce-window([[reshape_2_12]], [[reshape_3_15]], [[init_ninf_4]], [[init_inf_5]]), window={size=364 pad=364_0}, to_apply=[[MaxMin_7]] +// CHECK-NEXT: [[get_tuple_element_2_17:%[^ ]+]] = f32[365]{0} get-tuple-element([[reduce_window_2_16]]), index=0 +// CHECK-NEXT: [[slice_2_18:%[^ ]+]] = f32[364]{0} slice([[get_tuple_element_2_17]]), slice={[0:364]} +// CHECK-NEXT: [[broadcast_19:%[^ ]+]] = f32[364,128]{0,1} broadcast([[slice_2_18]]), dimensions={0} +// CHECK-NEXT: [[get_tuple_element_3_20:%[^ ]+]] = f32[365]{0} get-tuple-element([[reduce_window_2_16]]), index=1 +// CHECK-NEXT: [[slice_3_21:%[^ ]+]] = f32[364]{0} slice([[get_tuple_element_3_20]]), slice={[0:364]} +// CHECK-NEXT: [[broadcast_1_22:%[^ ]+]] = f32[364,128]{0,1} broadcast([[slice_3_21]]), dimensions={0} +// CHECK-NEXT: [[map_23:%[^ ]+]] = f32[364,128]{0,1} map([[get_tuple_element_4_8]], [[get_tuple_element_5_9]], [[broadcast_19]], [[broadcast_1_22]]), dimensions={0,1}, to_apply=[[MaxMin_7]].clone +// CHECK-NEXT: [[reshape_4_24:%[^ ]+]] = f32[46592]{0} reshape([[map_23]]) +// CHECK-NEXT: [[map_1_25:%[^ ]+]] = f32[364,128]{0,1} map([[get_tuple_element_4_8]], [[get_tuple_element_5_9]], [[broadcast_19]], [[broadcast_1_22]]), dimensions={0,1}, to_apply=[[MaxMin_7]].clone.1 +// CHECK-NEXT: [[reshape_5_26:%[^ ]+]] = f32[46592]{0} reshape([[map_1_25]]) +// CHECK-NEXT: ROOT [[tuple_27:%[^ ]+]] = (f32[46592]{0}, f32[46592]{0}) tuple([[reshape_4_24]], [[reshape_5_26]]) +// CHECK-NEXT: } + )"); +} + +} // namespace +} // namespace xla -- cgit v1.2.3