Index: mlir/include/mlir/Dialect/StandardOps/IR/Ops.td =================================================================== --- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2823,6 +2823,63 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SignedFloorDivIOp +//===----------------------------------------------------------------------===// + +def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> { + let summary = "signed floor integer division operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type + ``` + + Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = floordivi_signed %b, %c : i64 + + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// SignedCeilDivIOp +//===----------------------------------------------------------------------===// + +def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> { + let summary = "signed ceil integer division operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type + ``` + + Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = ceildivi_signed %b, %c : i64 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -38,6 +38,15 @@ /// Creates an instance of std bufferization pass. std::unique_ptr createStdBufferizePass(); +/// Creates an instance of the StdExpandDivs pass that legalizes Std +/// dialect Divs to be convertible to StaLLVMndard. For example, +/// `std.ceildivi_signed` get transformed to a number of std operations, +/// which can be lowered to LLVM. +std::unique_ptr createStdExpandDivsPass(); + +/// Collects a set of patterns to rewrite ops within the Std dialect. +void populateStdToStdRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); Index: mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -22,6 +22,11 @@ let dependentDialects = ["scf::SCFDialect"]; } +def StdExpandDivs : FunctionPass<"std-expand-divs"> { + let summary = "Legalize div std dialect operations to be convertible to LLVM."; + let constructor = "mlir::createStdExpandDivsPass()"; +} + def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { let summary = "Bufferize func/call/return ops"; let description = [{ Index: mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir =================================================================== --- /dev/null +++ mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-expand-divs -convert-vector-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @transfer_read_2d(%A : memref<40xi32>, %base1: index) { + %i42 = constant -42: i32 + %f = vector.transfer_read %A[%base1], %i42 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref<40xi32>, vector<40xi32> + vector.print %f: vector<40xi32> + return +} + +func @entry() { + %c0 = constant 0: index + %c20 = constant 20: i32 + %c10 = constant 10: i32 + %cmin10 = constant -10: i32 + %A = alloc() : memref<40xi32> + + // print numerator + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + store %ii30, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with ceil(*, 10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = ceildivi_signed %ii30, %c10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with floor(*, 10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = floordivi_signed %ii30, %c10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + + // test with ceil(*, -10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = ceildivi_signed %ii30, %cmin10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with floor(*, -10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = floordivi_signed %ii30, %cmin10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + return +} + +// List below is aligned for easy manual check +// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10) +// ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ) +// ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) +// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) +// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) +// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) + +// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ) +// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) +// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) +// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) +// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) \ No newline at end of file Index: mlir/lib/Dialect/StandardOps/IR/Ops.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2887,6 +2887,113 @@ return overflowOrDiv0 ? Attribute() : result; } +//===----------------------------------------------------------------------===// +// SignedFloorDivIOp +//===----------------------------------------------------------------------===// + +static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { + // Returns (a-1)/b + 1 + APInt one(a.getBitWidth(), 1, true); // Signed value 1. + APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); + return val.sadd_ov(one, overflow); +} + +OpFoldResult SignedFloorDivIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + unsigned bits = a.getBitWidth(); + APInt zero = APInt::getNullValue(bits); + if (a.sge(zero) && b.sgt(zero)) { + // Both positive (or a is zero), return a / b. + return a.sdiv_ov(b, overflowOrDiv0); + } else if (a.sle(zero) && b.slt(zero)) { + // Both negative (or a is zero), return -a / -b. + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + return posA.sdiv_ov(posB, overflowOrDiv0); + } else if (a.slt(zero) && b.sgt(zero)) { + // A is negative, b is positive, return - ceil(-a, b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); + return zero.ssub_ov(ceil, overflowOrDiv0); + } else { + // A is positive, b is negative, return - ceil(a, -b). + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); + return zero.ssub_ov(ceil, overflowOrDiv0); + } + }); + + // Fold out floor division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// SignedCeilDivIOp +//===----------------------------------------------------------------------===// + +OpFoldResult SignedCeilDivIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + unsigned bits = a.getBitWidth(); + APInt zero = APInt::getNullValue(bits); + if (a.sgt(zero) && b.sgt(zero)) { + // Both positive, return ceil(a, b). + return signedCeilNonnegInputs(a, b, overflowOrDiv0); + } else if (a.slt(zero) && b.slt(zero)) { + // Both negative, return ceil(-a, -b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); + } else if (a.slt(zero) && b.sgt(zero)) { + // A is negative, b is positive, return - ( -a / b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt div = posA.sdiv_ov(b, overflowOrDiv0); + return zero.ssub_ov(div, overflowOrDiv0); + } else { + // A is positive (or zero), b is negative, return - (a / -b). + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + APInt div = a.sdiv_ov(posB, overflowOrDiv0); + return zero.ssub_ov(div, overflowOrDiv0); + } + }); + + // Fold out floor division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ ExpandTanh.cpp FuncBufferize.cpp FuncConversions.cpp + StdExpandDivs.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms Index: mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/StandardOps/Transforms/StdExpandDivs.cpp @@ -0,0 +1,155 @@ +//===- StdExpandDivs.cpp - Code to prepare Std for lowring Divs 0to LLVM -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file Std transformations to expand Divs operation to help for the +// lowering to LLVM. Currently implemented tranformations are Ceil and Floor +// for Signed Integers. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { + +/// Expands SignedCeilDivIOP (n, m) into +/// 1) x = (m > 0) ? -1 : 1 +/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) +struct SignedCeilDivIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SignedCeilDivIOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + SignedCeilDivIOp signedCeilDivIOp = cast(op); + Type type = signedCeilDivIOp.getType(); + Value a = signedCeilDivIOp.lhs(); + Value b = signedCeilDivIOp.rhs(); + Value plusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, 1)); + Value zero = + rewriter.create(loc, rewriter.getIntegerAttr(type, 0)); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, -1)); + // Compute x = (b>0) ? -1 : 1. + Value compare = rewriter.create(loc, CmpIPredicate::sgt, b, zero); + Value x = rewriter.create(loc, compare, minusOne, plusOne); + // Compute positive res: 1 + ((x+a)/b). + Value xPlusA = rewriter.create(loc, x, a); + Value xPlusADivB = rewriter.create(loc, xPlusA, b); + Value posRes = rewriter.create(loc, plusOne, xPlusADivB); + // Compute negative res: - ((-a)/b). + Value minusA = rewriter.create(loc, zero, a); + Value minusADivB = rewriter.create(loc, minusA, b); + Value negRes = rewriter.create(loc, zero, minusADivB); + // Result is (a*b>0) ? pos result : neg result. + // Note, we want to avoid using a*b because of possible overflow. + // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do + // not particuliarly care if a*b<0 is true or false when b is zero + // as this will result in an illegal divide. So `a*b<0` can be reformulated + // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. + // We pick the first expression here. + Value aNeg = rewriter.create(loc, CmpIPredicate::slt, a, zero); + Value aPos = rewriter.create(loc, CmpIPredicate::sgt, a, zero); + Value bNeg = rewriter.create(loc, CmpIPredicate::slt, b, zero); + Value bPos = rewriter.create(loc, CmpIPredicate::sgt, b, zero); + Value firstTerm = rewriter.create(loc, aNeg, bNeg); + Value secondTerm = rewriter.create(loc, aPos, bPos); + Value compareRes = rewriter.create(loc, firstTerm, secondTerm); + Value res = rewriter.create(loc, compareRes, posRes, negRes); + // Perform substitution and return success. + rewriter.replaceOp(op, {res}); + return success(); + } +}; + +/// Expands SignedFloorDivIOP (n, m) into +/// 1) x = (m<0) ? 1 : -1 +/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m +struct SignedFloorDivIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SignedFloorDivIOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + SignedFloorDivIOp signedFloorDivIOp = cast(op); + Type type = signedFloorDivIOp.getType(); + Value a = signedFloorDivIOp.lhs(); + Value b = signedFloorDivIOp.rhs(); + Value plusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, 1)); + Value zero = + rewriter.create(loc, rewriter.getIntegerAttr(type, 0)); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, -1)); + // Compute x = (b<0) ? 1 : -1. + Value compare = rewriter.create(loc, CmpIPredicate::slt, b, zero); + Value x = rewriter.create(loc, compare, plusOne, minusOne); + // Compute negative res: -1 - ((x-a)/b). + Value xMinusA = rewriter.create(loc, x, a); + Value xMinusADivB = rewriter.create(loc, xMinusA, b); + Value negRes = rewriter.create(loc, minusOne, xMinusADivB); + // Compute positive res: a/b. + Value posRes = rewriter.create(loc, a, b); + // Result is (a*b<0) ? negative result : positive result. + // Note, we want to avoid using a*b because of possible overflow. + // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do + // not particuliarly care if a*b<0 is true or false when b is zero + // as this will result in an illegal divide. So `a*b<0` can be reformulated + // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'. + // We pick the first expression here. + Value aNeg = rewriter.create(loc, CmpIPredicate::slt, a, zero); + Value aPos = rewriter.create(loc, CmpIPredicate::sgt, a, zero); + Value bNeg = rewriter.create(loc, CmpIPredicate::slt, b, zero); + Value bPos = rewriter.create(loc, CmpIPredicate::sgt, b, zero); + Value firstTerm = rewriter.create(loc, aNeg, bPos); + Value secondTerm = rewriter.create(loc, aPos, bNeg); + Value compareRes = rewriter.create(loc, firstTerm, secondTerm); + Value res = rewriter.create(loc, compareRes, negRes, posRes); + // Perform substitution and return success. + rewriter.replaceOp(op, {res}); + return success(); + } +}; + +} // namespace + +namespace { +struct StdExpandDivs : public StdExpandDivsBase { + void runOnFunction() override; +}; +} // namespace + +void StdExpandDivs::runOnFunction() { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + populateStdToStdRewritePatterns(&ctx, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); +} + +void mlir::populateStdToStdRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + patterns.insert( + context); +} + +std::unique_ptr mlir::createStdExpandDivsPass() { + return std::make_unique(); +} Index: mlir/test/Dialect/Standard/std-expand-divs.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Standard/std-expand-divs.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt -std-expand-divs %s -split-input-file | FileCheck %s + +// Test floor divide with signed integer +// CHECK-LABEL: func @floordivi +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { +func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { + %res = floordivi_signed %arg0, %arg1 : i32 + return %res : i32 +// CHECK: [[ONE:%.+]] = constant 1 : i32 +// CHECK: [[ZERO:%.+]] = constant 0 : i32 +// CHECK: [[MIN1:%.+]] = constant -1 : i32 +// CHECK: [[CMP1:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32 +// CHECK: [[TRUE1:%.+]] = subi [[X]], [[ARG0]] : i32 +// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32 +// CHECK: [[TRUE3:%.+]] = subi [[MIN1]], [[TRUE2]] : i32 +// CHECK: [[FALSE:%.+]] = divi_signed [[ARG0]], [[ARG1]] : i32 +// CHECK: [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32 +// CHECK: [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32 +// CHECK: [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[TERM1:%.+]] = and [[NNEG]], [[MPOS]] : i1 +// CHECK: [[TERM2:%.+]] = and [[NPOS]], [[MNEG]] : i1 +// CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1 +// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 +} + +// ----- + +// Test ceil divide with signed integer +// CHECK-LABEL: func @ceildivi +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { +func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) { + %res = ceildivi_signed %arg0, %arg1 : i32 + return %res : i32 + +// CHECK: [[ONE:%.+]] = constant 1 : i32 +// CHECK: [[ZERO:%.+]] = constant 0 : i32 +// CHECK: [[MINONE:%.+]] = constant -1 : i32 +// CHECK: [[CMP1:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : i32 +// CHECK: [[TRUE1:%.+]] = addi [[X]], [[ARG0]] : i32 +// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32 +// CHECK: [[TRUE3:%.+]] = addi [[ONE]], [[TRUE2]] : i32 +// CHECK: [[FALSE1:%.+]] = subi [[ZERO]], [[ARG0]] : i32 +// CHECK: [[FALSE2:%.+]] = divi_signed [[FALSE1]], [[ARG1]] : i32 +// CHECK: [[FALSE3:%.+]] = subi [[ZERO]], [[FALSE2]] : i32 +// CHECK: [[NNEG:%.+]] = cmpi "slt", [[ARG0]], [[ZERO]] : i32 +// CHECK: [[NPOS:%.+]] = cmpi "sgt", [[ARG0]], [[ZERO]] : i32 +// CHECK: [[MNEG:%.+]] = cmpi "slt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[MPOS:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[TERM1:%.+]] = and [[NNEG]], [[MNEG]] : i1 +// CHECK: [[TERM2:%.+]] = and [[NPOS]], [[MPOS]] : i1 +// CHECK: [[CMP2:%.+]] = or [[TERM1]], [[TERM2]] : i1 +// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32 +} Index: mlir/test/IR/core-ops.mlir =================================================================== --- mlir/test/IR/core-ops.mlir +++ mlir/test/IR/core-ops.mlir @@ -569,6 +569,30 @@ // CHECK: %{{[0-9]+}} = floorf %arg0 : tensor<4x4x?xf32> %166 = floorf %t : tensor<4x4x?xf32> + // CHECK: %{{[0-9]+}} = floordivi_signed %arg2, %arg2 : i32 + %167 = floordivi_signed %i, %i : i32 + + // CHECK: %{{[0-9]+}} = floordivi_signed %arg3, %arg3 : index + %168 = floordivi_signed %idx, %idx : index + + // CHECK: %{{[0-9]+}} = floordivi_signed %cst_5, %cst_5 : vector<42xi32> + %169 = floordivi_signed %vci32, %vci32 : vector<42 x i32> + + // CHECK: %{{[0-9]+}} = floordivi_signed %cst_4, %cst_4 : tensor<42xi32> + %170 = floordivi_signed %tci32, %tci32 : tensor<42 x i32> + + // CHECK: %{{[0-9]+}} = ceildivi_signed %arg2, %arg2 : i32 + %171 = ceildivi_signed %i, %i : i32 + + // CHECK: %{{[0-9]+}} = ceildivi_signed %arg3, %arg3 : index + %172 = ceildivi_signed %idx, %idx : index + + // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_5, %cst_5 : vector<42xi32> + %173 = ceildivi_signed %vci32, %vci32 : vector<42 x i32> + + // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_4, %cst_4 : tensor<42xi32> + %174 = ceildivi_signed %tci32, %tci32 : tensor<42 x i32> + return } Index: mlir/test/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Transforms/canonicalize.mlir +++ mlir/test/Transforms/canonicalize.mlir @@ -949,6 +949,46 @@ // ----- +// CHECK-LABEL: func @floordivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @floordivi_signed_by_one(%arg0: i32) -> (i32) { + %c1 = constant 1 : i32 + %res = floordivi_signed %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @tensor_floordivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_floordivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = constant dense<1> : tensor<4x5xi32> + %res = floordivi_signed %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} + +// ----- + +// CHECK-LABEL: func @ceildivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @ceildivi_signed_by_one(%arg0: i32) -> (i32) { + %c1 = constant 1 : i32 + %res = ceildivi_signed %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @tensor_ceildivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_ceildivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = constant dense<1> : tensor<4x5xi32> + %res = ceildivi_signed %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} + +// ----- + // CHECK-LABEL: func @memref_cast_folding_subview func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { %0 = memref_cast %arg0 : memref<4x5xf32> to memref Index: mlir/test/Transforms/constant-fold.mlir =================================================================== --- mlir/test/Transforms/constant-fold.mlir +++ mlir/test/Transforms/constant-fold.mlir @@ -402,6 +402,82 @@ // ----- +// CHECK-LABEL: func @simple_floordivi_signed +func @simple_floordivi_signed() -> (i32, i32, i32, i32, i32) { + // CHECK-DAG: [[C0:%.+]] = constant 0 + %z = constant 0 : i32 + // CHECK-DAG: [[C6:%.+]] = constant 7 + %0 = constant 7 : i32 + %1 = constant 2 : i32 + + // floor(7, 2) = 3 + // CHECK-NEXT: [[C3:%.+]] = constant 3 : i32 + %2 = floordivi_signed %0, %1 : i32 + + %3 = constant -2 : i32 + + // floor(7, -2) = -4 + // CHECK-NEXT: [[CM3:%.+]] = constant -4 : i32 + %4 = floordivi_signed %0, %3 : i32 + + %5 = constant -9 : i32 + + // floor(-9, 2) = -5 + // CHECK-NEXT: [[CM4:%.+]] = constant -5 : i32 + %6 = floordivi_signed %5, %1 : i32 + + %7 = constant -13 : i32 + + // floor(-13, -2) = 6 + // CHECK-NEXT: [[CM5:%.+]] = constant 6 : i32 + %8 = floordivi_signed %7, %3 : i32 + + // CHECK-NEXT: [[XZ:%.+]] = floordivi_signed [[C6]], [[C0]] + %9 = floordivi_signed %0, %z : i32 + + return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32 +} + +// ----- + +// CHECK-LABEL: func @simple_ceildivi_signed +func @simple_ceildivi_signed() -> (i32, i32, i32, i32, i32) { + // CHECK-DAG: [[C0:%.+]] = constant 0 + %z = constant 0 : i32 + // CHECK-DAG: [[C6:%.+]] = constant 7 + %0 = constant 7 : i32 + %1 = constant 2 : i32 + + // ceil(7, 2) = 4 + // CHECK-NEXT: [[C3:%.+]] = constant 4 : i32 + %2 = ceildivi_signed %0, %1 : i32 + + %3 = constant -2 : i32 + + // ceil(7, -2) = -3 + // CHECK-NEXT: [[CM3:%.+]] = constant -3 : i32 + %4 = ceildivi_signed %0, %3 : i32 + + %5 = constant -9 : i32 + + // ceil(-9, 2) = -4 + // CHECK-NEXT: [[CM4:%.+]] = constant -4 : i32 + %6 = ceildivi_signed %5, %1 : i32 + + %7 = constant -15 : i32 + + // ceil(-15, -2) = 8 + // CHECK-NEXT: [[CM5:%.+]] = constant 8 : i32 + %8 = ceildivi_signed %7, %3 : i32 + + // CHECK-NEXT: [[XZ:%.+]] = ceildivi_signed [[C6]], [[C0]] + %9 = ceildivi_signed %0, %z : i32 + + return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32 +} + +// ----- + // CHECK-LABEL: func @simple_remi_signed func @simple_remi_signed(%a : i32) -> (i32, i32, i32) { %0 = constant 5 : i32