diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -138,6 +138,36 @@ let hasFolder = 1; } +def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> { + let summary = "Division of sizes and indices"; + let description = [{ + Divides two sizes or indices. If either operand is an error it will be + propagated to the result. The operands can be of type `size` or `index`. + If at least one of the operands can hold an error, i.e. if it is of type + `size`, the result must be of type `size`. If error propagation is not + possible because both operands are of type `index` then the result may be + of type `size` or `index`. If both operands and result are of type `index`, + their runtime values could be negative. The result is rounded toward + negative infinity, i.e. floor(lhs / rhs), such that + + div(lhs, rhs) + mod(lhs, rhs) = lhs + + always holds. If any of the values is of type `size`, the behavior for + negative value is undefined. + }]; + + let arguments = (ins Shape_SizeOrIndexType:$lhs, + Shape_SizeOrIndexType:$rhs); + let results = (outs Shape_SizeOrIndexType:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) + }]; + + let verifier = [{ return ::verifySizeOrIndexOp(*this); }]; + let hasFolder = 1; +} + def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> { let summary = "Returns whether the input shapes or extent tensors are equal"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -600,6 +600,30 @@ return operands[0]; } +//===----------------------------------------------------------------------===// +// DivOp +//===----------------------------------------------------------------------===// + +OpFoldResult DivOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + if (!lhs) + return nullptr; + auto rhs = operands[1].dyn_cast_or_null(); + if (!rhs) + return nullptr; + + // Division in APInt does not follow floor(lhs, rhs) when the result is + // negative. Rather, APInt rounds toward zero. + APInt quotient, remainder; + APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder); + if (quotient.isNegative() && !remainder.isNullValue()) { + quotient -= 1; + } + + Type indexTy = IndexType::get(getContext()); + return IntegerAttr::get(indexTy, quotient); +} + //===----------------------------------------------------------------------===// // ShapeEqOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -950,6 +950,71 @@ // ----- +// Fold `div` for constant sizes. +// CHECK-LABEL: @fold_div_size +func @fold_div_size() -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 3 + // CHECK: return %[[RESULT]] : !shape.size + %c2 = shape.const_size 10 + %c3 = shape.const_size 3 + %result = shape.div %c2, %c3 : !shape.size, !shape.size -> !shape.size + return %result : !shape.size +} + +// ----- + +// Fold `div` for constant indices. +// CHECK-LABEL: @fold_div_index +func @fold_div_index() -> index { + // CHECK: %[[RESULT:.*]] = constant 2 : index + // CHECK: return %[[RESULT]] : index + %c2 = constant 10 : index + %c3 = constant 4 : index + %result = shape.div %c2, %c3 : index, index -> index + return %result : index +} + +// ----- + +// Fold `div` for constant indices and lhs is negative. +// CHECK-LABEL: @fold_div_index_neg_lhs +func @fold_div_index_neg_lhs() -> index { + // CHECK: %[[RESULT:.*]] = constant -3 : index + // CHECK: return %[[RESULT]] : index + %c2 = constant -10 : index + %c3 = constant 4 : index + %result = shape.div %c2, %c3 : index, index -> index + return %result : index +} + +// ----- + +// Fold `div` for constant indices and rhs is negative. +// CHECK-LABEL: @fold_div_index_neg_rhs +func @fold_div_index_neg_rhs() -> index { + // CHECK: %[[RESULT:.*]] = constant -3 : index + // CHECK: return %[[RESULT]] : index + %c2 = constant 10 : index + %c3 = constant -4 : index + %result = shape.div %c2, %c3 : index, index -> index + return %result : index +} + +// ----- + +// Fold `div` for mixed constants. +// CHECK-LABEL: @fold_div_mixed +func @fold_div_mixed() -> !shape.size { + // CHECK: %[[RESULT:.*]] = shape.const_size 4 + // CHECK: return %[[RESULT]] : !shape.size + %c2 = shape.const_size 12 + %c3 = constant 3 : index + %result = shape.div %c2, %c3 : !shape.size, index -> !shape.size + return %result : !shape.size +} + +// ----- + // Fold index_cast when already on index. // CHECK-LABEL: @fold_index_cast_on_index func @fold_index_cast_on_index(%arg: index) -> index { diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -129,6 +129,15 @@ return } +func @div(%size_arg : !shape.size, %index_arg : index) { + %size_div = shape.div %size_arg, %size_arg + : !shape.size, !shape.size -> !shape.size + %index_div = shape.div %index_arg, %index_arg : index, index -> index + %mixed_div = shape.div %size_arg, %index_arg + : !shape.size, index -> !shape.size + return +} + func @add(%size_arg : !shape.size, %index_arg : index) { %size_sum = shape.add %size_arg, %size_arg : !shape.size, !shape.size -> !shape.size