diff --git a/mlir/include/mlir/Dialect/Func/CMakeLists.txt b/mlir/include/mlir/Dialect/Func/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Func/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Func/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Func/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS FuncTransformOps.td) +mlir_tablegen(FuncTransformOps.h.inc -gen-op-decls) +mlir_tablegen(FuncTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRFuncTransformOpsIncGen) + +add_mlir_doc(FuncTransformOps FuncTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h @@ -0,0 +1,27 @@ +//===- FuncTransformOps.h - CF transformation ops --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_FUNC_TRANSFORMOPS_FUNCTRANSFORMOPS_H +#define MLIR_DIALECT_FUNC_TRANSFORMOPS_FUNCTRANSFORMOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace func { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace func +} // namespace mlir + +#endif // MLIR_DIALECT_FUNC_TRANSFORMOPS_FUNCTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td @@ -0,0 +1,29 @@ +//===- FuncTransformOps.td - CF transformation ops -*- tablegen -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef FUNC_TRANSFORM_OPS +#define FUNC_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/OpBase.td" + +def ApplyFuncToLLVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert Func dialect ops to LLVM dialect ops. + These patterns require an "LLVMTypeConverter". + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // FUNC_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -32,13 +32,22 @@ instead of the classic "malloc", "aligned_alloc" and "free" functions. - `use_opaque_pointers`: Generate LLVM IR using opaque pointers instead of typed pointers. + // TODO: the following two options don't really make sense for + // memref_to_llvm_type_converter specifically. + // We should have a single to_llvm_type_converter. + - `use_bare_ptr_call_conv`: Replace FuncOp's MemRef arguments with bare + pointers to the MemRef element types. + - `data-layout`: String description (LLVM format) of the data layout that is + expected on the produced module. }]; let arguments = (ins - DefaultValuedAttr:$use_aligned_alloc, - DefaultValuedAttr:$index_bitwidth, - DefaultValuedAttr:$use_generic_functions, - DefaultValuedAttr:$use_opaque_pointers); + DefaultValuedOptionalAttr:$use_aligned_alloc, + DefaultValuedOptionalAttr:$index_bitwidth, + DefaultValuedOptionalAttr:$use_generic_functions, + DefaultValuedOptionalAttr:$use_opaque_pointers, + DefaultValuedOptionalAttr:$use_bare_ptr_call_conv, + OptionalAttr:$data_layout); let assemblyFormat = "attr-dict"; } diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" #include "mlir/Dialect/IRDL/IR/IRDL.h" @@ -138,6 +139,7 @@ // Register all dialect extensions. affine::registerTransformDialectExtension(registry); bufferization::registerTransformDialectExtension(registry); + func::registerTransformDialectExtension(registry); gpu::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); memref::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Dialect/Func/CMakeLists.txt b/mlir/lib/Dialect/Func/CMakeLists.txt --- a/mlir/lib/Dialect/Func/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Extensions) add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/lib/Dialect/Func/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Func/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Func/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRFuncTransformOps + FuncTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/TransformOps + + DEPENDS + MLIRFuncTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRFuncToLLVM + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRTransformDialect +) diff --git a/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp @@ -0,0 +1,66 @@ +//===- FuncTransformOps.cpp - Implementation of CF transform ops ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Apply...ConversionPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyFuncToLLVMConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + populateFuncToLLVMConversionPatterns( + static_cast(typeConverter), patterns); +} + +LogicalResult +transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class FuncTransformDialectExtension + : public transform::TransformDialectExtension< + FuncTransformDialectExtension> { +public: + using Base::Base; + + void init() { + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc" + +void mlir::func::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -44,6 +44,13 @@ if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(getIndexBitwidth()); + // TODO: the following two options don't really make sense for + // memref_to_llvm_type_converter specifically but we should have a single + // to_llvm_type_converter. + if (getDataLayout().has_value()) + options.dataLayout = llvm::DataLayout(getDataLayout().value()); + options.useBarePtrCallConv = getUseBarePtrCallConv(); + return std::make_unique(getContext(), options); } diff --git a/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir b/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir --- a/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir +++ b/mlir/test/Conversion/FuncToLLVM/convert-data-layout.mlir @@ -1,6 +1,27 @@ // RUN: mlir-opt -convert-func-to-llvm='use-opaque-pointers=1' %s | FileCheck %s + // RUN-32: mlir-opt -convert-func-to-llvm='data-layout=p:32:32:32 use-opaque-pointers=1' %s | FileCheck %s +// RUN-32: mlir-opt -test-transform-dialect-interpreter %s | FileCheck %s + // CHECK: module attributes {llvm.data_layout = ""} // CHECK-32: module attributes {llvm.data_layout ="p:32:32:32"} module {} + +// TODO: right now this cannot apply as we cannot put the transform.sequence +// anywhere that would let us apply a transform to the top-level module. +// Encapsulating in another module does not help as the data layout is not +// current applied to anything else than the top-level module. +transform.sequence failures(propagate) { +^bb1(%toplevel_module: !transform.any_op): + transform.apply_conversion_patterns to %toplevel_module { + transform.apply_conversion_patterns.func.func_to_llvm + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + {data_layout = "p:32:32:32", use_opaque_pointers = true} + } { + legal_dialects = ["llvm"], + legal_ops = ["builtin.module"], + partial_conversion + } : !transform.any_op +} diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir --- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir +++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir @@ -1,6 +1,9 @@ // RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm),convert-func-to-llvm{use-opaque-pointers=1},reconcile-unrealized-casts)" %s -split-input-file | FileCheck %s + // RUN: mlir-opt -pass-pipeline="builtin.module(func.func(convert-math-to-llvm,convert-arith-to-llvm{index-bitwidth=32}),convert-func-to-llvm{index-bitwidth=32 use-opaque-pointers=1},reconcile-unrealized-casts)" %s -split-input-file | FileCheck --check-prefix=CHECK32 %s +// RUN: mlir-opt -test-transform-dialect-interpreter %s | FileCheck --check-prefix=CHECK32 %s + // CHECK-LABEL: func @empty() { // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -447,8 +450,6 @@ cf.br ^bb1 } -// ----- - // CHECK-LABEL: func @ceilf( // CHECK-SAME: f32 func.func @ceilf(%arg0 : f32) { @@ -457,8 +458,6 @@ func.return } -// ----- - // CHECK-LABEL: func @floorf( // CHECK-SAME: f32 func.func @floorf(%arg0 : f32) { @@ -467,8 +466,6 @@ func.return } -// ----- - // Lowers `cf.assert` to a function call to `abort` if the assertion is violated. // CHECK: llvm.func @abort() // CHECK-LABEL: @assert_test_function @@ -484,8 +481,6 @@ return } -// ----- - // This should not trigger an assertion by creating an LLVM::CallOp with a // nullptr result type. @@ -497,8 +492,6 @@ } func.func private @zero_result_func() -// ----- - // CHECK-LABEL: func @fmaf( // CHECK-SAME: %[[ARG0:.*]]: f32 // CHECK-SAME: %[[ARG1:.*]]: vector<4xf32> @@ -510,8 +503,6 @@ func.return } -// ----- - // CHECK-LABEL: func @switchi8( func.func @switchi8(%arg0 : i8) -> i32 { cf.switch %arg0 : i8, [ @@ -537,3 +528,21 @@ // CHECK-NEXT: %[[E1:.+]] = llvm.mlir.constant(42 : i32) : i32 // CHECK-NEXT: llvm.return %[[E1]] : i32 // CHECK-NEXT: } + +transform.sequence failures(propagate) { +^bb1(%toplevel_module: !transform.any_op): + %module = transform.structured.match ops{["func.func"]} in %toplevel_module + : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %module { + transform.apply_conversion_patterns.dialect_to_llvm "math" + transform.apply_conversion_patterns.dialect_to_llvm "arith" + transform.apply_conversion_patterns.dialect_to_llvm "cf" + transform.apply_conversion_patterns.func.func_to_llvm + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + {index_bitwidth = 32, use_opaque_pointers = true} + } { + legal_dialects = ["llvm"], + partial_conversion + } : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4050,6 +4050,57 @@ ], ) +td_library( + name = "FuncTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td", + ], + includes = ["include"], + deps = [ + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "FuncTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Func/TransformOps/FuncTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td", + deps = [ + ":FuncTransformOpsTdFiles", + ], +) + +cc_library( + name = "FuncTransformOps", + srcs = [ + "lib/Dialect/Func/TransformOps/FuncTransformOps.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Func/TransformOps/FuncTransformOps.h", + ], + includes = ["include"], + deps = [ + ":FuncDialect", + ":FuncToLLVM", + ":FuncTransformOpsIncGen", + ":IR", + ":LLVMCommonConversion", + ":LLVMDialect", + ":TransformDialect", + ], +) + cc_library( name = "AllExtensions", hdrs = ["include/mlir/InitAllExtensions.h"],