diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -13,6 +13,7 @@ class RewritePatternSet; +void populateExpandCtlzPattern(RewritePatternSet &patterns); void populateExpandTanhPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp - ExpandTanh.cpp + ExpandPatterns.cpp PolynomialApproximation.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp rename from mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp rename to mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/Transforms/DialectConversion.h" @@ -53,6 +54,67 @@ return success(); } +static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, + PatternRewriter &rewriter) { + auto operand = op.getOperand(); + auto elementTy = operand.getType(); + auto resultTy = op.getType(); + Location loc = op.getLoc(); + + int bitWidth = elementTy.getIntOrFloatBitWidth(); + auto zero = + rewriter.create(loc, IntegerAttr::get(elementTy, 0)); + auto leadingZeros = rewriter.create( + loc, IntegerAttr::get(elementTy, bitWidth)); + + SmallVector operands = {operand, leadingZeros, zero}; + SmallVector types = {elementTy, elementTy, elementTy}; + SmallVector locations = {loc, loc, loc}; + + auto whileOp = rewriter.create(loc, types, operands); + Block *before = + rewriter.createBlock(&whileOp.getBefore(), {}, types, locations); + Block *after = + rewriter.createBlock(&whileOp.getAfter(), {}, types, locations); + + // The conditional block of the while loop. + { + rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); + Value input = before->getArgument(0); + Value zero = before->getArgument(2); + + Value inputNotZero = rewriter.create( + loc, arith::CmpIPredicate::ne, input, zero); + rewriter.create(loc, inputNotZero, + before->getArguments()); + } + + // The body of the while loop: shift right until reaching a value of 0. + { + rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); + Value input = after->getArgument(0); + Value leadingZeros = after->getArgument(1); + + auto one = + rewriter.create(loc, IntegerAttr::get(elementTy, 1)); + auto shifted = rewriter.create(loc, resultTy, input, one); + auto leadingZerosMinusOne = + rewriter.create(loc, resultTy, leadingZeros, one); + + rewriter.create( + loc, + ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); + } + + rewriter.setInsertionPointAfter(whileOp); + rewriter.replaceOp(op, whileOp->getResult(1)); + return success(); +} + +void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { + patterns.add(convertCtlzOp); +} + void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } diff --git a/mlir/test/Dialect/Math/expand-tanh.mlir b/mlir/test/Dialect/Math/expand-math.mlir rename from mlir/test/Dialect/Math/expand-tanh.mlir rename to mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-tanh.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-expand-tanh | FileCheck %s +// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s // CHECK-LABEL: func @tanh func.func @tanh(%arg: f32) -> f32 { @@ -21,3 +21,22 @@ // CHECK: %[[COND:.+]] = arith.cmpf oge, %arg0, %[[ZERO]] : f32 // CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32 // CHECK: return %[[RESULT]] + +// ---- + +// CHECK-LABEL: func @ctlz +func.func @ctlz(%arg: i32) -> i32 { + // CHECK: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK: %[[C32:.+]] = arith.constant 32 : i32 + // CHECK: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]]) + // CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]] + // CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]] + // CHECK: %[[SHR:.+]] = arith.shrui %[[A1]], %[[C1]] + // CHECK: %[[SUB:.+]] = arith.subi %[[A2]], %[[C1]] + // CHECK: scf.yield %[[SHR]], %[[SUB]], %[[A3]] + %res = math.ctlz %arg : i32 + + // CHECK: return %[[WHILE]]#1 + return %res : i32 +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -1,7 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMathTestPasses TestAlgebraicSimplification.cpp - TestExpandTanh.cpp + TestExpandMath.cpp TestPolynomialApproximation.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Math/TestExpandTanh.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp rename from mlir/test/lib/Dialect/Math/TestExpandTanh.cpp rename to mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandTanh.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -1,4 +1,4 @@ -//===- TestExpandTanh.cpp - Test expand tanh op into exp form -------------===// +//===- TestExpandMath.cpp - Test expand math op into exp form -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,35 +6,41 @@ // //===----------------------------------------------------------------------===// // -// This file contains test passes for expanding tanh. +// This file contains test passes for expanding math operations. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; namespace { -struct TestExpandTanhPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandTanhPass) +struct TestExpandMathPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass) void runOnOperation() override; - StringRef getArgument() const final { return "test-expand-tanh"; } - StringRef getDescription() const final { return "Test expanding tanh"; } + StringRef getArgument() const final { return "test-expand-math"; } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getDescription() const final { return "Test expanding math"; } }; } // namespace -void TestExpandTanhPass::runOnOperation() { +void TestExpandMathPass::runOnOperation() { RewritePatternSet patterns(&getContext()); + populateExpandCtlzPattern(patterns); populateExpandTanhPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } namespace mlir { namespace test { -void registerTestExpandTanhPass() { PassRegistration(); } +void registerTestExpandMathPass() { PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -76,7 +76,7 @@ void registerTestDiagnosticsPass(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); -void registerTestExpandTanhPass(); +void registerTestExpandMathPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); void registerTestIRVisitorsPass(); @@ -172,7 +172,7 @@ mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); - mlir::test::registerTestExpandTanhPass(); + mlir::test::registerTestExpandMathPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestIRVisitorsPass();