diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -125,6 +125,10 @@ TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc); +/// Return the identity numeric value associated to the give op. Return +/// std::nullopt if there is no known neutral element. +std::optional getNeutralElement(Operation *op); + /// Returns the identity value associated with an AtomicRMWKind op. Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -95,10 +95,6 @@ std::optional> getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); -/// Return the identity numeric value associated to the give op. Return -/// std::nullopt if there is no known neutral element. -std::optional getNeutralElement(Operation *op); - //===----------------------------------------------------------------------===// // Fusion / Tiling utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; @@ -2377,6 +2378,38 @@ return nullptr; } +/// Return the identity numeric value associated to the give op. +std::optional mlir::arith::getNeutralElement(Operation *op) { + std::optional maybeKind = + llvm::TypeSwitch>(op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return AtomicRMWKind::addf; }) + .Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; }) + .Case([](arith::MaxFOp op) { return AtomicRMWKind::maxf; }) + .Case([](arith::MinFOp op) { return AtomicRMWKind::minf; }) + // Integer operations. + .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) + .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) + .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) + .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) + .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) + .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) + .Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; }) + .Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; }) + .Case([](arith::MulIOp op) { return AtomicRMWKind::muli; }) + .Default([](Operation *op) { return std::nullopt; }); + if (!maybeKind) { + op->emitError() << "Unknown neutral element for: " << *op; + return std::nullopt; + } + + // Builder only used as helper for attribute creation. + OpBuilder b(op->getContext()); + Type resultType = op->getResult(0).getType(); + + return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc()); +} + /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -66,7 +66,7 @@ return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; - std::optional identity = getNeutralElement(reductionOp); + std::optional identity = arith::getNeutralElement(reductionOp); if (!identity.has_value()) return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); @@ -274,7 +274,8 @@ SmallVector neutralElements; for (Operation *reductionOp : combinerOps) { - std::optional neutralElement = getNeutralElement(reductionOp); + std::optional neutralElement = + arith::getNeutralElement(reductionOp); if (!neutralElement.has_value()) return b.notifyMatchFailure(op, "cannot find neutral element."); neutralElements.push_back(*neutralElement); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -271,7 +271,7 @@ return op->emitOpError("Failed to anaysis the reduction operation."); Operation *reductionOp = combinerOps[0]; - std::optional identity = getNeutralElement(reductionOp); + std::optional identity = arith::getNeutralElement(reductionOp); if (!identity.has_value()) return op->emitOpError( "Failed to get an identity value for the reduction operation."); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -905,37 +905,5 @@ return reassociation; } -/// Return the identity numeric value associated to the give op. -std::optional getNeutralElement(Operation *op) { - // Builder only used as helper for attribute creation. - OpBuilder b(op->getContext()); - Type resultType = op->getResult(0).getType(); - if (auto floatType = dyn_cast(resultType)) { - const llvm::fltSemantics &semantic = floatType.getFloatSemantics(); - if (isa(op)) - return b.getFloatAttr(resultType, llvm::APFloat::getZero(semantic)); - if (isa(op)) - return b.getFloatAttr(resultType, llvm::APFloat(semantic, 1)); - if (isa(op)) - return b.getFloatAttr(resultType, - llvm::APFloat::getInf(semantic, /*Negative=*/true)); - if (isa(op)) - return b.getFloatAttr( - resultType, llvm::APFloat::getInf(semantic, /*Negative=*/false)); - return std::nullopt; - } - if (isa(op)) - return b.getIntegerAttr(resultType, 0); - if (isa(op)) - return b.getIntegerAttr(resultType, -1); - if (isa(op)) - return b.getIntegerAttr(resultType, std::numeric_limits::min()); - if (isa(op)) - return b.getIntegerAttr(resultType, std::numeric_limits::max()); - if (isa(op)) - return b.getIntegerAttr(resultType, 1); - return std::nullopt; -} - } // namespace linalg } // namespace mlir diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -340,6 +340,7 @@ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor) outs(%arg1 : tensor) { ^bb0(%in: f32, %out: f32): %1 = llvm.fmul %in, %in : f32 + // expected-error @below {{Unknown neutral element for:}} %2 = llvm.fadd %1, %out : f32 linalg.yield %2 : f32 } -> tensor