diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -14,22 +14,46 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; +/// Create a float constant. +static Value createFloatConst(Location loc, Type type, double value, + OpBuilder &b) { + auto attr = b.getFloatAttr(getElementTypeOrSelf(type), value); + if (auto shapedTy = dyn_cast(type)) { + return b.create(loc, + DenseElementsAttr::get(shapedTy, attr)); + } + + return b.create(loc, attr); +} + +/// Create a float constant. +static Value createIntConst(Location loc, Type type, int64_t value, + OpBuilder &b) { + auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value); + if (auto shapedTy = dyn_cast(type)) { + return b.create(loc, + DenseElementsAttr::get(shapedTy, attr)); + } + + return b.create(loc, attr); +} + /// Expands tanh op into /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { auto floatType = op.getOperand().getType(); Location loc = op.getLoc(); - auto floatOne = rewriter.getFloatAttr(floatType, 1.0); - auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); - Value one = rewriter.create(loc, floatOne); - Value two = rewriter.create(loc, floatTwo); + Value one = createFloatConst(loc, floatType, 1.0, rewriter); + Value two = createFloatConst(loc, floatType, 2.0, rewriter); Value doubledX = rewriter.create(loc, op.getOperand(), two); // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} @@ -46,8 +70,7 @@ Value negativeRes = rewriter.create(loc, dividend, divisor); // tanh(x) = x >= 0 ? positiveRes : negativeRes - auto floatZero = rewriter.getFloatAttr(floatType, 0.0); - Value zero = rewriter.create(loc, floatZero); + Value zero = createFloatConst(loc, floatType, 0.0, rewriter); Value cmpRes = rewriter.create(loc, arith::CmpFPredicate::OGE, op.getOperand(), zero); rewriter.replaceOpWithNewOp(op, cmpRes, positiveRes, @@ -55,6 +78,7 @@ return success(); } +// Converts math.tan to math.sin, math.cos, and arith.divf. static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value operand = op.getOperand(); @@ -66,52 +90,47 @@ return success(); } +// Converts math.ctlz to scf and arith operations. This is done +// by performing a binary search on the bits. static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter) { auto operand = op.getOperand(); - auto elementTy = operand.getType(); - auto resultTy = op.getType(); + auto operandTy = operand.getType(); + auto eTy = getElementTypeOrSelf(operandTy); Location loc = op.getLoc(); - int bitWidth = elementTy.getIntOrFloatBitWidth(); - auto zero = - rewriter.create(loc, IntegerAttr::get(elementTy, 0)); - auto leadingZeros = rewriter.create( - loc, IntegerAttr::get(elementTy, bitWidth)); - - SmallVector operands = {operand, leadingZeros, zero}; - SmallVector types = {elementTy, elementTy, elementTy}; - SmallVector locations = {loc, loc, loc}; - - auto whileOp = rewriter.create( - loc, types, operands, - [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) { - // The conditional block of the while loop. - Value input = args[0]; - Value zero = args[2]; - - Value inputNotZero = beforeBuilder.create( - loc, arith::CmpIPredicate::ne, input, zero); - beforeBuilder.create(loc, inputNotZero, args); - }, - [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) { - // The body of the while loop: shift right until reaching a value of 0. - Value input = args[0]; - Value leadingZeros = args[1]; - - auto one = afterBuilder.create( - loc, IntegerAttr::get(elementTy, 1)); - auto shifted = - afterBuilder.create(loc, resultTy, input, one); - auto leadingZerosMinusOne = afterBuilder.create( - loc, resultTy, leadingZeros, one); - - afterBuilder.create( - loc, ValueRange({shifted, leadingZerosMinusOne, args[2]})); - }); - - rewriter.setInsertionPointAfter(whileOp); - rewriter.replaceOp(op, whileOp->getResult(1)); + int32_t bitwidth = eTy.getIntOrFloatBitWidth(); + if (bitwidth > 64) + return failure(); + + uint64_t allbits = -1; + if (bitwidth < 64) { + allbits = allbits >> (64 - bitwidth); + } + + Value x = operand; + Value count = createIntConst(loc, operandTy, 0, rewriter); + for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) { + auto half = bw / 2; + auto bits = createIntConst(loc, operandTy, half, rewriter); + auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter); + + Value pred = + rewriter.create(loc, arith::CmpIPredicate::ule, x, mask); + Value add = rewriter.create(loc, count, bits); + Value shift = rewriter.create(loc, x, bits); + + x = rewriter.create(loc, pred, shift, x); + count = rewriter.create(loc, pred, add, count); + } + + Value zero = createIntConst(loc, operandTy, 0, rewriter); + Value pred = rewriter.create(loc, arith::CmpIPredicate::eq, + operand, zero); + + Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter); + Value sel = rewriter.create(loc, pred, bwval, count); + rewriter.replaceOp(op, sel); return success(); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -24,6 +24,16 @@ // ----- + +// CHECK-LABEL: func @vector_tanh +func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> { + // CHECK-NOT: math.tanh + %res = math.tanh %arg : vector<4xf32> + return %res : vector<4xf32> +} + +// ----- + // CHECK-LABEL: func @tan func.func @tan(%arg: f32) -> f32 { %res = math.tan %arg : f32 @@ -33,23 +43,79 @@ // CHECK-SAME: %[[ARG0:.+]]: f32 // CHECK: %[[SIN:.+]] = math.sin %[[ARG0]] // CHECK: %[[COS:.+]] = math.cos %[[ARG0]] -// CEHCK: %[[DIV:.+]] = arith.div %[[SIN]] %[[COS]] +// CHECK: %[[DIV:.+]] = arith.divf %[[SIN]], %[[COS]] + + +// ----- + +// CHECK-LABEL: func @vector_tan +func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> { + %res = math.tan %arg : vector<4xf32> + return %res : vector<4xf32> +} + +// CHECK-NOT: math.tan // ----- -// CHECK-LABEL: func @ctlz func.func @ctlz(%arg: i32) -> i32 { - // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 - // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32 - // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32 - // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]]) - // CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]] - // CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]] - // CHECK: %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]] - // CHECK: %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]] - // CHECK: scf.yield %[[SHR]], %[[SUB]], %[[A3]] %res = math.ctlz %arg : i32 - - // CHECK: return %[[WHILE]]#1 return %res : i32 } + +// CHECK-LABEL: @ctlz +// CHECK-SAME: %[[ARG0:.+]]: i32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 +// CHECK-DAG: %[[C65535:.+]] = arith.constant 65535 +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 +// CHECK-DAG: %[[C16777215:.+]] = arith.constant 16777215 +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 +// CHECK-DAG: %[[C268435455:.+]] = arith.constant 268435455 +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 +// CHECK-DAG: %[[C1073741823:.+]] = arith.constant 1073741823 +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 +// CHECK-DAG: %[[C2147483647:.+]] = arith.constant 2147483647 +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 + +// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[ARG0]], %[[C65535]] +// CHECK: %[[SHL:.+]] = arith.shli %[[ARG0]], %[[C16]] +// CHECK: %[[SELX0:.+]] = arith.select %[[PRED]], %[[SHL]], %[[ARG0]] +// CHECK: %[[SELY0:.+]] = arith.select %[[PRED]], %[[C16]], %[[C0]] + +// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX0]], %[[C16777215]] +// CHECK: %[[ADD:.+]] = arith.addi %[[SELY0]], %[[C8]] +// CHECK: %[[SHL:.+]] = arith.shli %[[SELX0]], %[[C8]] +// CHECK: %[[SELX1:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX0]] +// CHECK: %[[SELY1:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY0]] + +// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX1]], %[[C268435455]] : i32 +// CHECK: %[[ADD:.+]] = arith.addi %[[SELY1]], %[[C4]] +// CHECK: %[[SHL:.+]] = arith.shli %[[SELX1]], %[[C4]] +// CHECK: %[[SELX2:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX1]] +// CHECK: %[[SELY2:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY1]] + + +// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX2]], %[[C1073741823]] : i32 +// CHECK: %[[ADD:.+]] = arith.addi %[[SELY2]], %[[C2]] +// CHECK: %[[SHL:.+]] = arith.shli %[[SELX2]], %[[C2]] +// CHECK: %[[SELX3:.+]] = arith.select %[[PRED]], %[[SHL]], %[[SELX2]] +// CHECK: %[[SELY3:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY2]] + +// CHECK: %[[PRED:.+]] = arith.cmpi ule, %[[SELX3]], %[[C2147483647]] : i32 +// CHECK: %[[ADD:.+]] = arith.addi %[[SELY3]], %[[C1]] +// CHECK: %[[SELY4:.+]] = arith.select %[[PRED]], %[[ADD]], %[[SELY3]] + +// CHECK: %[[PRED:.+]] = arith.cmpi eq, %[[ARG0]], %[[C0]] : i32 +// CHECK: %[[SEL:.+]] = arith.select %[[PRED]], %[[C32]], %[[SELY4]] : i32 +// CHECK: return %[[SEL]] + +// ----- + +func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> { + %res = math.ctlz %arg : vector<4xi32> + return %res : vector<4xi32> +} + +// CHECK-LABEL: @ctlz_vector +// CHECK-NOT: math.ctlz diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -26,7 +27,8 @@ void runOnOperation() override; StringRef getArgument() const final { return "test-expand-math"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry + .insert(); } StringRef getDescription() const final { return "Test expanding math"; } };