diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -620,6 +620,93 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// MaxFOp +//===----------------------------------------------------------------------===// + +def Arith_MaxFOp : Arith_FloatBinaryOp<"maxf"> { + let summary = "floating-point maximum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arith.maxf` ssa-use `,` ssa-use `:` type + ``` + + Returns the maximum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point maximum. + %a = arith.maxf %b, %c : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MaxSIOp +//===----------------------------------------------------------------------===// + +def Arith_MaxSIOp : Arith_IntBinaryOp<"maxsi"> { + let summary = "signed integer maximum operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// MaxUIOp +//===----------------------------------------------------------------------===// + +def Arith_MaxUIOp : Arith_IntBinaryOp<"maxui"> { + let summary = "unsigned integer maximum operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// MinFOp +//===----------------------------------------------------------------------===// + +def Arith_MinFOp : Arith_FloatBinaryOp<"minf"> { + let summary = "floating-point minimum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `arith.minf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = arith.minf %b, %c : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MinSIOp +//===----------------------------------------------------------------------===// + +def Arith_MinSIOp : Arith_IntBinaryOp<"minsi"> { + let summary = "signed integer minimum operation"; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// MinUIOp +//===----------------------------------------------------------------------===// + +def Arith_MinUIOp : Arith_IntBinaryOp<"minui"> { + let summary = "unsigned integer minimum operation"; + let hasFolder = 1; +} + + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -665,156 +665,6 @@ let hasFolder = 1; } -//===----------------------------------------------------------------------===// -// MaxFOp -//===----------------------------------------------------------------------===// - -def MaxFOp : FloatBinaryOp<"maxf"> { - let summary = "floating-point maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type - ``` - - Returns the maximum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point maximum. - %a = maxf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MaxSIOp -//===----------------------------------------------------------------------===// - -def MaxSIOp : IntBinaryOp<"maxsi"> { - let summary = "signed integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as signed integers. - - Example: - - ```mlir - // Scalar signed integer maximum. - %a = maxsi %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// MaxUIOp -//===----------------------------------------------------------------------===// - -def MaxUIOp : IntBinaryOp<"maxui"> { - let summary = "unsigned integer maximum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type - ``` - - Returns the larger of %a and %b comparing the values as unsigned integers. - - Example: - - ```mlir - // Scalar unsigned integer maximum. - %a = maxui %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// MinFOp -//===----------------------------------------------------------------------===// - -def MinFOp : FloatBinaryOp<"minf"> { - let summary = "floating-point minimum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type - ``` - - Returns the minimum of the two arguments, treating -0.0 as less than +0.0. - If one of the arguments is NaN, then the result is also NaN. - - Example: - - ```mlir - // Scalar floating-point minimum. - %a = minf %b, %c : f64 - ``` - }]; -} - -//===----------------------------------------------------------------------===// -// MinSIOp -//===----------------------------------------------------------------------===// - -def MinSIOp : IntBinaryOp<"minsi"> { - let summary = "signed integer minimum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type - ``` - - Returns the smaller of %a and %b comparing the values as signed integers. - - Example: - - ```mlir - // Scalar signed integer minimum. - %a = minsi %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// MinUIOp -//===----------------------------------------------------------------------===// - -def MinUIOp : IntBinaryOp<"minui"> { - let summary = "unsigned integer minimum operation"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type - ``` - - Returns the smaller of %a and %b comparing the values as unsigned integers. - - Example: - - ```mlir - // Scalar unsigned integer minimum. - %a = minui %b, %c : i64 - ``` - }]; - let hasFolder = 1; -} - //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -58,12 +58,12 @@ .Case([](arith::MulFOp) { return AtomicRMWKind::mulf; }) .Case([](arith::AddIOp) { return AtomicRMWKind::addi; }) .Case([](arith::MulIOp) { return AtomicRMWKind::muli; }) - .Case([](MinFOp) { return AtomicRMWKind::minf; }) - .Case([](MaxFOp) { return AtomicRMWKind::maxf; }) - .Case([](MinSIOp) { return AtomicRMWKind::mins; }) - .Case([](MaxSIOp) { return AtomicRMWKind::maxs; }) - .Case([](MinUIOp) { return AtomicRMWKind::minu; }) - .Case([](MaxUIOp) { return AtomicRMWKind::maxu; }) + .Case([](arith::MinFOp) { return AtomicRMWKind::minf; }) + .Case([](arith::MaxFOp) { return AtomicRMWKind::maxf; }) + .Case([](arith::MinSIOp) { return AtomicRMWKind::mins; }) + .Case([](arith::MaxSIOp) { return AtomicRMWKind::maxs; }) + .Case([](arith::MinUIOp) { return AtomicRMWKind::minu; }) + .Case([](arith::MaxUIOp) { return AtomicRMWKind::maxu; }) .Default([](Operation *) -> Optional { // TODO: AtomicRMW supports other kinds of reductions this is // currently not detecting, add those when the need arises. diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -230,12 +230,12 @@ patterns.add< // Unary and binary patterns - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, - spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, + spirv::UnaryAndBinaryOpPattern, ReturnOpPattern, SelectOpPattern, SplatPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, context); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -126,9 +126,9 @@ return gpu::MMAElementwiseOp::ADDF; if (isa(op)) return gpu::MMAElementwiseOp::MULF; - if (isa(op)) + if (isa(op)) return gpu::MMAElementwiseOp::MAXF; - if (isa(op)) + if (isa(op)) return gpu::MMAElementwiseOp::MINF; if (isa(op)) return gpu::MMAElementwiseOp::DIVF; diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -561,6 +561,106 @@ operands, [](APFloat a, APFloat b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// MaxSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaxSIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // maxsi(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + APInt intValue; + // maxsi(x,MAX_INT) -> MAX_INT + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && + intValue.isMaxSignedValue()) + return getRhs(); + + // maxsi(x, MIN_INT) -> x + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && + intValue.isMinSignedValue()) + return getLhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); +} + +//===----------------------------------------------------------------------===// +// MaxUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MaxUIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // maxui(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + APInt intValue; + // maxui(x,MAX_INT) -> MAX_INT + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) + return getRhs(); + + // maxui(x, MIN_INT) -> x + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) + return getLhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); +} + +//===----------------------------------------------------------------------===// +// MinSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinSIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // minsi(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + APInt intValue; + // minsi(x,MIN_INT) -> MIN_INT + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && + intValue.isMinSignedValue()) + return getRhs(); + + // minsi(x, MAX_INT) -> x + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && + intValue.isMaxSignedValue()) + return getLhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); +} + +//===----------------------------------------------------------------------===// +// MinUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult MinUIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // minui(x,x) -> x + if (getLhs() == getRhs()) + return getRhs(); + + APInt intValue; + // minui(x,MIN_INT) -> MIN_INT + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) + return getRhs(); + + // minui(x, MAX_INT) -> x + if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) + return getLhs(); + + return constFoldBinaryOp( + operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -8,6 +8,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/IR/TypeUtilities.h" using namespace mlir; @@ -147,6 +148,50 @@ } }; +template +struct MaxMinFOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Location loc = op.getLoc(); + Value cmp = rewriter.create(loc, pred, lhs, rhs); + Value select = rewriter.create(loc, cmp, lhs, rhs); + + auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, + lhs, rhs); + + Value nan = rewriter.create( + loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); + if (VectorType vectorType = lhs.getType().dyn_cast()) + nan = rewriter.create(loc, vectorType, nan); + + rewriter.replaceOpWithNewOp(op, isNaN, nan, select); + return success(); + } +}; + +template +struct MaxMinIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Location loc = op.getLoc(); + Value cmp = rewriter.create(loc, pred, lhs, rhs); + rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); + return success(); + } +}; + struct ArithmeticExpandOpsPass : public ArithmeticExpandOpsBase { void runOnFunction() override { @@ -156,9 +201,19 @@ arith::populateArithmeticExpandOpsPatterns(patterns); target.addLegalDialect(); - target.addIllegalOp(); - + // clang-format off + target.addIllegalOp< + arith::CeilDivSIOp, + arith::CeilDivUIOp, + arith::FloorDivSIOp, + arith::MaxFOp, + arith::MaxSIOp, + arith::MaxUIOp, + arith::MinFOp, + arith::MinSIOp, + arith::MinUIOp + >(); + // clang-format on if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); @@ -169,9 +224,19 @@ void mlir::arith::populateArithmeticExpandOpsPatterns( RewritePatternSet &patterns) { - patterns - .add( - patterns.getContext()); + // clang-format off + patterns.add< + CeilDivSIOpConverter, + CeilDivUIOpConverter, + FloorDivSIOpConverter, + MaxMinFOpConverter, + MaxMinFOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter + >(patterns.getContext()); + // clang-format on } std::unique_ptr mlir::arith::createArithmeticExpandOpsPass() { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -283,36 +283,36 @@ Value applyfn__max(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__max_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__min(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } Value applyfn__min_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); + return builder.create(lhs.getLoc(), lhs, rhs); llvm_unreachable("unsupported non numeric type"); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -129,10 +129,12 @@ .Case( [&](auto op) { return vector::CombiningKind::ADD; }) .Case([&](auto op) { return vector::CombiningKind::AND; }) - .Case([&](auto op) { return vector::CombiningKind::MAXSI; }) - .Case([&](auto op) { return vector::CombiningKind::MAXF; }) - .Case([&](auto op) { return vector::CombiningKind::MINSI; }) - .Case([&](auto op) { return vector::CombiningKind::MINF; }) + .Case( + [&](auto op) { return vector::CombiningKind::MAXSI; }) + .Case([&](auto op) { return vector::CombiningKind::MAXF; }) + .Case( + [&](auto op) { return vector::CombiningKind::MINSI; }) + .Case([&](auto op) { return vector::CombiningKind::MINF; }) .Case( [&](auto op) { return vector::CombiningKind::MUL; }) .Case([&](auto op) { return vector::CombiningKind::OR; }) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -251,17 +251,17 @@ case AtomicRMWKind::muli: return builder.create(loc, lhs, rhs); case AtomicRMWKind::maxf: - return builder.create(loc, lhs, rhs); + return builder.create(loc, lhs, rhs); case AtomicRMWKind::minf: - return builder.create(loc, lhs, rhs); + return builder.create(loc, lhs, rhs); case AtomicRMWKind::maxs: - return builder.create(loc, lhs, rhs); + return builder.create(loc, lhs, rhs); case AtomicRMWKind::mins: - return builder.create(loc, lhs, rhs); + return builder.create(loc, lhs, rhs); case AtomicRMWKind::maxu: - return builder.create(loc, lhs, rhs); + return builder.create(loc, lhs, rhs); case AtomicRMWKind::minu: - return builder.create(loc, lhs, rhs); + return builder.create(loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); @@ -921,106 +921,6 @@ return value.isa(); } -//===----------------------------------------------------------------------===// -// MaxSIOp -//===----------------------------------------------------------------------===// - -OpFoldResult MaxSIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - // maxsi(x,x) -> x - if (getLhs() == getRhs()) - return getRhs(); - - APInt intValue; - // maxsi(x,MAX_INT) -> MAX_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMaxSignedValue()) - return getRhs(); - - // maxsi(x, MIN_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMinSignedValue()) - return getLhs(); - - return constFoldBinaryOp( - operands, [](APInt a, APInt b) { return llvm::APIntOps::smax(a, b); }); -} - -//===----------------------------------------------------------------------===// -// MaxUIOp -//===----------------------------------------------------------------------===// - -OpFoldResult MaxUIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - // maxui(x,x) -> x - if (getLhs() == getRhs()) - return getRhs(); - - APInt intValue; - // maxui(x,MAX_INT) -> MAX_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) - return getRhs(); - - // maxui(x, MIN_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) - return getLhs(); - - return constFoldBinaryOp( - operands, [](APInt a, APInt b) { return llvm::APIntOps::umax(a, b); }); -} - -//===----------------------------------------------------------------------===// -// MinSIOp -//===----------------------------------------------------------------------===// - -OpFoldResult MinSIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - // minsi(x,x) -> x - if (getLhs() == getRhs()) - return getRhs(); - - APInt intValue; - // minsi(x,MIN_INT) -> MIN_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMinSignedValue()) - return getRhs(); - - // minsi(x, MAX_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && - intValue.isMaxSignedValue()) - return getLhs(); - - return constFoldBinaryOp( - operands, [](APInt a, APInt b) { return llvm::APIntOps::smin(a, b); }); -} - -//===----------------------------------------------------------------------===// -// MinUIOp -//===----------------------------------------------------------------------===// - -OpFoldResult MinUIOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary operation takes two operands"); - - // minui(x,x) -> x - if (getLhs() == getRhs()) - return getRhs(); - - APInt intValue; - // minui(x,MIN_INT) -> MIN_INT - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue()) - return getRhs(); - - // minui(x, MAX_INT) -> x - if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue()) - return getLhs(); - - return constFoldBinaryOp( - operands, [](APInt a, APInt b) { return llvm::APIntOps::umin(a, b); }); -} - //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -119,64 +119,16 @@ } }; -template -struct MaxMinFOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const final { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - Location loc = op.getLoc(); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); - - auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, - lhs, rhs); - - Value nan = rewriter.create( - loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); - if (VectorType vectorType = lhs.getType().dyn_cast()) - nan = rewriter.create(loc, vectorType, nan); - - rewriter.replaceOpWithNewOp(op, isNaN, nan, select); - return success(); - } -}; - -template -struct MaxMinIOpConverter : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const final { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - Location loc = op.getLoc(); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); - return success(); - } -}; - struct StdExpandOpsPass : public StdExpandOpsBase { void runOnFunction() override { MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); populateStdExpandOpsPatterns(patterns); - arith::populateArithmeticExpandOpsPatterns(patterns); - ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); target.addDynamicallyLegalOp([](AtomicRMWOp op) { return op.getKind() != AtomicRMWKind::maxf && op.getKind() != AtomicRMWKind::minf; @@ -184,16 +136,6 @@ target.addDynamicallyLegalOp([](memref::ReshapeOp op) { return !op.shape().getType().cast().hasStaticShape(); }); - // clang-format off - target.addIllegalOp< - MaxFOp, - MaxSIOp, - MaxUIOp, - MinFOp, - MinSIOp, - MinUIOp - >(); - // clang-format on if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); @@ -203,18 +145,8 @@ } // namespace void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) { - // clang-format off - patterns.add< - AtomicRMWOpConverter, - MaxMinFOpConverter, - MaxMinFOpConverter, - MaxMinIOpConverter, - MaxMinIOpConverter, - MaxMinIOpConverter, - MaxMinIOpConverter, - MemRefReshapeOpConverter - >(patterns.getContext()); - // clang-format on + patterns.add( + patterns.getContext()); } std::unique_ptr mlir::createStdExpandOpsPass() { diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -255,22 +255,22 @@ result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::MINUI: - result = rewriter.create(loc, operand, result); + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::MINSI: - result = rewriter.create(loc, operand, result); + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::MINF: - result = rewriter.create(loc, operand, result); + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::MAXUI: - result = rewriter.create(loc, operand, result); + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::MAXSI: - result = rewriter.create(loc, operand, result); + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::MAXF: - result = rewriter.create(loc, operand, result); + result = rewriter.create(loc, operand, result); break; case vector::CombiningKind::AND: result = rewriter.create(loc, operand, result); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -862,16 +862,16 @@ combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::MINUI: - combinedResult = rewriter.create(loc, mul, acc); + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::MINSI: - combinedResult = rewriter.create(loc, mul, acc); + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::MAXUI: - combinedResult = rewriter.create(loc, mul, acc); + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::MAXSI: - combinedResult = rewriter.create(loc, mul, acc); + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::AND: combinedResult = rewriter.create(loc, mul, acc); @@ -910,10 +910,10 @@ combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::MINF: - combinedResult = rewriter.create(loc, mul, acc); + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::MAXF: - combinedResult = rewriter.create(loc, mul, acc); + combinedResult = rewriter.create(loc, mul, acc); break; case CombiningKind::ADD: // Already handled this special case above. case CombiningKind::AND: // Only valid for integer types. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -334,30 +334,30 @@ def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs, rhs).result + return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxSIOp(lhs, rhs).result + return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs, rhs).result + return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxUIOp(lhs, rhs).result + return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs, rhs).result + return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinSIOp(lhs, rhs).result + return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs, rhs).result + return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinUIOp(lhs, rhs).result + return arith.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -25,13 +25,13 @@ // CHECK: spv.UMod %{{.*}}, %{{.*}}: i32 %5 = arith.remui %lhs, %rhs: i32 // CHECK: spv.GLSL.SMax %{{.*}}, %{{.*}}: i32 - %6 = maxsi %lhs, %rhs : i32 + %6 = arith.maxsi %lhs, %rhs : i32 // CHECK: spv.GLSL.UMax %{{.*}}, %{{.*}}: i32 - %7 = maxui %lhs, %rhs : i32 + %7 = arith.maxui %lhs, %rhs : i32 // CHECK: spv.GLSL.SMin %{{.*}}, %{{.*}}: i32 - %8 = minsi %lhs, %rhs : i32 + %8 = arith.minsi %lhs, %rhs : i32 // CHECK: spv.GLSL.UMin %{{.*}}, %{{.*}}: i32 - %9 = minui %lhs, %rhs : i32 + %9 = arith.minui %lhs, %rhs : i32 return } @@ -76,9 +76,9 @@ // CHECK: spv.FRem %{{.*}}, %{{.*}}: f32 %4 = arith.remf %lhs, %rhs: f32 // CHECK: spv.GLSL.FMax %{{.*}}, %{{.*}}: f32 - %5 = maxf %lhs, %rhs: f32 + %5 = arith.maxf %lhs, %rhs: f32 // CHECK: spv.GLSL.FMin %{{.*}}, %{{.*}}: f32 - %6 = minf %lhs, %rhs: f32 + %6 = arith.minf %lhs, %rhs: f32 return } diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir @@ -34,7 +34,7 @@ affine.for %i = 0 to 256 { %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { %ld = affine.load %in[%i, %j] : memref<256x512xf32> - %min = minf %red_iter, %ld : f32 + %min = arith.minf %red_iter, %ld : f32 affine.yield %min : f32 } affine.store %final_red, %out[%i] : memref<256xf32> @@ -47,7 +47,7 @@ // CHECK: %[[vmax:.*]] = arith.constant dense<0x7F800000> : vector<128xf32> // CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xf32>) { // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> -// CHECK: %[[min:.*]] = minf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[min:.*]] = arith.minf %[[red_iter]], %[[ld]] : vector<128xf32> // CHECK: affine.yield %[[min]] : vector<128xf32> // CHECK: } // CHECK: %[[final_min:.*]] = vector.reduction "minf", %[[vred:.*]] : vector<128xf32> into f32 @@ -61,7 +61,7 @@ affine.for %i = 0 to 256 { %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) { %ld = affine.load %in[%i, %j] : memref<256x512xf32> - %max = maxf %red_iter, %ld : f32 + %max = arith.maxf %red_iter, %ld : f32 affine.yield %max : f32 } affine.store %final_red, %out[%i] : memref<256xf32> @@ -74,7 +74,7 @@ // CHECK: %[[vmin:.*]] = arith.constant dense<0xFF800000> : vector<128xf32> // CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xf32>) { // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xf32>, vector<128xf32> -// CHECK: %[[max:.*]] = maxf %[[red_iter]], %[[ld]] : vector<128xf32> +// CHECK: %[[max:.*]] = arith.maxf %[[red_iter]], %[[ld]] : vector<128xf32> // CHECK: affine.yield %[[max]] : vector<128xf32> // CHECK: } // CHECK: %[[final_max:.*]] = vector.reduction "maxf", %[[vred:.*]] : vector<128xf32> into f32 @@ -88,7 +88,7 @@ affine.for %i = 0 to 256 { %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) { %ld = affine.load %in[%i, %j] : memref<256x512xi32> - %min = minsi %red_iter, %ld : i32 + %min = arith.minsi %red_iter, %ld : i32 affine.yield %min : i32 } affine.store %final_red, %out[%i] : memref<256xi32> @@ -101,7 +101,7 @@ // CHECK: %[[vmax:.*]] = arith.constant dense<2147483647> : vector<128xi32> // CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xi32>) { // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32> -// CHECK: %[[min:.*]] = minsi %[[red_iter]], %[[ld]] : vector<128xi32> +// CHECK: %[[min:.*]] = arith.minsi %[[red_iter]], %[[ld]] : vector<128xi32> // CHECK: affine.yield %[[min]] : vector<128xi32> // CHECK: } // CHECK: %[[final_min:.*]] = vector.reduction "minsi", %[[vred:.*]] : vector<128xi32> into i32 @@ -115,7 +115,7 @@ affine.for %i = 0 to 256 { %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) { %ld = affine.load %in[%i, %j] : memref<256x512xi32> - %max = maxsi %red_iter, %ld : i32 + %max = arith.maxsi %red_iter, %ld : i32 affine.yield %max : i32 } affine.store %final_red, %out[%i] : memref<256xi32> @@ -128,7 +128,7 @@ // CHECK: %[[vmin:.*]] = arith.constant dense<-2147483648> : vector<128xi32> // CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xi32>) { // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32> -// CHECK: %[[max:.*]] = maxsi %[[red_iter]], %[[ld]] : vector<128xi32> +// CHECK: %[[max:.*]] = arith.maxsi %[[red_iter]], %[[ld]] : vector<128xi32> // CHECK: affine.yield %[[max]] : vector<128xi32> // CHECK: } // CHECK: %[[final_max:.*]] = vector.reduction "maxsi", %[[vred:.*]] : vector<128xi32> into i32 @@ -142,7 +142,7 @@ affine.for %i = 0 to 256 { %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) { %ld = affine.load %in[%i, %j] : memref<256x512xi32> - %min = minui %red_iter, %ld : i32 + %min = arith.minui %red_iter, %ld : i32 affine.yield %min : i32 } affine.store %final_red, %out[%i] : memref<256xi32> @@ -155,7 +155,7 @@ // CHECK: %[[vmax:.*]] = arith.constant dense<-1> : vector<128xi32> // CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmax]]) -> (vector<128xi32>) { // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32> -// CHECK: %[[min:.*]] = minui %[[red_iter]], %[[ld]] : vector<128xi32> +// CHECK: %[[min:.*]] = arith.minui %[[red_iter]], %[[ld]] : vector<128xi32> // CHECK: affine.yield %[[min]] : vector<128xi32> // CHECK: } // CHECK: %[[final_min:.*]] = vector.reduction "minui", %[[vred:.*]] : vector<128xi32> into i32 @@ -169,7 +169,7 @@ affine.for %i = 0 to 256 { %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) { %ld = affine.load %in[%i, %j] : memref<256x512xi32> - %max = maxui %red_iter, %ld : i32 + %max = arith.maxui %red_iter, %ld : i32 affine.yield %max : i32 } affine.store %final_red, %out[%i] : memref<256xi32> @@ -182,7 +182,7 @@ // CHECK: %[[vmin:.*]] = arith.constant dense<0> : vector<128xi32> // CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vmin]]) -> (vector<128xi32>) { // CHECK: %[[ld:.*]] = vector.transfer_read %{{.*}} : memref<256x512xi32>, vector<128xi32> -// CHECK: %[[max:.*]] = maxui %[[red_iter]], %[[ld]] : vector<128xi32> +// CHECK: %[[max:.*]] = arith.maxui %[[red_iter]], %[[ld]] : vector<128xi32> // CHECK: affine.yield %[[max]] : vector<128xi32> // CHECK: } // CHECK: %[[final_max:.*]] = vector.reduction "maxui", %[[vred:.*]] : vector<128xi32> into i32 diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -386,3 +386,75 @@ %res = arith.bitcast %bf : bf16 to i16 return %res : i16 } + +// ----- + +// CHECK-LABEL: test_maxsi +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 +// CHECK: %[[X:.+]] = arith.maxsi %arg0, %[[C0]] +// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] +func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 127 : i8 + %minIntCst = arith.constant -128 : i8 + %c0 = arith.constant 42 : i8 + %0 = arith.maxsi %arg0, %arg0 : i8 + %1 = arith.maxsi %arg0, %maxIntCst : i8 + %2 = arith.maxsi %arg0, %minIntCst : i8 + %3 = arith.maxsi %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} + +// ----- + +// CHECK-LABEL: test_maxui +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 +// CHECK: %[[X:.+]] = arith.maxui %arg0, %[[C0]] +// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] +func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 255 : i8 + %minIntCst = arith.constant 0 : i8 + %c0 = arith.constant 42 : i8 + %0 = arith.maxui %arg0, %arg0 : i8 + %1 = arith.maxui %arg0, %maxIntCst : i8 + %2 = arith.maxui %arg0, %minIntCst : i8 + %3 = arith.maxui %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} + +// ----- + +// CHECK-LABEL: test_minsi +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 +// CHECK: %[[X:.+]] = arith.minsi %arg0, %[[C0]] +// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] +func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 127 : i8 + %minIntCst = arith.constant -128 : i8 + %c0 = arith.constant 42 : i8 + %0 = arith.minsi %arg0, %arg0 : i8 + %1 = arith.minsi %arg0, %maxIntCst : i8 + %2 = arith.minsi %arg0, %minIntCst : i8 + %3 = arith.minsi %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} + +// ----- + +// CHECK-LABEL: test_minui +// CHECK: %[[C0:.+]] = arith.constant 42 +// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 +// CHECK: %[[X:.+]] = arith.minui %arg0, %[[C0]] +// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] +func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) { + %maxIntCst = arith.constant 255 : i8 + %minIntCst = arith.constant 0 : i8 + %c0 = arith.constant 42 : i8 + %0 = arith.minui %arg0, %arg0 : i8 + %1 = arith.minui %arg0, %maxIntCst : i8 + %2 = arith.minui %arg0, %minIntCst : i8 + %3 = arith.minui %arg0, %c0 : i8 + return %0, %1, %2, %3: i8, i8, i8, i8 +} diff --git a/mlir/test/Dialect/Arithmetic/expand-ops.mlir b/mlir/test/Dialect/Arithmetic/expand-ops.mlir --- a/mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ b/mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -145,3 +145,92 @@ // CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index // CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index } + +// ----- + +// CHECK-LABEL: func @maxf +func @maxf(%a: f32, %b: f32) -> f32 { + %result = arith.maxf %a, %b : f32 + return %result : f32 +} +// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: return %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @maxf_vector +func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> { + %result = arith.maxf %a, %b : vector<4xf16> + return %result : vector<4xf16> +} +// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16 +// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16> +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]] +// CHECK-NEXT: return %[[RESULT]] : vector<4xf16> + +// ----- + +// CHECK-LABEL: func @minf +func @minf(%a: f32, %b: f32) -> f32 { + %result = arith.minf %a, %b : f32 + return %result : f32 +} +// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: return %[[RESULT]] : f32 + + +// ----- + +// CHECK-LABEL: func @maxsi +func @maxsi(%a: i32, %b: i32) -> i32 { + %result = arith.maxsi %a, %b : i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32 + +// ----- + +// CHECK-LABEL: func @minsi +func @minsi(%a: i32, %b: i32) -> i32 { + %result = arith.minsi %a, %b : i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32 + + +// ----- + +// CHECK-LABEL: func @maxui +func @maxui(%a: i32, %b: i32) -> i32 { + %result = arith.maxui %a, %b : i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32 + + +// ----- + +// CHECK-LABEL: func @minui +func @minui(%a: i32, %b: i32) -> i32 { + %result = arith.minui %a, %b : i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -704,3 +704,25 @@ return } + +// CHECK-LABEL: func @maximum +func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>, + %f1: f32, %f2: f32, + %i1: i32, %i2: i32) { + %max_vector = arith.maxf %v1, %v2 : vector<4xf32> + %max_float = arith.maxf %f1, %f2 : f32 + %max_signed = arith.maxsi %i1, %i2 : i32 + %max_unsigned = arith.maxui %i1, %i2 : i32 + return +} + +// CHECK-LABEL: func @minimum +func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, + %f1: f32, %f2: f32, + %i1: i32, %i2: i32) { + %min_vector = arith.minf %v1, %v2 : vector<4xf32> + %min_float = arith.minf %f1, %f2 : f32 + %min_signed = arith.minsi %i1, %i2 : i32 + %min_unsigned = arith.minui %i1, %i2 : i32 + return +} diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -940,7 +940,7 @@ iterator_types = ["parallel", "reduction"] } ins(%5 : tensor) outs(%7 : tensor) { ^bb0(%arg2: f32, %arg3: f32): // no predecessors - %9 = maxf %arg2, %arg3 : f32 + %9 = arith.maxf %arg2, %arg3 : f32 linalg.yield %9 : f32 } -> tensor return %8 : tensor diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -111,7 +111,7 @@ // CHECK-LABEL: @generalize_pooling_nhwc_max_f32 // CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) -// CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT_ARG]], %[[IN_ARG]] : f32 // CHECK-NEXT: linalg.yield %[[MAX]] : f32 // CHECK-NEXT: -> tensor<1x2x4x1xf32> @@ -125,7 +125,7 @@ // CHECK-LABEL: @generalize_pooling_nhwc_max_i32 // Verify signed integer maximum. -// CHECK: = maxsi +// CHECK: = arith.maxsi // ----- @@ -137,7 +137,7 @@ // CHECK-LABEL: @generalize_pooling_nhwc_max_unsigned_i32 // Verify unsigned integer minimum. -// CHECK: = maxui +// CHECK: = arith.maxui // ----- @@ -149,7 +149,7 @@ // CHECK-LABEL: @generalize_pooling_nhwc_min_f32 // CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32) -// CHECK-NEXT: %[[MIN:.+]] = minf %[[OUT_ARG]], %[[IN_ARG]] : f32 +// CHECK-NEXT: %[[MIN:.+]] = arith.minf %[[OUT_ARG]], %[[IN_ARG]] : f32 // CHECK-NEXT: linalg.yield %[[MIN]] : f32 // CHECK-NEXT: -> tensor<1x2x4x1xf32> @@ -163,7 +163,7 @@ // CHECK-LABEL: @generalize_pooling_nhwc_min_i32 // Verify signed integer minimum. -// CHECK: = minsi +// CHECK: = arith.minsi // ----- @@ -175,7 +175,7 @@ // CHECK-LABEL: @generalize_pooling_nhwc_min_unsigned_i32 // Verify unsigned integer minimum. -// CHECK: = minui +// CHECK: = arith.minui // ----- diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -819,7 +819,7 @@ iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { ^bb0(%in0: f32, %out0: f32): // no predecessors - %max = maxf %in0, %out0 : f32 + %max = arith.maxf %in0, %out0 : f32 linalg.yield %max : f32 } -> tensor<4xf32> return %red : tensor<4xf32> @@ -834,7 +834,7 @@ // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: %[[R:.+]] = vector.multi_reduction #vector.kind, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> - // CHECK: minf %[[R]], %[[CMAXF]] : vector<4xf32> + // CHECK: arith.minf %[[R]], %[[CMAXF]] : vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %maxf32 = arith.constant 3.40282e+38 : f32 %init = linalg.init_tensor [4] : tensor<4xf32> @@ -844,7 +844,7 @@ iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<4x4xf32>) outs(%fill : tensor<4xf32>) { ^bb0(%in0: f32, %out0: f32): // no predecessors - %min = minf %out0, %in0 : f32 + %min = arith.minf %out0, %in0 : f32 linalg.yield %min : f32 } -> tensor<4xf32> return %red : tensor<4xf32> diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -60,68 +60,3 @@ %res = select %arg0, %false, %true : i1 return %res : i1 } - -// CHECK-LABEL: test_maxsi -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MAX_INT_CST:.+]] = arith.constant 127 -// CHECK: %[[X:.+]] = maxsi %arg0, %[[C0]] -// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] -func @test_maxsi(%arg0 : i8) -> (i8, i8, i8, i8) { - %maxIntCst = arith.constant 127 : i8 - %minIntCst = arith.constant -128 : i8 - %c0 = arith.constant 42 : i8 - %0 = maxsi %arg0, %arg0 : i8 - %1 = maxsi %arg0, %maxIntCst : i8 - %2 = maxsi %arg0, %minIntCst : i8 - %3 = maxsi %arg0, %c0 : i8 - return %0, %1, %2, %3: i8, i8, i8, i8 -} - -// CHECK-LABEL: test_maxui -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MAX_INT_CST:.+]] = arith.constant -1 -// CHECK: %[[X:.+]] = maxui %arg0, %[[C0]] -// CHECK: return %arg0, %[[MAX_INT_CST]], %arg0, %[[X]] -func @test_maxui(%arg0 : i8) -> (i8, i8, i8, i8) { - %maxIntCst = arith.constant 255 : i8 - %minIntCst = arith.constant 0 : i8 - %c0 = arith.constant 42 : i8 - %0 = maxui %arg0, %arg0 : i8 - %1 = maxui %arg0, %maxIntCst : i8 - %2 = maxui %arg0, %minIntCst : i8 - %3 = maxui %arg0, %c0 : i8 - return %0, %1, %2, %3: i8, i8, i8, i8 -} - - -// CHECK-LABEL: test_minsi -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MIN_INT_CST:.+]] = arith.constant -128 -// CHECK: %[[X:.+]] = minsi %arg0, %[[C0]] -// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] -func @test_minsi(%arg0 : i8) -> (i8, i8, i8, i8) { - %maxIntCst = arith.constant 127 : i8 - %minIntCst = arith.constant -128 : i8 - %c0 = arith.constant 42 : i8 - %0 = minsi %arg0, %arg0 : i8 - %1 = minsi %arg0, %maxIntCst : i8 - %2 = minsi %arg0, %minIntCst : i8 - %3 = minsi %arg0, %c0 : i8 - return %0, %1, %2, %3: i8, i8, i8, i8 -} - -// CHECK-LABEL: test_minui -// CHECK: %[[C0:.+]] = arith.constant 42 -// CHECK: %[[MIN_INT_CST:.+]] = arith.constant 0 -// CHECK: %[[X:.+]] = minui %arg0, %[[C0]] -// CHECK: return %arg0, %arg0, %[[MIN_INT_CST]], %[[X]] -func @test_minui(%arg0 : i8) -> (i8, i8, i8, i8) { - %maxIntCst = arith.constant 255 : i8 - %minIntCst = arith.constant 0 : i8 - %c0 = arith.constant 42 : i8 - %0 = minui %arg0, %arg0 : i8 - %1 = minui %arg0, %maxIntCst : i8 - %2 = minui %arg0, %minIntCst : i8 - %3 = minui %arg0, %c0 : i8 - return %0, %1, %2, %3: i8, i8, i8, i8 -} diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -52,92 +52,3 @@ // CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8], // CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]] // CHECK-SAME: : memref<*xf32> to memref - -// ----- - -// CHECK-LABEL: func @maxf -func @maxf(%a: f32, %b: f32) -> f32 { - %result = maxf(%a, %b): (f32, f32) -> f32 - return %result : f32 -} -// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 -// CHECK-NEXT: return %[[RESULT]] : f32 - -// ----- - -// CHECK-LABEL: func @maxf_vector -func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> { - %result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16> - return %result : vector<4xf16> -} -// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16> -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7E00 : f16 -// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16> -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]] -// CHECK-NEXT: return %[[RESULT]] : vector<4xf16> - -// ----- - -// CHECK-LABEL: func @minf -func @minf(%a: f32, %b: f32) -> f32 { - %result = minf(%a, %b): (f32, f32) -> f32 - return %result : f32 -} -// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[RHS]] : f32 -// CHECK-NEXT: %[[NAN:.*]] = arith.constant 0x7FC00000 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 -// CHECK-NEXT: return %[[RESULT]] : f32 - - -// ----- - -// CHECK-LABEL: func @maxsi -func @maxsi(%a: i32, %b: i32) -> i32 { - %result = maxsi(%a, %b): (i32, i32) -> i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32 - -// ----- - -// CHECK-LABEL: func @minsi -func @minsi(%a: i32, %b: i32) -> i32 { - %result = minsi(%a, %b): (i32, i32) -> i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32 - - -// ----- - -// CHECK-LABEL: func @maxui -func @maxui(%a: i32, %b: i32) -> i32 { - %result = maxui(%a, %b): (i32, i32) -> i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32 - - -// ----- - -// CHECK-LABEL: func @minui -func @minui(%a: i32, %b: i32) -> i32 { - %result = minui(%a, %b): (i32, i32) -> i32 - return %result : i32 -} -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) -// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -62,27 +62,3 @@ %result = constant [0.1 : f64, -1.0 : f64] : complex return %result : complex } - -// CHECK-LABEL: func @maximum -func @maximum(%v1: vector<4xf32>, %v2: vector<4xf32>, - %f1: f32, %f2: f32, - %i1: i32, %i2: i32) { - %max_vector = maxf(%v1, %v2) - : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> - %max_float = maxf(%f1, %f2) : (f32, f32) -> f32 - %max_signed = maxsi(%i1, %i2) : (i32, i32) -> i32 - %max_unsigned = maxui(%i1, %i2) : (i32, i32) -> i32 - return -} - -// CHECK-LABEL: func @minimum -func @minimum(%v1: vector<4xf32>, %v2: vector<4xf32>, - %f1: f32, %f2: f32, - %i1: i32, %i2: i32) { - %min_vector = minf(%v1, %v2) - : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> - %min_float = minf(%f1, %f2) : (f32, f32) -> f32 - %min_signed = minsi(%i1, %i2) : (i32, i32) -> i32 - %min_unsigned = minui(%i1, %i2) : (i32, i32) -> i32 - return -} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -27,11 +27,11 @@ // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[RV01:.+]] = minf %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = arith.minf %[[V1]], %[[V0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> -// CHECK: %[[RV012:.+]] = minf %[[V2]], %[[RV01]] : vector<2xf32> +// CHECK: %[[RV012:.+]] = arith.minf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> -// CHECK: %[[RESULT_VEC:.+]] = minf %[[V3]], %[[RV012]] : vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = arith.minf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { @@ -44,11 +44,11 @@ // CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> // CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> // CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> -// CHECK: %[[RV01:.+]] = maxf %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = arith.maxf %[[V1]], %[[V0]] : vector<2xf32> // CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> -// CHECK: %[[RV012:.+]] = maxf %[[V2]], %[[RV01]] : vector<2xf32> +// CHECK: %[[RV012:.+]] = arith.maxf %[[V2]], %[[RV01]] : vector<2xf32> // CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> -// CHECK: %[[RESULT_VEC:.+]] = maxf %[[V3]], %[[RV012]] : vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = arith.maxf %[[V3]], %[[RV012]] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] : vector<2xf32> func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> { diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-linalg-async-parallel-for.mlir @@ -6,6 +6,7 @@ // RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ +// RUN: -arith-expand \ // RUN: -std-expand \ // RUN: -convert-vector-to-llvm \ // RUN: -convert-memref-to-llvm \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/microbench-scf-async-parallel-for.mlir @@ -6,6 +6,7 @@ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \ +// RUN: -arith-expand \ // RUN: -std-expand \ // RUN: -convert-vector-to-llvm \ // RUN: -convert-memref-to-llvm \ @@ -26,6 +27,7 @@ // RUN: -convert-async-to-llvm \ // RUN: -convert-linalg-to-loops \ // RUN: -convert-scf-to-std \ +// RUN: -arith-expand \ // RUN: -std-expand \ // RUN: -convert-vector-to-llvm \ // RUN: -convert-memref-to-llvm \ diff --git a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir --- a/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir +++ b/mlir/test/Integration/Dialect/Async/CPU/test-async-parallel-for-1d.mlir @@ -5,6 +5,7 @@ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \ +// RUN: -arith-expand \ // RUN: -std-expand \ // RUN: -convert-std-to-llvm \ // RUN: -reconcile-unrealized-casts \ @@ -20,6 +21,7 @@ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \ +// RUN: -arith-expand \ // RUN: -std-expand \ // RUN: -convert-std-to-llvm \ // RUN: -reconcile-unrealized-casts \ @@ -31,13 +33,14 @@ // RUN: mlir-opt %s -async-parallel-for="async-dispatch=false \ // RUN: num-workers=20 \ -// RUN: min-task-size=1" \ +// RUN: min-task-size=1" \ // RUN: -async-to-async-runtime \ // RUN: -async-runtime-ref-counting \ // RUN: -async-runtime-ref-counting-opt \ // RUN: -convert-async-to-llvm \ // RUN: -convert-scf-to-std \ // RUN: -convert-memref-to-llvm \ +// RUN: -arith-expand \ // RUN: -std-expand \ // RUN: -convert-std-to-llvm \ // RUN: -reconcile-unrealized-casts \ diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir --- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir +++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir @@ -1,4 +1,7 @@ -// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std \ +// RUN: -std-expand -arith-expand -convert-vector-to-llvm \ +// RUN: -convert-memref-to-llvm -convert-std-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_structured_generic.py @@ -295,7 +295,7 @@ # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: i32) # CHECK-NEXT: %[[IN_CAST:.+]] = arith.fptosi %[[IN:.+]] : f32 to i32 - # CHECK-NEXT: %[[MAX:.+]] = maxsi %[[OUT]], %[[IN_CAST:.+]] : i32 + # CHECK-NEXT: %[[MAX:.+]] = arith.maxsi %[[OUT]], %[[IN_CAST:.+]] : i32 # CHECK-NEXT: linalg.yield %[[MAX]] : i32 # CHECK-NEXT: -> tensor<2x4xi32> @builtin.FuncOp.from_py_func( @@ -307,7 +307,7 @@ # CHECK-LABEL: @test_f32i32_max_unsigned_pooling # CHECK: = arith.fptoui - # CHECK: = maxui + # CHECK: = arith.maxui @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), i32)) @@ -320,7 +320,7 @@ # CHECK-SAME: indexing_maps = [#[[$CONV_MAP_I]], #[[$POOL_MAP_K]], #[[$CONV_MAP_O]]] # CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "parallel"] # CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[SHAPE:.+]]: f32, %[[OUT:.+]]: f32) - # CHECK-NEXT: %[[MAX:.+]] = maxf %[[OUT]], %[[IN:.+]] : f32 + # CHECK-NEXT: %[[MAX:.+]] = arith.maxf %[[OUT]], %[[IN:.+]] : f32 # CHECK-NEXT: linalg.yield %[[MAX]] : f32 # CHECK-NEXT: -> tensor<2x4xf32> @builtin.FuncOp.from_py_func( @@ -332,7 +332,7 @@ # CHECK-LABEL: @test_f32i32_min_pooling # CHECK: = arith.fptosi - # CHECK: = minsi + # CHECK: = arith.minsi @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), i32)) @@ -342,7 +342,7 @@ # CHECK-LABEL: @test_f32i32_min_unsigned_pooling # CHECK: = arith.fptoui - # CHECK: = minui + # CHECK: = arith.minui @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), i32)) @@ -351,7 +351,7 @@ input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32f32_min_pooling - # CHECK: = minf + # CHECK: = arith.minf @builtin.FuncOp.from_py_func( RankedTensorType.get((4, 16), f32), RankedTensorType.get((2, 2), f32), RankedTensorType.get((2, 4), f32)) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -128,7 +128,7 @@ boilerplate) pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + - "convert-scf-to-std, std-expand), convert-vector-to-llvm," + + "convert-scf-to-std, arith-expand, std-expand), convert-vector-to-llvm," + "convert-memref-to-llvm, convert-std-to-llvm," + "reconcile-unrealized-casts") pm.run(mod)