Index: mlir/include/mlir/Dialect/StandardOps/IR/Ops.td =================================================================== --- mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2718,6 +2718,63 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// SignedFloorDivIOp +//===----------------------------------------------------------------------===// + +def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> { + let summary = "signed floor integer division operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type + ``` + + Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = floordivi_signed %b, %c : i64 + + ``` + }]; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// SignedCeilDivIOp +//===----------------------------------------------------------------------===// + +def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> { + let summary = "signed ceil integer division operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type + ``` + + Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. + + Note: the semantics of division by zero or signed division overflow (minimum + value divided by -1) is TBD; do NOT assume any specific behavior. + + Example: + + ```mlir + // Scalar signed integer division. + %a = ceildivi_signed %b, %c : i64 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h =================================================================== --- mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -35,6 +35,16 @@ /// Creates an instance of std bufferization pass. std::unique_ptr createStdBufferizePass(); +/// Creates an instance of the StdToStdLowering pass that legalizes Std +/// dialect to be convertible to StaLLVMndard. For example, +/// `std.ceildivi_signed` get transformed to a number of std operations, +/// which can be lowered to LLVM. +std::unique_ptr createStdToStdLowering(); + +/// Collects a set of patterns to rewrite ops within the Std dialect. +void populateStdToStdRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td =================================================================== --- mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -22,4 +22,9 @@ let dependentDialects = ["scf::SCFDialect"]; } +def StdToStdLowering : FunctionPass<"std-to-std-lowering"> { + let summary = "Legalize std dialect to be convertible to LLVM."; + let constructor = "mlir::createStdToStdLowering()"; +} + #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES Index: mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir =================================================================== --- /dev/null +++ mlir/integration_test/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir @@ -0,0 +1,82 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -std-to-std-lowering -convert-vector-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @transfer_read_2d(%A : memref<40xi32>, %base1: index) { + %i42 = constant -42: i32 + %f = vector.transfer_read %A[%base1], %i42 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref<40xi32>, vector<40xi32> + vector.print %f: vector<40xi32> + return +} + +func @entry() { + %c0 = constant 0: index + %c20 = constant 20: i32 + %c10 = constant 10: i32 + %cmin10 = constant -10: i32 + %A = alloc() : memref<40xi32> + + // print numerator + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + store %ii30, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with ceil(*, 10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = ceildivi_signed %ii30, %c10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with floor(*, 10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = floordivi_signed %ii30, %c10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + + // test with ceil(*, -10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = ceildivi_signed %ii30, %cmin10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + // test with floor(*, -10) + affine.for %i = 0 to 40 { + %ii = index_cast %i: index to i32 + %ii30 = subi %ii, %c20 : i32 + %val = floordivi_signed %ii30, %cmin10 : i32 + store %val, %A[%i] : memref<40xi32> + } + call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> () + + return +} + +// List below is aligned for easy manual check +// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10) +// ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ) +// ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) +// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) +// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) +// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) + +// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ) +// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 ) +// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) +// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 ) +// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 ) \ No newline at end of file Index: mlir/lib/Dialect/StandardOps/IR/Ops.cpp =================================================================== --- mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2737,6 +2737,113 @@ return overflowOrDiv0 ? Attribute() : result; } +//===----------------------------------------------------------------------===// +// SignedFloorDivIOp +//===----------------------------------------------------------------------===// + +static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { + // Returns (a-1)/b + 1 + APInt one(a.getBitWidth(), 1, true); // Signed value 1. + APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); + return val.sadd_ov(one, overflow); +} + +OpFoldResult SignedFloorDivIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + unsigned bits = a.getBitWidth(); + APInt zero = APInt::getNullValue(bits); + if (a.sge(zero) && b.sgt(zero)) { + // Both positive (or a is zero), return a / b. + return a.sdiv_ov(b, overflowOrDiv0); + } else if (a.sle(zero) && b.slt(zero)) { + // Both negative (or a is zero), return -a / -b. + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + return posA.sdiv_ov(posB, overflowOrDiv0); + } else if (a.slt(zero) && b.sgt(zero)) { + // A is negative, b is positive, return - ceil(-a, b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); + return zero.ssub_ov(ceil, overflowOrDiv0); + } else { + // A is positive, b is negative, return - ceil(a, -b). + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); + return zero.ssub_ov(ceil, overflowOrDiv0); + } + }); + + // Fold out floor division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + +//===----------------------------------------------------------------------===// +// SignedCeilDivIOp +//===----------------------------------------------------------------------===// + +OpFoldResult SignedCeilDivIOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "binary operation takes two operands"); + + // Don't fold if it would overflow or if it requires a division by zero. + bool overflowOrDiv0 = false; + auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { + if (overflowOrDiv0 || !b) { + overflowOrDiv0 = true; + return a; + } + unsigned bits = a.getBitWidth(); + APInt zero = APInt::getNullValue(bits); + if (a.sgt(zero) && b.sgt(zero)) { + // Both positive, return ceil(a, b). + return signedCeilNonnegInputs(a, b, overflowOrDiv0); + } else if (a.slt(zero) && b.slt(zero)) { + // Both negative, return ceil(-a, -b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); + } else if (a.slt(zero) && b.sgt(zero)) { + // A is negative, b is positive, return - ( -a / b). + APInt posA = zero.ssub_ov(a, overflowOrDiv0); + APInt div = posA.sdiv_ov(b, overflowOrDiv0); + return zero.ssub_ov(div, overflowOrDiv0); + } else { + // A is positive (or zero), b is negative, return - (a / -b). + APInt posB = zero.ssub_ov(b, overflowOrDiv0); + APInt div = a.sdiv_ov(posB, overflowOrDiv0); + return zero.ssub_ov(div, overflowOrDiv0); + } + }); + + // Fold out floor division by one. Assumes all tensors of all ones are + // splats. + if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getValue() == 1) + return lhs(); + } else if (auto rhs = operands[1].dyn_cast_or_null()) { + if (rhs.getSplatValue().getValue() == 1) + return lhs(); + } + + return overflowOrDiv0 ? Attribute() : result; +} + //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// @@ -2799,7 +2906,8 @@ assert(shapedType.getElementType() == constOperand.getType() && "incorrect input attribute type for folding"); - // SplatElementsAttr::get treats single value for second arg as being a splat. + // SplatElementsAttr::get treats single value for second arg as being a + // splat. return SplatElementsAttr::get(shapedType, {constOperand}); } @@ -3296,12 +3404,11 @@ } namespace { - -/// Take a list of `values` with potential new constant to extract and a list -/// of `constantValues` with`values.size()` sentinel that evaluate to true by -/// applying `isDynamic`. -/// Detects the `values` produced by a ConstantIndexOp and places the new -/// constant in place of the corresponding sentinel value. +/// Take a list of `values` with potential new constant to extract and a +/// list of `constantValues` with`values.size()` sentinel that evaluate to +/// true by applying `isDynamic`. Detects the `values` produced by a +/// ConstantIndexOp and places the new constant in place of the +/// corresponding sentinel value. void canonicalizeSubViewPart(SmallVectorImpl &values, SmallVectorImpl &constantValues, llvm::function_ref isDynamic) { @@ -3353,7 +3460,8 @@ return failure(); // At least one of offsets/sizes/strides is a new constant. - // Form the new list of operands and constant attributes from the existing. + // Form the new list of operands and constant attributes from the + // existing. SmallVector newOffsets(op.offsets()); SmallVector newStaticOffsets = extractFromI64ArrayAttr(op.static_offsets()); @@ -3550,7 +3658,8 @@ LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { - // Any constant operand, just return to let SubViewOpConstantFolder kick in. + // Any constant operand, just return to let SubViewOpConstantFolder + // kick in. if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) @@ -3563,9 +3672,10 @@ if (!canFoldIntoConsumerOp(castOp)) return failure(); - /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on - /// the cast source operand type and the SubViewOp static information. This - /// is the resulting type if the MemRefCastOp were folded. + /// Deduce the resultType of the SubViewOp using + /// `inferSubViewResultType` on the cast source operand type and the + /// SubViewOp static information. This is the resulting type if the + /// MemRefCastOp were folded. Type resultType = SubViewOp::inferResultType( castOp.source().getType().cast(), extractFromI64ArrayAttr(subViewOp.static_offsets()), @@ -3805,7 +3915,6 @@ } namespace { - /// Replaces chains of two tensor_cast operations by a single tensor_cast /// operation if doing so does not remove runtime constraints. struct ChainedTensorCast : public OpRewritePattern { @@ -3824,18 +3933,19 @@ auto intermediateType = tensorCastOperand.getType().cast(); auto resultType = tensorCast.getType().cast(); - // We can remove the intermediate cast if joining all three produces the - // same result as just joining the source and result shapes. + // We can remove the intermediate cast if joining all three produces + // the same result as just joining the source and result shapes. auto firstJoin = joinShapes(joinShapes(sourceType, intermediateType), resultType); - // The join might not exist if the cast sequence would fail at runtime. + // The join might not exist if the cast sequence would fail at + // runtime. if (!firstJoin) return failure(); - // The newJoin always exists if the above join exists, it might just contain - // less information. If so, we cannot drop the intermediate cast, as doing - // so would remove runtime checks. + // The newJoin always exists if the above join exists, it might just + // contain less information. If so, we cannot drop the intermediate + // cast, as doing so would remove runtime checks. auto newJoin = joinShapes(sourceType, resultType); if (firstJoin != newJoin) return failure(); @@ -4123,7 +4233,6 @@ Value ViewOp::getViewSource() { return source(); } namespace { - struct ViewOpShapeFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -4167,7 +4276,8 @@ // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); } else { - // Dynamic shape dimension not folded; copy operand from old memref. + // Dynamic shape dimension not folded; copy operand from old + // memref. newShapeConstants.push_back(dimSize); newOperands.push_back(viewOp.sizes()[dynamicDimPos]); } Index: mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ ExpandAtomic.cpp ExpandTanh.cpp FuncConversions.cpp + StdToStdLowering.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms Index: mlir/lib/Dialect/StandardOps/Transforms/StdToStdLowering.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/StandardOps/Transforms/StdToStdLowering.cpp @@ -0,0 +1,149 @@ +//===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of tanh op. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/Utils.h" + +using namespace mlir; + +namespace { + +/// Expands SignedCeilDivIOP (n, m) into +/// 1) x = (m > 0) ? -1 : 1 +/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) + +struct SignedCeilDivIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SignedCeilDivIOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + SignedCeilDivIOp signedCeilDivIOp = cast(op); + Type type = signedCeilDivIOp.getType(); + Value a = signedCeilDivIOp.lhs(); + Value b = signedCeilDivIOp.rhs(); + Value plusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, 1)); + Value zero = + rewriter.create(loc, rewriter.getIntegerAttr(type, 0)); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, -1)); + // Compute x = (b>0) ? -1 : 1. + Value compare = rewriter.create(loc, CmpIPredicate::sgt, b, zero); + Value x = rewriter.create(loc, compare, minusOne, plusOne); + // Compute positive res: 1 + ((x+a)/b). + Value xPlusA = rewriter.create(loc, x, a); + Value xPlusADivB = rewriter.create(loc, xPlusA, b); + Value posRes = rewriter.create(loc, plusOne, xPlusADivB); + // Compute negative res: - ((-a)/b). + Value minusA = rewriter.create(loc, zero, a); + Value minusADivB = rewriter.create(loc, minusA, b); + Value negRes = rewriter.create(loc, zero, minusADivB); + // Result is (a*b>0) ? pos result : neg result. + Value aTimesB = rewriter.create(loc, a, b); + Value compareRes = + rewriter.create(loc, CmpIPredicate::sgt, aTimesB, zero); + Value res = rewriter.create(loc, compareRes, posRes, negRes); + // Perform substitution and return success. + rewriter.replaceOp(op, {res}); + return success(); + } +}; + +/// Expands SignedFloorDivIOP (n, m) into +/// 1) x = (m<0) ? 1 : -1 +/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m + +struct SignedFloorDivIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SignedFloorDivIOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + SignedFloorDivIOp signedFloorDivIOp = cast(op); + Type type = signedFloorDivIOp.getType(); + Value a = signedFloorDivIOp.lhs(); + Value b = signedFloorDivIOp.rhs(); + Value plusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, 1)); + Value zero = + rewriter.create(loc, rewriter.getIntegerAttr(type, 0)); + Value minusOne = + rewriter.create(loc, rewriter.getIntegerAttr(type, -1)); + // Compute x = (b<0) ? 1 : -1. + Value compare = rewriter.create(loc, CmpIPredicate::slt, b, zero); + Value x = rewriter.create(loc, compare, plusOne, minusOne); + // Compute negative res: -1 - ((x-a)/b). + Value xMinusA = rewriter.create(loc, x, a); + Value xMinusADivB = rewriter.create(loc, xMinusA, b); + Value negRes = rewriter.create(loc, minusOne, xMinusADivB); + // Compute positive res: a/b. + Value posRes = rewriter.create(loc, a, b); + // Result is (a*b<0) ? negative result : positive result. + Value aTimesB = rewriter.create(loc, a, b); + Value compareRes = + rewriter.create(loc, CmpIPredicate::slt, aTimesB, zero); + Value res = rewriter.create(loc, compareRes, negRes, posRes); + // Perform substitution and return success. + rewriter.replaceOp(op, {res}); + return success(); + } +}; + +} // namespace + +namespace { +struct StdToStdLowering : public StdToStdLoweringBase { + void runOnFunction() override; +}; +} // namespace + +void StdToStdLowering::runOnFunction() { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + populateStdToStdRewritePatterns(&ctx, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + if (failed(mlir::applyPartialConversion(getFunction(), target, + std::move(patterns)))) + signalPassFailure(); +} + +void mlir::populateStdToStdRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + patterns.insert( + context); +} + +std::unique_ptr mlir::createStdToStdLowering() { + return std::make_unique(); +} Index: mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir =================================================================== --- mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -131,3 +131,4 @@ %0 = transpose %arg0 (i, j, k) -> (k, i, j) : memref to memref (d2 * s1 + s0 + d0 * s2 + d1)>> return } + Index: mlir/test/Dialect/Standard/std-to-std.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Standard/std-to-std.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt -std-to-std-lowering %s -split-input-file | FileCheck %s + +// Test floor divide with signed integer +// CHECK-LABEL: func @floordivi +// CHECK-SAME: ([[VAR_arg0:%.+]]: i32, [[VAR_arg1:%.+]]: i32) -> i32 { +func @floordivi(%arg0: i32, %arg1: i32) -> (i32) { + %res = floordivi_signed %arg0, %arg1 : i32 + return %res : i32 +// CHECK: [[ONE:%.+]] = constant 1 : i32 +// CHECK: [[ZERO:%.+]] = constant 0 : i32 +// CHECK: [[MIN1:%.+]] = constant -1 : i32 +// CHECK: [[CMP1:%.+]] = cmpi "slt", [[VAR_arg1]], [[ZERO]] : i32 +// CHECK: [[X:%.+]] = select [[CMP1]], [[ONE]], [[MIN1]] : i32 +// CHECK: [[TRUE1:%.+]] = subi [[X]], [[VAR_arg0]] : i32 +// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[VAR_arg1]] : i32 +// CHECK: [[TRUE3:%.+]] = subi [[MIN1]], [[TRUE2]] : i32 +// CHECK: [[FALSE:%.+]] = divi_signed [[VAR_arg0]], [[VAR_arg1]] : i32 +// CHECK: [[VAL:%.+]] = muli [[VAR_arg0]], [[VAR_arg1]] : i32 +// CHECK: [[CMP2:%.+]] = cmpi "slt", [[VAL]], [[ZERO]] : i32 +// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : i32 +// CHECK: return [[RES]] : i32 +// CHECK: } +} + +// ----- + +// Test ceil divide with signed integer +// CHECK-LABEL: func @ceildivi +// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 { +func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) { + %res = ceildivi_signed %arg0, %arg1 : i32 + return %res : i32 + +// CHECK: [[ONE:%.+]] = constant 1 : i32 +// CHECK: [[ZERO:%.+]] = constant 0 : i32 +// CHECK: [[MINONE:%.+]] = constant -1 : i32 +// CHECK: [[CMP1:%.+]] = cmpi "sgt", [[ARG1]], [[ZERO]] : i32 +// CHECK: [[X:%.+]] = select [[CMP1]], [[MINONE]], [[ONE]] : i32 +// CHECK: [[TRUE1:%.+]] = addi [[X]], [[ARG0]] : i32 +// CHECK: [[TRUE2:%.+]] = divi_signed [[TRUE1]], [[ARG1]] : i32 +// CHECK: [[TRUE3:%.+]] = addi [[ONE]], [[TRUE2]] : i32 +// CHECK: [[FALSE1:%.+]] = subi [[ZERO]], [[ARG0]] : i32 +// CHECK: [[FALSE2:%.+]] = divi_signed [[FALSE1]], [[ARG1]] : i32 +// CHECK: [[FALSE3:%.+]] = subi [[ZERO]], [[FALSE2]] : i32 +// CHECK: [[VAL:%.+]] = muli [[ARG0]], [[ARG1]] : i32 +// CHECK: [[CMP2:%.+]] = cmpi "sgt", [[VAL]], [[ZERO]] : i32 +// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32 +// CHECK: return [[RES]] : i32 +// CHECK: } +// CHECK: } +} Index: mlir/test/IR/core-ops.mlir =================================================================== --- mlir/test/IR/core-ops.mlir +++ mlir/test/IR/core-ops.mlir @@ -569,6 +569,30 @@ // CHECK: %{{[0-9]+}} = floorf %arg0 : tensor<4x4x?xf32> %166 = floorf %t : tensor<4x4x?xf32> + // CHECK: %{{[0-9]+}} = floordivi_signed %arg2, %arg2 : i32 + %167 = floordivi_signed %i, %i : i32 + + // CHECK: %{{[0-9]+}} = floordivi_signed %arg3, %arg3 : index + %168 = floordivi_signed %idx, %idx : index + + // CHECK: %{{[0-9]+}} = floordivi_signed %cst_5, %cst_5 : vector<42xi32> + %169 = floordivi_signed %vci32, %vci32 : vector<42 x i32> + + // CHECK: %{{[0-9]+}} = floordivi_signed %cst_4, %cst_4 : tensor<42xi32> + %170 = floordivi_signed %tci32, %tci32 : tensor<42 x i32> + + // CHECK: %{{[0-9]+}} = ceildivi_signed %arg2, %arg2 : i32 + %171 = ceildivi_signed %i, %i : i32 + + // CHECK: %{{[0-9]+}} = ceildivi_signed %arg3, %arg3 : index + %172 = ceildivi_signed %idx, %idx : index + + // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_5, %cst_5 : vector<42xi32> + %173 = ceildivi_signed %vci32, %vci32 : vector<42 x i32> + + // CHECK: %{{[0-9]+}} = ceildivi_signed %cst_4, %cst_4 : tensor<42xi32> + %174 = ceildivi_signed %tci32, %tci32 : tensor<42 x i32> + return } Index: mlir/test/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Transforms/canonicalize.mlir +++ mlir/test/Transforms/canonicalize.mlir @@ -949,6 +949,46 @@ // ----- +// CHECK-LABEL: func @floordivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @floordivi_signed_by_one(%arg0: i32) -> (i32) { + %c1 = constant 1 : i32 + %res = floordivi_signed %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @tensor_floordivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_floordivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = constant dense<1> : tensor<4x5xi32> + %res = floordivi_signed %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} + +// ----- + +// CHECK-LABEL: func @ceildivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @ceildivi_signed_by_one(%arg0: i32) -> (i32) { + %c1 = constant 1 : i32 + %res = ceildivi_signed %arg0, %c1 : i32 + // CHECK: return %[[ARG]] + return %res : i32 +} + +// CHECK-LABEL: func @tensor_ceildivi_signed_by_one +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]] +func @tensor_ceildivi_signed_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> { + %c1 = constant dense<1> : tensor<4x5xi32> + %res = ceildivi_signed %arg0, %c1 : tensor<4x5xi32> + // CHECK: return %[[ARG]] + return %res : tensor<4x5xi32> +} + +// ----- + // CHECK-LABEL: func @memref_cast_folding_subview func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { %0 = memref_cast %arg0 : memref<4x5xf32> to memref Index: mlir/test/Transforms/constant-fold.mlir =================================================================== --- mlir/test/Transforms/constant-fold.mlir +++ mlir/test/Transforms/constant-fold.mlir @@ -402,6 +402,84 @@ // ----- +// CHECK-LABEL: func @simple_floordivi_signed +func @simple_floordivi_signed() -> (i32, i32, i32, i32, i32) { + // CHECK-DAG: [[C0:%.+]] = constant 0 + %z = constant 0 : i32 + // CHECK-DAG: [[C6:%.+]] = constant 7 + %0 = constant 7 : i32 + %1 = constant 2 : i32 + + // floor(7, 2) = 3 + // CHECK-NEXT: [[C3:%.+]] = constant 3 : i32 + %2 = floordivi_signed %0, %1 : i32 + + %3 = constant -2 : i32 + + // floor(7, -2) = -4 + // CHECK-NEXT: [[CM3:%.+]] = constant -4 : i32 + %4 = floordivi_signed %0, %3 : i32 + + %5 = constant -9 : i32 + + // floor(-9, 2) = -5 + // CHECK-NEXT: [[CM4:%.+]] = constant -5 : i32 + %6 = floordivi_signed %5, %1 : i32 + + %7 = constant -13 : i32 + + // floor(-13, -2) = 6 + // CHECK-NEXT: [[CM5:%.+]] = constant 6 : i32 + %8 = floordivi_signed %7, %3 : i32 + + // CHECK-NEXT: [[XZ:%.+]] = floordivi_signed [[C6]], [[C0]] + %9 = floordivi_signed %0, %z : i32 + + // CHECK-NEXT: return [[C3]], [[CM3]], [[CM4]], [[CM5]], [[XZ]] + return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32 +} + +// ----- + +// CHECK-LABEL: func @simple_ceildivi_signed +func @simple_ceildivi_signed() -> (i32, i32, i32, i32, i32) { + // CHECK-DAG: [[C0:%.+]] = constant 0 + %z = constant 0 : i32 + // CHECK-DAG: [[C6:%.+]] = constant 7 + %0 = constant 7 : i32 + %1 = constant 2 : i32 + + // ceil(7, 2) = 4 + // CHECK-NEXT: [[C3:%.+]] = constant 4 : i32 + %2 = ceildivi_signed %0, %1 : i32 + + %3 = constant -2 : i32 + + // ceil(7, -2) = -3 + // CHECK-NEXT: [[CM3:%.+]] = constant -3 : i32 + %4 = ceildivi_signed %0, %3 : i32 + + %5 = constant -9 : i32 + + // ceil(-9, 2) = -4 + // CHECK-NEXT: [[CM4:%.+]] = constant -4 : i32 + %6 = ceildivi_signed %5, %1 : i32 + + %7 = constant -15 : i32 + + // ceil(-15, -2) = 8 + // CHECK-NEXT: [[CM5:%.+]] = constant 8 : i32 + %8 = ceildivi_signed %7, %3 : i32 + + // CHECK-NEXT: [[XZ:%.+]] = ceildivi_signed [[C6]], [[C0]] + %9 = ceildivi_signed %0, %z : i32 + + // CHECK-NEXT: return [[C3]], [[CM3]], [[CM4]], [[CM5]], [[XZ]] + return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32 +} + +// ----- + // CHECK-LABEL: func @simple_remi_signed func @simple_remi_signed(%a : i32) -> (i32, i32, i32) { %0 = constant 5 : i32