diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -920,6 +920,9 @@ LinalgTransformationFilter filter; }; +/// Return vector::CombiningKind for the given op. +llvm::Optional getCombinerOpKind(Operation *combinerOp); + //===----------------------------------------------------------------------===// // Transformation and lowering options exposed as auxiliary structs. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -109,25 +109,25 @@ Operation *newOp; }; -static llvm::Optional -getKindForOp(Operation *reductionOp) { - if (!reductionOp) +llvm::Optional +mlir::linalg::getCombinerOpKind(Operation *combinerOp) { + using ::mlir::vector::CombiningKind; + + if (!combinerOp) return llvm::None; - return llvm::TypeSwitch>( - reductionOp) + return llvm::TypeSwitch>( + combinerOp) .Case( - [&](auto op) { return vector::CombiningKind::ADD; }) - .Case([&](auto op) { return vector::CombiningKind::AND; }) - .Case( - [&](auto op) { return vector::CombiningKind::MAXSI; }) - .Case([&](auto op) { return vector::CombiningKind::MAXF; }) - .Case( - [&](auto op) { return vector::CombiningKind::MINSI; }) - .Case([&](auto op) { return vector::CombiningKind::MINF; }) + [&](auto op) { return CombiningKind::ADD; }) + .Case([&](auto op) { return CombiningKind::AND; }) + .Case([&](auto op) { return CombiningKind::MAXSI; }) + .Case([&](auto op) { return CombiningKind::MAXF; }) + .Case([&](auto op) { return CombiningKind::MINSI; }) + .Case([&](auto op) { return CombiningKind::MINF; }) .Case( - [&](auto op) { return vector::CombiningKind::MUL; }) - .Case([&](auto op) { return vector::CombiningKind::OR; }) - .Case([&](auto op) { return vector::CombiningKind::XOR; }) + [&](auto op) { return CombiningKind::MUL; }) + .Case([&](auto op) { return CombiningKind::OR; }) + .Case([&](auto op) { return CombiningKind::XOR; }) .Default([&](auto op) { return llvm::None; }); } @@ -174,7 +174,7 @@ static Value buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, Value valueToReduce, const SmallVector &reductionMask) { - auto maybeKind = getKindForOp(reduceOp); + auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); return b.create( reduceOp->getLoc(), valueToReduce, reductionMask, *maybeKind); @@ -589,7 +589,7 @@ } for (OpOperand *opOperand : op.getOutputOperands()) { Operation *reduceOp = matchLinalgReduction(opOperand); - if (!reduceOp || !getKindForOp(reduceOp)) { + if (!reduceOp || !getCombinerOpKind(reduceOp)) { LDBG("reduction precondition failed: reduction detection failed"); return failure(); } @@ -1458,10 +1458,10 @@ if (!reduceOp) return; llvm::Optional maybeKind; - maybeKind = getKindForOp(reduceOp); + maybeKind = getCombinerOpKind(reduceOp); if (!maybeKind || *maybeKind != vector::CombiningKind::ADD) return; - maybeKind = getKindForOp(&(linalgOp->getRegion(0).front().front())); + maybeKind = getCombinerOpKind(&(linalgOp->getRegion(0).front().front())); if (!maybeKind || *maybeKind != vector::CombiningKind::MUL) return;