diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -618,6 +618,7 @@ "arith::ArithDialect", "cf::ControlFlowDialect", "func::FuncDialect", + "scf::SCFDialect", "vector::VectorDialect", "LLVM::LLVMDialect", ]; @@ -637,10 +638,10 @@ Pass<"finalize-memref-to-llvm", "ModuleOp"> { let summary = "Finalize MemRef dialect to LLVM dialect conversion"; let description = [{ - Finalize the conversion of the operations from the MemRef + Finalize the conversion of the operations from the MemRef dialect to the LLVM dialect. - This conversion will not convert some complex MemRef - operations. Make sure to run `expand-strided-metadata` + This conversion will not convert some complex MemRef + operations. Make sure to run `expand-strided-metadata` beforehand for these. }]; let dependentDialects = ["LLVM::LLVMDialect"]; diff --git a/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt --- a/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToFuncs/CMakeLists.txt @@ -17,6 +17,7 @@ MLIRLLVMDialect MLIRMathDialect MLIRPass + MLIRSCFDialect MLIRTransforms MLIRVectorDialect MLIRVectorUtils diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" @@ -41,14 +42,14 @@ }; // Callback type for getting pre-generated FuncOp implementing -// a power operation of the given type. -using GetPowerFuncCallbackTy = function_ref; +// an operation of the given type. +using GetFuncCallbackTy = function_ref; // Pattern to convert scalar IPowIOp into a call of outlined // software implementation. class IPowIOpLowering : public OpRewritePattern { public: - IPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb) + IPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb) : OpRewritePattern(context), getFuncOpCallback(cb) {} /// Convert IPowI into a call to a local function implementing @@ -58,14 +59,14 @@ PatternRewriter &rewriter) const final; private: - GetPowerFuncCallbackTy getFuncOpCallback; + GetFuncCallbackTy getFuncOpCallback; }; // Pattern to convert scalar FPowIOp into a call of outlined // software implementation. class FPowIOpLowering : public OpRewritePattern { public: - FPowIOpLowering(MLIRContext *context, GetPowerFuncCallbackTy cb) + FPowIOpLowering(MLIRContext *context, GetFuncCallbackTy cb) : OpRewritePattern(context), getFuncOpCallback(cb) {} /// Convert FPowI into a call to a local function implementing @@ -75,7 +76,24 @@ PatternRewriter &rewriter) const final; private: - GetPowerFuncCallbackTy getFuncOpCallback; + GetFuncCallbackTy getFuncOpCallback; +}; + +// Pattern to convert scalar ctlz into a call of outlined software +// implementation. +class CtlzOpLowering : public OpRewritePattern { +public: + CtlzOpLowering(MLIRContext *context, GetFuncCallbackTy cb) + : OpRewritePattern(context), + getFuncOpCallback(cb) {} + + /// Convert ctlz into a call to a local function implementing + /// the count leading zeros operation. + LogicalResult matchAndRewrite(math::CountLeadingZerosOp op, + PatternRewriter &rewriter) const final; + +private: + GetFuncCallbackTy getFuncOpCallback; }; } // namespace @@ -346,7 +364,7 @@ // The outlined software implementation must have been already // generated. - func::FuncOp elementFunc = getFuncOpCallback(baseType); + func::FuncOp elementFunc = getFuncOpCallback(op, baseType); if (!elementFunc) return rewriter.notifyMatchFailure(op, "missing software implementation"); @@ -571,7 +589,7 @@ // The outlined software implementation must have been already // generated. - func::FuncOp elementFunc = getFuncOpCallback(funcType); + func::FuncOp elementFunc = getFuncOpCallback(op, funcType); if (!elementFunc) return rewriter.notifyMatchFailure(op, "missing software implementation"); @@ -579,6 +597,146 @@ return success(); } +/// Create function to implement the ctlz function the given \p elementType type +/// inside \p module. The \p elementType must be IntegerType, an the created +/// function has 'IntegerType (*)(IntegerType)' function type. +/// +/// template +/// T __mlir_math_ctlz_*(T x) { +/// uint32_t n = 0; +/// bits = sizeof(x) * 8; +/// for (int i = 1; i < bits; ++i) { +/// if (x < 0) break; +/// n++; +/// x <<= 1; +/// } +/// return n; +/// } +/// +/// Converts to (for i32): +/// +/// func.func private @__mlir_math_ctlz_i32(%arg: i32) -> i32 { +/// %c_1index = arith.constant 1 : index +/// %c_1i32 = arith.constant 1 : i32 +/// %c_32 = arith.constant 32 : index +/// %c_0 = arith.constant 0 : i32 +/// %n = arith.constant 0 : i32 +/// %arg_out, %n_out = scf.for %i = %c_1index to %c_32 step %c_1index +/// iter_args(%arg_iter = %arg, %n_iter = %n) -> (i32, i32) { +/// %cond = arith.cmpi slt, %arg_iter, %c_0 : i32 +/// %yield_val = scf.if %cond { +/// scf.yield %arg_iter, %n_iter : i32, i32 +/// } else { +/// %arg_next = arith.shli %arg_iter, %c_1i32 : i32 +/// %n_next = arith.addi %n_iter, %c_1i32 : i32 +/// scf.yield %arg_next, %n_next : i32, i32 +/// } +/// scf.yield %yield_val: i32, i32 +/// } +/// return %n_out: i32 +/// } +static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { + if (!elementType.isa()) { + fprintf(stderr, "non-integer element type for CtlzFunc; type was: "); + elementType.dump(); + assert(elementType.isa() && "non-integer element type"); + } + int64_t bitWidth = elementType.getIntOrFloatBitWidth(); + + ImplicitLocOpBuilder builder = + ImplicitLocOpBuilder::atBlockEnd(module->getLoc(), module->getBody()); + + std::string funcName("__mlir_math_ctlz"); + llvm::raw_string_ostream nameOS(funcName); + nameOS << '_' << elementType; + FunctionType funcType = + FunctionType::get(builder.getContext(), {elementType}, elementType); + auto funcOp = builder.create(funcName, funcType); + + // LinkonceODR ensures that there is only one implementation of this function + // across all math.ctlz functions that are lowered in this way. + LLVM::linkage::Linkage inlineLinkage = LLVM::linkage::Linkage::LinkonceODR; + Attribute linkage = + LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage); + funcOp->setAttr("llvm.linkage", linkage); + funcOp.setPrivate(); + + // set the insertion point to the start of the function + Block *funcBody = funcOp.addEntryBlock(); + builder.setInsertionPointToStart(funcBody); + + Value arg = funcOp.getArgument(0); + Type indexType = builder.getIndexType(); + Value oneIndex = + builder.create(indexType, builder.getIndexAttr(1)); + Value oneValue = builder.create( + elementType, builder.getIntegerAttr(elementType, 1)); + Value bitWidthIndex = builder.create( + indexType, builder.getIndexAttr(bitWidth)); + Value zeroValue = builder.create( + elementType, builder.getIntegerAttr(elementType, 0)); + Value nValue = builder.create( + elementType, builder.getIntegerAttr(elementType, 0)); + + auto loop = builder.create( + oneIndex, bitWidthIndex, oneIndex, + // Initial values for two loop induction variables, the arg which is being + // shifted left in each iteration, and the n value which tracks the count + // of leading zeros. + ValueRange{arg, nValue}, + // Callback to build the body of the for loop + // if (arg < 0) { + // break; + // } else { + // n++; + // arg <<= 1; + // } + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + Value argIter = args[0]; + Value nIter = args[1]; + + Value argIsNonNegative = b.create( + loc, arith::CmpIPredicate::slt, argIter, zeroValue); + scf::IfOp ifOp = b.create( + loc, argIsNonNegative, + [&](OpBuilder &b, Location loc) { + // If arg is negative, break out of the loop. + b.create(loc, ValueRange{argIter, nIter}); + }, + [&](OpBuilder &b, Location loc) { + // Otherwise, increment n and shift arg left. + Value nNext = b.create(loc, nIter, oneValue); + Value argNext = b.create(loc, argIter, oneValue); + b.create(loc, ValueRange{argNext, nNext}); + }); + b.create(loc, ifOp.getResults()); + }); + + Value nOutValue = loop.getResult(1); + builder.create(nOutValue); + return funcOp; +} + +/// Convert ctlz into a call to a local function implementing the ctlz +/// operation. +LogicalResult CtlzOpLowering::matchAndRewrite(math::CountLeadingZerosOp op, + PatternRewriter &rewriter) const { + if (op.getType().template dyn_cast()) + return rewriter.notifyMatchFailure(op, "non-scalar operation"); + + Type type = getElementTypeOrSelf(op.getResult().getType()); + func::FuncOp elementFunc = getFuncOpCallback(op, type); + if (!elementFunc) { + return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { + diag << "Missing software implementation for op " << op->getName() + << " and type " << type; + }); + } + + rewriter.replaceOpWithNewOp(op, elementFunc, op.getOperand()); + return success(); +} + namespace { struct ConvertMathToFuncsPass : public impl::ConvertMathToFuncsBase { @@ -595,13 +753,13 @@ bool isFPowIConvertible(math::FPowIOp op); // Generate outlined implementations for power operations - // and store them in powerFuncs map. - void preprocessPowOperations(); + // and store them in funcImpls map. + void generateOpImplementations(); - // A map between function types deduced from power operations - // and the corresponding outlined software implementations - // of these operations. - DenseMap powerFuncs; + // A map between pairs of (operation, type) deduced from operations that this + // pass will convert, and the corresponding outlined software implementations + // of these operations for the given type. + DenseMap, func::FuncOp> funcImpls; }; } // namespace @@ -611,17 +769,28 @@ return (expTy && expTy.getWidth() >= minWidthOfFPowIExponent); } -void ConvertMathToFuncsPass::preprocessPowOperations() { +void ConvertMathToFuncsPass::generateOpImplementations() { ModuleOp module = getOperation(); module.walk([&](Operation *op) { TypeSwitch(op) + .Case([&](math::CountLeadingZerosOp op) { + Type resultType = getElementTypeOrSelf(op.getResult().getType()); + + // Generate the software implementation of this operation, + // if it has not been generated yet. + auto key = std::pair(op->getName(), resultType); + auto entry = funcImpls.try_emplace(key, func::FuncOp{}); + if (entry.second) + entry.first->second = createCtlzFunc(&module, resultType); + }) .Case([&](math::IPowIOp op) { Type resultType = getElementTypeOrSelf(op.getResult().getType()); // Generate the software implementation of this operation, // if it has not been generated yet. - auto entry = powerFuncs.try_emplace(resultType, func::FuncOp{}); + auto key = std::pair(op->getName(), resultType); + auto entry = funcImpls.try_emplace(key, func::FuncOp{}); if (entry.second) entry.first->second = createElementIPowIFunc(&module, resultType); }) @@ -635,7 +804,8 @@ // if it has not been generated yet. // FPowI implementations are mapped via the FunctionType // created from the operation's result and operands. - auto entry = powerFuncs.try_emplace(funcType, func::FuncOp{}); + auto key = std::pair(op->getName(), funcType); + auto entry = funcImpls.try_emplace(key, func::FuncOp{}); if (entry.second) entry.first->second = createElementFPowIFunc(&module, funcType); }); @@ -646,27 +816,31 @@ ModuleOp module = getOperation(); // Create outlined implementations for power operations. - preprocessPowOperations(); + generateOpImplementations(); RewritePatternSet patterns(&getContext()); - patterns.add, VecOpToScalarOp>( + patterns.add, VecOpToScalarOp, + VecOpToScalarOp>( patterns.getContext()); - // For the given Type Returns FuncOp stored in powerFuncs map. - auto getPowerFuncOpByType = [&](Type type) -> func::FuncOp { - auto it = powerFuncs.find(type); - if (it == powerFuncs.end()) + // For the given Type Returns FuncOp stored in funcImpls map. + auto getFuncOpByType = [&](Operation *op, Type type) -> func::FuncOp { + auto it = funcImpls.find(std::pair(op->getName(), type)); + if (it == funcImpls.end()) return {}; return it->second; }; - patterns.add(patterns.getContext(), - getPowerFuncOpByType); + patterns.add( + patterns.getContext(), getFuncOpByType); ConversionTarget target(getContext()); target.addLegalDialect(); + func::FuncDialect, scf::SCFDialect, + vector::VectorDialect>(); + target.addIllegalOp(); + target.addIllegalOp(); target.addDynamicallyLegalOp( [this](math::FPowIOp op) { return !isFPowIConvertible(op); }); if (failed(applyPartialConversion(module, target, std::move(patterns)))) diff --git a/mlir/test/Conversion/MathToFuncs/ctlz.mlir b/mlir/test/Conversion/MathToFuncs/ctlz.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MathToFuncs/ctlz.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(convert-math-to-funcs)" | FileCheck %s + +// Check a golden-path i32 conversion + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[VAL_0:.*]]: i32 +// CHECK-SAME: ) { +// CHECK: %[[VAL_1:.*]] = call @__mlir_math_ctlz_i32(%[[VAL_0]]) : (i32) -> i32 +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func.func private @__mlir_math_ctlz_i32( +// CHECK-SAME: %[[ARG:.*]]: i32 +// CHECK-SAME: ) -> i32 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[C_1INDEX:.*]] = arith.constant 1 : index +// CHECK: %[[C_1I32:.*]] = arith.constant 1 : i32 +// CHECK: %[[C_32:.*]] = arith.constant 32 : index +// CHECK: %[[C_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[N:.*]] = arith.constant 0 : i32 +// CHECK: %[[FOR_RET:.*]]:2 = scf.for %[[I:.*]] = %[[C_1INDEX]] to %[[C_32]] step %[[C_1INDEX]] +// CHECK: iter_args(%[[ARG_ITER:.*]] = %[[ARG]], %[[N_ITER:.*]] = %[[N]]) -> (i32, i32) { +// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ARG_ITER]], %[[C_0]] : i32 +// CHECK: %[[IF_RET:.*]]:2 = scf.if %[[COND]] -> (i32, i32) { +// CHECK: scf.yield %[[ARG_ITER]], %[[N_ITER]] : i32, i32 +// CHECK: } else { +// CHECK: %[[N_NEXT:.*]] = arith.addi %[[N_ITER]], %[[C_1I32]] : i32 +// CHECK: %[[ARG_NEXT:.*]] = arith.shli %[[ARG_ITER]], %[[C_1I32]] : i32 +// CHECK: scf.yield %[[ARG_NEXT]], %[[N_NEXT]] : i32, i32 +// CHECK: } +// CHECK: scf.yield %[[IF_RET]]#0, %[[IF_RET]]#1 : i32, i32 +// CHECK: } +// CHECK: return %[[FOR_RET]]#1 : i32 +// CHECK: } +func.func @main(%arg0: i32) { + %0 = math.ctlz %arg0 : i32 + func.return +} + +// ----- + +// Check that i8 input is preserved + +// CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[VAL_0:.*]]: i8 +// CHECK-SAME: ) { +// CHECK: %[[VAL_1:.*]] = call @__mlir_math_ctlz_i8(%[[VAL_0]]) : (i8) -> i8 +// CHECK: return +// CHECK: } + +// CHECK-LABEL: func.func private @__mlir_math_ctlz_i8( +// CHECK-SAME: %[[ARG:.*]]: i8 +// CHECK-SAME: ) -> i8 attributes {llvm.linkage = #llvm.linkage} { +// CHECK: %[[C_1INDEX:.*]] = arith.constant 1 : index +// CHECK: %[[C_1I8:.*]] = arith.constant 1 : i8 +// CHECK: %[[C_8:.*]] = arith.constant 8 : index +// CHECK: %[[C_0:.*]] = arith.constant 0 : i8 +// CHECK: %[[N:.*]] = arith.constant 0 : i8 +// CHECK: %[[FOR_RET:.*]]:2 = scf.for %[[I:.*]] = %[[C_1INDEX]] to %[[C_8]] step %[[C_1INDEX]] +// CHECK: iter_args(%[[ARG_ITER:.*]] = %[[ARG]], %[[N_ITER:.*]] = %[[N]]) -> (i8, i8) { +// CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ARG_ITER]], %[[C_0]] : i8 +// CHECK: %[[IF_RET:.*]]:2 = scf.if %[[COND]] -> (i8, i8) { +// CHECK: scf.yield %[[ARG_ITER]], %[[N_ITER]] : i8, i8 +// CHECK: } else { +// CHECK: %[[N_NEXT:.*]] = arith.addi %[[N_ITER]], %[[C_1I8]] : i8 +// CHECK: %[[ARG_NEXT:.*]] = arith.shli %[[ARG_ITER]], %[[C_1I8]] : i8 +// CHECK: scf.yield %[[ARG_NEXT]], %[[N_NEXT]] : i8, i8 +// CHECK: } +// CHECK: scf.yield %[[IF_RET]]#0, %[[IF_RET]]#1 : i8, i8 +// CHECK: } +// CHECK: return %[[FOR_RET]]#1 : i8 +// CHECK: } +func.func @main(%arg0: i8) { + %0 = math.ctlz %arg0 : i8 + func.return +} + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6308,6 +6308,7 @@ ":LLVMDialect", ":MathDialect", ":Pass", + ":SCFDialect", ":Transforms", ":VectorDialect", ":VectorUtils",