Index: mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -277,6 +277,28 @@ } //===----------------------------------------------------------------------===// +// CeilDivUIOp +//===----------------------------------------------------------------------===// + +def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> { + let summary = "unsigned ceil integer division operation"; + let description = [{ + Unsigned integer division. Rounds towards positive infinity, i.e. `7 / 2 = 4`. + + Note: the semantics of division by zero is TBD; do NOT assume any specific + behavior. + + Example: + + ```mlir + // Scalar unsigned integer division. + %a = arith.ceildivui %b, %c : i64 + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp =================================================================== --- mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -307,6 +307,36 @@ } //===----------------------------------------------------------------------===// +// CeilDivUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::CeilDivUIOp::fold(ArrayRef operands) { + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + APInt quotient = a.udiv(b); + if (!a.urem(b)) + return quotient; + APInt one(a.getBitWidth(), 1, true); + return quotient.uadd_ov(one, overflowOrDiv0); + }); + // Fold out ceil division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return getLhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return getLhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// @@ -342,7 +372,7 @@ return zero.ssub_ov(div, overflowOrDiv0); }); - // Fold out floor division by one. Assumes all tensors of all ones are + // Fold out ceil division by one. Assumes all tensors of all ones are // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) Index: mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp =================================================================== --- mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -13,6 +13,30 @@ namespace { +/// Expands CeilDivUIOp (n, m) into +/// n == 0 ? 0 : ((n-1) / m) + 1 +struct CeilDivUIOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::CeilDivUIOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value a = op.lhs(); + Value b = op.rhs(); + Value zero = rewriter.create( + loc, rewriter.getIntegerAttr(a.getType(), 0)); + Value compare = + rewriter.create(loc, arith::CmpIPredicate::eq, a, zero); + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(a.getType(), 1)); + Value minusOne = rewriter.create(loc, a, one); + Value quotient = rewriter.create(loc, minusOne, b); + Value plusOne = rewriter.create(loc, quotient, one); + Value res = rewriter.create(loc, compare, zero, plusOne); + rewriter.replaceOp(op, {res}); + return success(); + } +}; + /// Expands CeilDivSIOp (n, m) into /// 1) x = (m > 0) ? -1 : 1 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) @@ -132,7 +156,8 @@ arith::populateArithmeticExpandOpsPatterns(patterns); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) @@ -144,8 +169,9 @@ void mlir::arith::populateArithmeticExpandOpsPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } std::unique_ptr mlir::arith::createArithmeticExpandOpsPass() { Index: mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -175,7 +175,8 @@ target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalOp([](AtomicRMWOp op) { return op.getKind() != AtomicRMWKind::maxf && op.getKind() != AtomicRMWKind::minf; Index: mlir/test/Dialect/Arithmetic/expand-ops.mlir =================================================================== --- mlir/test/Dialect/Arithmetic/expand-ops.mlir +++ mlir/test/Dialect/Arithmetic/expand-ops.mlir @@ -111,3 +111,37 @@ // CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1 // CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index } + +// ----- + +// Test ceil divide with unsigned integer +// CHECK-LABEL: func @ceildivui +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { +func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) { + %res = arith.ceildivui %arg0, %arg1 : i32 + return %res : i32 +// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32 +// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32 +// CHECK: [[ONE:%.+]] = arith.constant 1 : i32 +// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32 +// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32 +// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32 +// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : i32 +} + +// ----- + +// Test unsigned ceil divide with index +// CHECK-LABEL: func @ceildivui_index +// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index { +func @ceildivui_index(%arg0: index, %arg1: index) -> (index) { + %res = arith.ceildivui %arg0, %arg1 : index + return %res : index +// CHECK: [[ZERO:%.+]] = arith.constant 0 : index +// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index +// CHECK: [[ONE:%.+]] = arith.constant 1 : index +// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index +// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index +// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index +// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index +} Index: mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir =================================================================== --- mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir +++ mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir @@ -17,6 +17,7 @@ %c20 = arith.constant 20: i32 %c10 = arith.constant 10: i32 %cmin10 = arith.constant -10: i32 + %cmax_int = arith.constant 2147483647: i32 %A = memref.alloc() : memref<40xi32> // print numerator @@ -64,20 +65,39 @@ } call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + // test with ceildivui(*, 10) + affine.for %i = 0 to 40 { + %ii = arith.index_cast %i: index to i32 + %val = arith.ceildivui %ii, %c10 : i32 + memref.store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with ceildivui(*, -1) + affine.for %i = 0 to 40 { + %ii = arith.index_cast %i: index to i32 + %ii30 = arith.subi %ii, %c20 : i32 + %val = arith.ceildivui %ii30, %cmax_int : i32 + memref.store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + memref.dealloc %A : memref<40xi32> return } // List below is aligned for easy manual check -// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10) +// legend: num, signed_ceil(num, 10), floor(num, 10), signed_ceil(num, -10), floor(num, -10), unsigned_ceil(num, 10), unsigned_ceil(num, max_int) // ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ) // ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) -// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) -// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) -// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) +// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) +// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) +// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) // CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ) // CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) // CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) // CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) // CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) +// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 ) +// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) Index: mlir/test/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Transforms/canonicalize.mlir +++ mlir/test/Transforms/canonicalize.mlir @@ -1028,6 +1028,26 @@ // ----- +// CHECK-LABEL: func @arith.ceildivui_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @arith.ceildivui_by_one(%arg0: i32) -> (i32) { + %c1 = arith.constant 1 : i32 + %res = arith.ceildivui %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @tensor_arith.ceildivui_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_arith.ceildivui_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = arith.constant dense<1> : tensor<4x5xi32> + %res = arith.ceildivui %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} + +// ----- + // CHECK-LABEL: func @memref_cast_folding_subview func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { %0 = memref.cast %arg0 : memref<4x5xf32> to memref Index: mlir/test/Transforms/constant-fold.mlir =================================================================== --- mlir/test/Transforms/constant-fold.mlir +++ mlir/test/Transforms/constant-fold.mlir @@ -478,6 +478,44 @@ // ----- +// CHECK-LABEL: func @simple_arith.ceildivui +func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) { + // CHECK-DAG: [[C0:%.+]] = arith.constant 0 + %z = arith.constant 0 : i32 + // CHECK-DAG: [[C6:%.+]] = arith.constant 7 + %0 = arith.constant 7 : i32 + %1 = arith.constant 2 : i32 + + // ceil(7, 2) = 4 + // CHECK-NEXT: [[C3:%.+]] = arith.constant 4 : i32 + %2 = arith.ceildivui %0, %1 : i32 + + %3 = arith.constant -2 : i32 + + // ceil(7, -2) = 0 + // CHECK-NEXT: [[CM1:%.+]] = arith.constant 1 : i32 + %4 = arith.ceildivui %0, %3 : i32 + + %5 = arith.constant -8 : i32 + + // ceil(-8, 2) = 2147483644 + // CHECK-NEXT: [[CM4:%.+]] = arith.constant 2147483644 : i32 + %6 = arith.ceildivui %5, %1 : i32 + + %7 = arith.constant -15 : i32 + + // ceil(-15, -2) = 0 + // CHECK-NOT: arith.constant 1 : i32 + %8 = arith.ceildivui %7, %3 : i32 + + // CHECK-NEXT: [[XZ:%.+]] = arith.ceildivui [[C6]], [[C0]] + %9 = arith.ceildivui %0, %z : i32 + + return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32 +} + +// ----- + // CHECK-LABEL: func @simple_arith.remsi func @simple_arith.remsi(%a : i32) -> (i32, i32, i32) { %0 = arith.constant 5 : i32