diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -350,9 +350,32 @@ Attribute mlir::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc) { switch (kind) { + case AtomicRMWKind::maxf: + return builder.getFloatAttr( + resultType, + APFloat::getInf(resultType.cast().getFloatSemantics(), + /*Negative=*/true)); case AtomicRMWKind::addf: case AtomicRMWKind::addi: + case AtomicRMWKind::maxu: return builder.getZeroAttr(resultType); + case AtomicRMWKind::maxs: + return builder.getIntegerAttr( + resultType, + APInt::getSignedMinValue(resultType.cast().getWidth())); + case AtomicRMWKind::minf: + return builder.getFloatAttr( + resultType, + APFloat::getInf(resultType.cast().getFloatSemantics(), + /*Negative=*/false)); + case AtomicRMWKind::mins: + return builder.getIntegerAttr( + resultType, + APInt::getSignedMaxValue(resultType.cast().getWidth())); + case AtomicRMWKind::minu: + return builder.getIntegerAttr( + resultType, + APInt::getMaxValue(resultType.cast().getWidth())); case AtomicRMWKind::muli: return builder.getIntegerAttr(resultType, 1); case AtomicRMWKind::mulf: @@ -385,6 +408,30 @@ return builder.create(loc, lhs, rhs); case AtomicRMWKind::muli: return builder.create(loc, lhs, rhs); + case AtomicRMWKind::maxf: + return builder.create( + loc, builder.create(loc, CmpFPredicate::OGT, lhs, rhs), lhs, + rhs); + case AtomicRMWKind::minf: + return builder.create( + loc, builder.create(loc, CmpFPredicate::OLT, lhs, rhs), lhs, + rhs); + case AtomicRMWKind::maxs: + return builder.create( + loc, builder.create(loc, CmpIPredicate::sgt, lhs, rhs), lhs, + rhs); + case AtomicRMWKind::mins: + return builder.create( + loc, builder.create(loc, CmpIPredicate::slt, lhs, rhs), lhs, + rhs); + case AtomicRMWKind::maxu: + return builder.create( + loc, builder.create(loc, CmpIPredicate::ugt, lhs, rhs), lhs, + rhs); + case AtomicRMWKind::minu: + return builder.create( + loc, builder.create(loc, CmpIPredicate::ult, lhs, rhs), lhs, + rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -357,6 +357,18 @@ return builder.create(vector.getLoc(), scalarType, builder.getStringAttr("mul"), vector, ValueRange{}); + case AtomicRMWKind::minf: + case AtomicRMWKind::mins: + case AtomicRMWKind::minu: + return builder.create(vector.getLoc(), scalarType, + builder.getStringAttr("min"), + vector, ValueRange{}); + case AtomicRMWKind::maxf: + case AtomicRMWKind::maxs: + case AtomicRMWKind::maxu: + return builder.create(vector.getLoc(), scalarType, + builder.getStringAttr("max"), + vector, ValueRange{}); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported");