diff --git a/mlir/include/mlir/Analysis/LoopAnalysis.h b/mlir/include/mlir/Analysis/LoopAnalysis.h --- a/mlir/include/mlir/Analysis/LoopAnalysis.h +++ b/mlir/include/mlir/Analysis/LoopAnalysis.h @@ -22,6 +22,7 @@ class AffineExpr; class AffineForOp; class AffineMap; +class BlockArgument; class MemRefType; class NestedPattern; class Operation; @@ -83,6 +84,37 @@ // TODO: extend this to check for memory-based dependence violation when we have // the support. bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef shifts); + +/// Utility to match a generic reduction given a list of iteration-carried +/// arguments, `iterCarriedArgs` and the position of the potential reduction +/// argument within the list, `redPos`. If a reduction is matched, returns the +/// reduced value and the topologically-sorted list of combiner operations +/// involved in the reduction. Otherwise, returns a null value. +/// +/// The matching algorithm relies on the following invariants, which are subject +/// to change: +/// 1. The first combiner operation must be a binary operation with the +/// iteration-carried value and the reduced value as operands. +/// 2. The iteration-carried value and combiner operations must be side +/// effect-free, have single result and a single use. +/// 3. Combiner operations must be immediately nested in the region op +/// performing the reduction. +/// 4. Reduction def-use chain must end in a terminator op that yields the +/// next iteration/output values in the same order as the iteration-carried +/// values in `iterCarriedArgs`. +/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values +/// of the region op performing the reduction. +/// +/// This utility is generic enough to detect reductions involving multiple +/// combiner operations (disabled for now) across multiple dialects, including +/// Linalg, Affine and SCF. For the sake of genericity, it does not return +/// specific enum values for the combiner operations since its goal is also +/// matching reductions without pre-defined semantics in core MLIR. It's up to +/// each client to make sense out of the list of combiner operations. It's also +/// up to each client to check for additional invariants on the expected +/// reductions not covered by this generic matching. +Value matchReduction(ArrayRef iterCarriedArgs, unsigned redPos, + SmallVectorImpl &combinerOps); } // end namespace mlir #endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -595,6 +595,19 @@ return 0; }] >, + InterfaceMethod< + /*desc=*/[{ + Return the output block arguments of the region. + }], + /*retTy=*/"Block::BlockArgListType", + /*methodName=*/"getRegionOutputArgs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + Block &entryBlock = this->getOperation()->getRegion(0).front(); + return entryBlock.getArguments().take_back(this->getNumOutputs()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the `opOperand` shape or an empty vector for scalars. diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -30,6 +30,7 @@ class Operation; class OperationName; class Type; +class Value; namespace detail { struct DiagnosticEngineImpl; @@ -218,6 +219,9 @@ return *this << *val; } + /// Stream in a Value. + Diagnostic &operator<<(Value val); + /// Stream in a range. template > std::enable_if_t::value, diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -22,7 +23,6 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -33,29 +33,6 @@ using llvm::dbgs; -/// Returns true if `value` (transitively) depends on iteration arguments of the -/// given `forOp`. -static bool dependsOnIterArgs(Value value, AffineForOp forOp) { - // Compute the backward slice of the value. - SetVector slice; - getBackwardSlice(value, &slice, - [&](Operation *op) { return !forOp->isAncestor(op); }); - - // Check that none of the operands of the operations in the backward slice are - // loop iteration arguments, and neither is the value itself. - auto argRange = forOp.getRegionIterArgs(); - llvm::SmallPtrSet iterArgs(argRange.begin(), argRange.end()); - if (iterArgs.contains(value)) - return true; - - for (Operation *op : slice) - for (Value operand : op->getOperands()) - if (iterArgs.contains(operand)) - return true; - - return false; -} - /// Get the value that is being reduced by `pos`-th reduction in the loop if /// such a reduction can be performed by affine parallel loops. This assumes /// floating-point operations are commutative. On success, `kind` will be the @@ -63,18 +40,19 @@ /// reduction is not supported, returns null. static Value getSupportedReduction(AffineForOp forOp, unsigned pos, AtomicRMWKind &kind) { - auto yieldOp = cast(forOp.getBody()->back()); - Value yielded = yieldOp.operands()[pos]; - Operation *definition = yielded.getDefiningOp(); - if (!definition) + SmallVector combinerOps; + Value reducedVal = + matchReduction(forOp.getRegionIterArgs(), pos, combinerOps); + if (!reducedVal) return nullptr; - if (!forOp.getRegionIterArgs()[pos].hasOneUse()) - return nullptr; - if (!yielded.hasOneUse()) + + // Expected only one combiner operation. + if (combinerOps.size() > 1) return nullptr; + Operation *combinerOp = combinerOps.back(); Optional maybeKind = - TypeSwitch>(definition) + TypeSwitch>(combinerOp) .Case([](Operation *) { return AtomicRMWKind::addf; }) .Case([](Operation *) { return AtomicRMWKind::mulf; }) .Case([](Operation *) { return AtomicRMWKind::addi; }) @@ -88,14 +66,7 @@ return nullptr; kind = *maybeKind; - if (definition->getOperand(0) == forOp.getRegionIterArgs()[pos] && - !dependsOnIterArgs(definition->getOperand(1), forOp)) - return definition->getOperand(1); - if (definition->getOperand(1) == forOp.getRegionIterArgs()[pos] && - !dependsOnIterArgs(definition->getOperand(0), forOp)) - return definition->getOperand(0); - - return nullptr; + return reducedVal; } /// Returns true if `forOp' is a parallel loop. If `parallelReductions` is diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -15,11 +15,13 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include @@ -392,3 +394,105 @@ } return true; } + +/// Returns true if `value` (transitively) depends on iteration-carried values +/// of the given `ancestorOp`. +static bool dependsOnCarriedVals(Value value, + ArrayRef iterCarriedArgs, + Operation *ancestorOp) { + // Compute the backward slice of the value. + SetVector slice; + getBackwardSlice(value, &slice, + [&](Operation *op) { return !ancestorOp->isAncestor(op); }); + + // Check that none of the operands of the operations in the backward slice are + // loop iteration arguments, and neither is the value itself. + SmallPtrSet iterCarriedValSet(iterCarriedArgs.begin(), + iterCarriedArgs.end()); + if (iterCarriedValSet.contains(value)) + return true; + + for (Operation *op : slice) + for (Value operand : op->getOperands()) + if (iterCarriedValSet.contains(operand)) + return true; + + return false; +} + +/// Utility to match a generic reduction given a list of iteration-carried +/// arguments, `iterCarriedArgs` and the position of the potential reduction +/// argument within the list, `redPos`. If a reduction is matched, returns the +/// reduced value and the topologically-sorted list of combiner operations +/// involved in the reduction. Otherwise, returns a null value. +/// +/// The matching algorithm relies on the following invariants, which are subject +/// to change: +/// 1. The first combiner operation must be a binary operation with the +/// iteration-carried value and the reduced value as operands. +/// 2. The iteration-carried value and combiner operations must be side +/// effect-free, have single result and a single use. +/// 3. Combiner operations must be immediately nested in the region op +/// performing the reduction. +/// 4. Reduction def-use chain must end in a terminator op that yields the +/// next iteration/output values in the same order as the iteration-carried +/// values in `iterCarriedArgs`. +/// 5. `iterCarriedArgs` must contain all the iteration-carried/output values +/// of the region op performing the reduction. +/// +/// This utility is generic enough to detect reductions involving multiple +/// combiner operations (disabled for now) across multiple dialects, including +/// Linalg, Affine and SCF. For the sake of genericity, it does not return +/// specific enum values for the combiner operations since its goal is also +/// matching reductions without pre-defined semantics in core MLIR. It's up to +/// each client to make sense out of the list of combiner operations. It's also +/// up to each client to check for additional invariants on the expected +/// reductions not covered by this generic matching. +Value mlir::matchReduction(ArrayRef iterCarriedArgs, + unsigned redPos, + SmallVectorImpl &combinerOps) { + assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds"); + + BlockArgument redCarriedVal = iterCarriedArgs[redPos]; + if (!redCarriedVal.hasOneUse()) + return nullptr; + + // For now, the first combiner op must be a binary op. + Operation *combinerOp = *redCarriedVal.getUsers().begin(); + if (combinerOp->getNumOperands() != 2) + return nullptr; + Value reducedVal = combinerOp->getOperand(0) == redCarriedVal + ? combinerOp->getOperand(1) + : combinerOp->getOperand(0); + + Operation *redRegionOp = + iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); + if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp)) + return nullptr; + + // Traverse the def-use chain starting from the first combiner op until a + // terminator is found. Gather all the combiner ops along the way in + // topological order. + while (!combinerOp->mightHaveTrait()) { + if (!MemoryEffectOpInterface::hasNoEffect(combinerOp) || + combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() || + combinerOp->getParentOp() != redRegionOp) + return nullptr; + + combinerOps.push_back(combinerOp); + combinerOp = *combinerOp->getUsers().begin(); + } + + // Limit matching to single combiner op until we can properly test reductions + // involving multiple combiners. + if (combinerOps.size() != 1) + return nullptr; + + // Check that the yielded value is in the same position as in + // `iterCarriedArgs`. + Operation *terminatorOp = combinerOp; + if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) + return nullptr; + + return reducedVal; +} diff --git a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToOpenMP/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRAnalysis MLIRLLVMIR MLIROpenMP MLIRSCF diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "../PassDetail.h" +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -34,10 +35,21 @@ if (block.empty() || llvm::hasSingleElement(block) || std::next(block.begin(), 2) != block.end()) return false; - return isa(block.front()) && + + if (block.getNumArguments() != 2) + return false; + + SmallVector combinerOps; + Value reducedVal = matchReduction({block.getArguments()[1]}, + /*redPos=*/0, combinerOps); + + if (!reducedVal || !reducedVal.isa() || + combinerOps.size() != 1) + return false; + + return isa(combinerOps[0]) && isa(block.back()) && - block.front().getOperands() == block.getArguments() && - block.back().getOperand(0) == block.front().getResult(0); + block.front().getOperands() == block.getArguments(); } /// Matches a block containing a select-based min/max reduction. The types of diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -110,46 +111,24 @@ return VectorType::get(st.getShape(), st.getElementType()); } -/// Given an `outputOperand` of a LinalgOp, compute the intersection of the -/// forward slice starting from `outputOperand` and the backward slice -/// starting from the corresponding linalg.yield operand. -/// This intersection is assumed to have a single binary operation that is -/// the reduction operation. Multiple reduction operations would impose an +/// Check whether `outputOperand` is a reduction with a single combiner +/// operation. Return the combiner operation of the reduction, which is assumed +/// to be a binary operation. Multiple reduction operations would impose an /// ordering between reduction dimensions and is currently unsupported in -/// Linalg. This limitation is motivated by the fact that e.g. -/// min(max(X)) != max(min(X)) +/// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != +/// max(min(X)) // TODO: use in LinalgOp verification, there is a circular dependency atm. static Operation *getSingleBinaryOpAssumedReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); - auto yieldOp = cast(linalgOp->getRegion(0).front().getTerminator()); - unsigned yieldNum = + unsigned outputPos = outputOperand->getOperandNumber() - linalgOp.getNumInputs(); - llvm::SetVector backwardSlice, forwardSlice; - BlockArgument bbArg = linalgOp->getRegion(0).front().getArgument( - outputOperand->getOperandNumber()); - Value yieldVal = yieldOp->getOperand(yieldNum); - getBackwardSlice(yieldVal, &backwardSlice, [&](Operation *op) { - return op->getParentOp() == linalgOp; - }); - backwardSlice.insert(yieldVal.getDefiningOp()); - getForwardSlice(bbArg, &forwardSlice, - [&](Operation *op) { return op->getParentOp() == linalgOp; }); - // Search for the (assumed unique) elementwiseMappable op at the intersection - // of forward and backward slices. - Operation *reductionOp = nullptr; - for (Operation *op : llvm::reverse(backwardSlice)) { - if (!forwardSlice.contains(op)) - continue; - if (OpTrait::hasElementwiseMappableTraits(op)) { - if (reductionOp) { - // Reduction detection fails: found more than 1 elementwise-mappable op. - return nullptr; - } - reductionOp = op; - } - } + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || + combinerOps.size() != 1) + return nullptr; + // TODO: also assert no other subsequent ops break the reduction. - return reductionOp; + return combinerOps[0]; } /// If `value` of assumed VectorType has a shape different than `shape`, try to diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -131,6 +131,14 @@ return *this << os.str(); } +/// Stream in a Value. +Diagnostic &Diagnostic::operator<<(Value val) { + std::string str; + llvm::raw_string_ostream os(str); + val.print(os); + return *this << os.str(); +} + /// Outputs this diagnostic to a stream. void Diagnostic::print(raw_ostream &os) const { for (auto &arg : getArguments()) diff --git a/mlir/test/Analysis/test-match-reduction.mlir b/mlir/test/Analysis/test-match-reduction.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-match-reduction.mlir @@ -0,0 +1,114 @@ +// RUN: mlir-opt %s -test-match-reduction -verify-diagnostics -split-input-file + +// Verify that the generic reduction detection utility works on different +// dialects. + +// expected-remark@below {{Testing function}} +func @linalg_red_add(%in0t : tensor, %out0t : tensor<1xf32>) { + // expected-remark@below {{Reduction found in output #0!}} + // expected-remark@below {{Reduced Value: of type 'f32' at index: 0}} + // expected-remark@below {{Combiner Op: %1 = addf %arg2, %arg3 : f32}} + %red = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (0)>], + iterator_types = ["reduction"]} + ins(%in0t : tensor) + outs(%out0t : tensor<1xf32>) { + ^bb0(%in0: f32, %out0: f32): + %add = addf %in0, %out0 : f32 + linalg.yield %add : f32 + } -> tensor<1xf32> + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @affine_red_add(%in: memref<256x512xf32>, %out: memref<256xf32>) { + %cst = constant 0.000000e+00 : f32 + affine.for %i = 0 to 256 { + // expected-remark@below {{Reduction found in output #0!}} + // expected-remark@below {{Reduced Value: %1 = affine.load %arg0[%arg2, %arg3] : memref<256x512xf32>}} + // expected-remark@below {{Combiner Op: %2 = addf %arg4, %1 : f32}} + %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { + %ld = affine.load %in[%i, %j] : memref<256x512xf32> + %add = addf %red_iter, %ld : f32 + affine.yield %add : f32 + } + affine.store %final_red, %out[%i] : memref<256xf32> + } + return +} + +// ----- + +// TODO: Iteration-carried values with multiple uses are not supported yet. +// expected-remark@below {{Testing function}} +func @linalg_red_max(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) { + // expected-remark@below {{Reduction NOT found in output #0!}} + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%in0t : tensor<4x4xf32>) + outs(%out0t : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): + %cmp = cmpf ogt, %in0, %out0 : f32 + %sel = select %cmp, %in0, %out0 : f32 + linalg.yield %sel : f32 + } -> tensor<4xf32> + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @linalg_fused_red_add(%in0t: tensor<4x4xf32>, %out0t: tensor<4xf32>) { + // expected-remark@below {{Reduction found in output #0!}} + // expected-remark@below {{Reduced Value: %2 = subf %1, %arg2 : f32}} + // expected-remark@below {{Combiner Op: %3 = addf %2, %arg3 : f32}} + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%in0t : tensor<4x4xf32>) + outs(%out0t : tensor<4xf32>) { + ^bb0(%in0: f32, %out0: f32): + %mul = mulf %in0, %in0 : f32 + %sub = subf %mul, %in0 : f32 + %add = addf %sub, %out0 : f32 + linalg.yield %add : f32 + } -> tensor<4xf32> + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @affine_no_red_rec(%in: memref<512xf32>) { + %cst = constant 0.000000e+00 : f32 + // %rec is the value loaded in the previous iteration. + // expected-remark@below {{Reduction NOT found in output #0!}} + %final_val = affine.for %j = 0 to 512 iter_args(%rec = %cst) -> (f32) { + %ld = affine.load %in[%j] : memref<512xf32> + %add = addf %ld, %rec : f32 + affine.yield %ld : f32 + } + return +} + +// ----- + +// expected-remark@below {{Testing function}} +func @affine_output_dep(%in: memref<512xf32>) { + %cst = constant 0.000000e+00 : f32 + // Reduction %red is not supported because it depends on another + // loop-carried dependence. + // expected-remark@below {{Reduction NOT found in output #0!}} + // expected-remark@below {{Reduction NOT found in output #1!}} + %final_red, %final_dep = affine.for %j = 0 to 512 + iter_args(%red = %cst, %dep = %cst) -> (f32, f32) { + %ld = affine.load %in[%j] : memref<512xf32> + %add = addf %dep, %red : f32 + affine.yield %add, %ld : f32, f32 + } + return +} + diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -3,6 +3,7 @@ TestAliasAnalysis.cpp TestCallGraph.cpp TestLiveness.cpp + TestMatchReduction.cpp TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp diff --git a/mlir/test/lib/Analysis/TestMatchReduction.cpp b/mlir/test/lib/Analysis/TestMatchReduction.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestMatchReduction.cpp @@ -0,0 +1,86 @@ +//===- TestMatchReduction.cpp - Test the match reduction utility ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a test pass for the match reduction utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/LoopAnalysis.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +void printReductionResult(Operation *redRegionOp, unsigned numOutput, + Value reducedValue, + ArrayRef combinerOps) { + if (reducedValue) { + redRegionOp->emitRemark("Reduction found in output #") << numOutput << "!"; + redRegionOp->emitRemark("Reduced Value: ") << reducedValue; + for (Operation *combOp : combinerOps) + redRegionOp->emitRemark("Combiner Op: ") << *combOp; + + return; + } + + redRegionOp->emitRemark("Reduction NOT found in output #") + << numOutput << "!"; +} + +struct TestMatchReductionPass + : public PassWrapper { + StringRef getArgument() const final { return "test-match-reduction"; } + StringRef getDescription() const final { + return "Test the match reduction utility."; + } + + void runOnFunction() override { + FuncOp func = getFunction(); + func->emitRemark("Testing function"); + + func.walk([](Operation *op) { + if (isa(op)) + return; + + // Limit testing to ops with only one region. + if (op->getNumRegions() != 1) + return; + + Region ®ion = op->getRegion(0); + if (!region.hasOneBlock()) + return; + + // We expect all the tested region ops to have 1 input by default. The + // remaining arguments are assumed to be outputs/reductions and there must + // be at least one. + // TODO: Extend it to support more generic cases. + Block ®ionEntry = region.front(); + auto args = regionEntry.getArguments(); + if (args.size() < 2) + return; + + auto outputs = args.drop_front(); + for (int i = 0, size = outputs.size(); i < size; ++i) { + SmallVector combinerOps; + Value reducedValue = matchReduction(outputs, i, combinerOps); + printReductionResult(op, i, reducedValue, combinerOps); + } + }); + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestMatchReductionPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -94,6 +94,7 @@ void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); @@ -183,6 +184,7 @@ mlir::test::registerTestLoopFusion(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); + mlir::test::registerTestMatchReductionPass(); mlir::test::registerTestMathAlgebraicSimplificationPass(); mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4336,6 +4336,7 @@ hdrs = ["include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"], includes = ["include"], deps = [ + ":Analysis", ":ConversionPassIncGen", ":IR", ":LLVMDialect",