diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -344,33 +344,6 @@ impossible to implement switching logic based on the comparison kind and made attribute validity checks (one out of ten possible kinds) more complex. -### 'select' operation to implement min/max - -Although `min` and `max` operations are likely to occur as a result of -transforming affine loops in ML functions, we did not make them first-class -operations. Instead, we provide the `select` operation that can be combined with -`cmpi` to implement the minimum and maximum computation. Although they now -require two operations, they are likely to be emitted automatically during the -transformation inside MLIR. On the other hand, there are multiple benefits of -introducing `select`: standalone min/max would concern themselves with the -signedness of the comparison, already taken into account by `cmpi`; `select` can -support floats transparently if used after a float-comparison operation; the -lower-level targets provide `select`-like instructions making the translation -trivial. - -This operation could have been implemented with additional control flow: `%r = -select %cond, %t, %f` is equivalent to - -```mlir -^bb0: - cond_br %cond, ^bb1(%t), ^bb1(%f) -^bb1(%r): -``` - -However, this control flow granularity is not available in the ML functions -where min/max, and thus `select`, are likely to appear. In addition, simpler -control flow may be beneficial for optimization in general. - ### Regions #### Attributes of type 'Block' diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1247,6 +1247,152 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// MaxFOp +//===----------------------------------------------------------------------===// + +def MaxFOp : FloatBinaryOp<"maxf"> { + let summary = "floating-point maximum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type + ``` + + Returns the maximum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point maximum. + %a = maxf %b, %c : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MaxSIOp +//===----------------------------------------------------------------------===// + +def MaxSIOp : IntBinaryOp<"maxsi"> { + let summary = "signed integer maximum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type + ``` + + Returns the larger of %a and %b comparing the values as signed integers. + + Example: + + ```mlir + // Scalar signed integer maximum. + %a = maxsi %b, %c : i64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MaxUIOp +//===----------------------------------------------------------------------===// + +def MaxUIOp : IntBinaryOp<"maxui"> { + let summary = "unsigned integer maximum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type + ``` + + Returns the larger of %a and %b comparing the values as unsigned integers. + + Example: + + ```mlir + // Scalar unsigned integer maximum. + %a = maxui %b, %c : i64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MinFOp +//===----------------------------------------------------------------------===// + +def MinFOp : FloatBinaryOp<"minf"> { + let summary = "floating-point minimum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type + ``` + + Returns the minimum of the two arguments, treating -0.0 as less than +0.0. + If one of the arguments is NaN, then the result is also NaN. + + Example: + + ```mlir + // Scalar floating-point minimum. + %a = minf %b, %c : f64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MinSIOp +//===----------------------------------------------------------------------===// + +def MinSIOp : IntBinaryOp<"minsi"> { + let summary = "signed integer minimum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type + ``` + + Returns the smaller of %a and %b comparing the values as signed integers. + + Example: + + ```mlir + // Scalar signed integer minimum. + %a = minsi %b, %c : i64 + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// MinUIOp +//===----------------------------------------------------------------------===// + +def MinUIOp : IntBinaryOp<"minui"> { + let summary = "unsigned integer minimum operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type + ``` + + Returns the smaller of %a and %b comparing the values as unsigned integers. + + Example: + + ```mlir + // Scalar unsigned integer minimum. + %a = minui %b, %c : i64 + ``` + }]; +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -215,6 +215,55 @@ } }; +static Type getElementTypeOrSelf(Type type) { + if (auto st = type.dyn_cast()) + return st.getElementType(); + return type; +} + +template +struct MaxMinFOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + Value lhs = op.lhs(); + Value rhs = op.rhs(); + + Location loc = op.getLoc(); + Value cmp = rewriter.create(loc, pred, lhs, rhs); + Value select = rewriter.create(loc, cmp, lhs, rhs); + + auto floatType = getElementTypeOrSelf(lhs.getType()).cast(); + Value isNaN = rewriter.create(loc, CmpFPredicate::UNO, lhs, rhs); + + Value nan = rewriter.create( + loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType); + if (VectorType vectorType = lhs.getType().dyn_cast()) + nan = rewriter.create(loc, vectorType, nan); + + rewriter.replaceOpWithNewOp(op, isNaN, nan, select); + return success(); + } +}; + +template +struct MaxMinIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + Value lhs = op.lhs(); + Value rhs = op.rhs(); + + Location loc = op.getLoc(); + Value cmp = rewriter.create(loc, pred, lhs, rhs); + rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); + return success(); + } +}; + struct StdExpandOpsPass : public StdExpandOpsBase { void runOnFunction() override { MLIRContext &ctx = getContext(); @@ -232,8 +281,18 @@ target.addDynamicallyLegalOp([](memref::ReshapeOp op) { return !op.shape().getType().cast().hasStaticShape(); }); - target.addIllegalOp(); - target.addIllegalOp(); + // clang-format off + target.addIllegalOp< + MaxFOp, + MaxSIOp, + MaxUIOp, + MinFOp, + MinSIOp, + MinUIOp, + SignedCeilDivIOp, + SignedFloorDivIOp + >(); + // clang-format on if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); @@ -243,9 +302,20 @@ } // namespace void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + // clang-format off + patterns.add< + AtomicRMWOpConverter, + MaxMinFOpConverter, + MaxMinFOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MemRefReshapeOpConverter, + SignedCeilDivIOpConverter, + SignedFloorDivIOpConverter + >(patterns.getContext()); + // clang-format on } std::unique_ptr mlir::createStdExpandOpsPass() { diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -109,3 +109,84 @@ // CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8], // CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]] // CHECK-SAME: : memref<*xf32> to memref + +// ----- + +// CHECK-LABEL: func @maxf +func @maxf(%a: f32, %b: f32) -> f32 { + %result = maxf(%a, %b): (f32, f32) -> f32 + return %result : f32 +} +// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32) +// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32 +// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32 +// CHECK-NEXT: return %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @maxf_vector +func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> { + %result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16> + return %result : vector<4xf16> +} +// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>) +// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] +// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16> +// CHECK-NEXT: %[[NAN:.*]] = constant 0x7E00 : f16 +// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16> +// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]] +// CHECK-NEXT: return %[[RESULT]] : vector<4xf16> + +// ----- + +// CHECK-LABEL: func @minf +func @minf(%a: f32, %b: f32) -> f32 { + %result = minf(%a, %b): (f32, f32) -> f32 + return %result : f32 +} + +// ----- + +// CHECK-LABEL: func @maxsi +func @maxsi(%a: i32, %b: i32) -> i32 { + %result = maxsi(%a, %b): (i32, i32) -> i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS]], %[[RHS]] : i32 + +// ----- + +// CHECK-LABEL: func @minsi +func @minsi(%a: i32, %b: i32) -> i32 { + %result = minsi(%a, %b): (i32, i32) -> i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS]], %[[RHS]] : i32 + + +// ----- + +// CHECK-LABEL: func @maxui +func @maxui(%a: i32, %b: i32) -> i32 { + %result = maxui(%a, %b): (i32, i32) -> i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS]], %[[RHS]] : i32 + + +// ----- + +// CHECK-LABEL: func @minui +func @minui(%a: i32, %b: i32) -> i32 { + %result = minui(%a, %b): (i32, i32) -> i32 + return %result : i32 +} +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = cmpi ult, %[[LHS]], %[[RHS]] : i32