diff --git a/mlir/include/mlir/Dialect/Arith/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arith/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arith/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.h b/mlir/include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.h @@ -0,0 +1,27 @@ +//===- ArithTransformOps.h - Arith transformation ops -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMOPS_ARITHTRANSFORMOPS_H +#define MLIR_DIALECT_ARITH_TRANSFORMOPS_ARITHTRANSFORMOPS_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Arith/TransformOps/ArithTransformOps.h.inc" + +namespace mlir { +class DialectRegistry; + +namespace arith { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMOPS_ARITHTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.td b/mlir/include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.td @@ -0,0 +1,29 @@ +//===- ArithTransformOps.td - Arith transformation ops ----*- tablegen -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef ARITH_TRANSFORM_OPS +#define ARITH_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/IR/OpBase.td" + +def ApplyArithToLLVMConversionPatternsOp : Op]> { + let description = [{ + Collects patterns that convert arith dialect ops to LLVM dialect ops. These + patterns require an "LLVMTypeConverter". + }]; + + let assemblyFormat = "attr-dict"; +} + +#endif // ARITH_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Arith/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/TransformOps/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS ArithTransformOps.td) +mlir_tablegen(ArithTransformOps.h.inc -gen-op-decls) +mlir_tablegen(ArithTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRArithTransformOpsIncGen) + +add_mlir_doc(ArithTransformOps ArithTransformOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -21,6 +21,7 @@ #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/TransformOps/ArithTransformOps.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" @@ -135,6 +136,7 @@ // clang-format on // Register all dialect extensions. + arith::registerTransformDialectExtension(registry); affine::registerTransformDialectExtension(registry); bufferization::registerTransformDialectExtension(registry); gpu::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Dialect/Arith/CMakeLists.txt b/mlir/lib/Dialect/Arith/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) +add_subdirectory(TransformOps) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Arith/TransformOps/ArithTransformOps.cpp b/mlir/lib/Dialect/Arith/TransformOps/ArithTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/TransformOps/ArithTransformOps.cpp @@ -0,0 +1,66 @@ +//===- ArithTransformOps.cpp - Implementation of Arith transform ops ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/TransformOps/ArithTransformOps.h" + +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Apply...ConversionPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyArithToLLVMConversionPatternsOp::populatePatterns( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + arith::populateArithToLLVMConversionPatterns( + static_cast(typeConverter), patterns); +} + +LogicalResult +transform::ApplyArithToLLVMConversionPatternsOp::verifyTypeConverter( + transform::TypeConverterBuilderOpInterface builder) { + if (builder.getTypeConverterType() != "LLVMTypeConverter") + return emitOpError("expected LLVMTypeConverter"); + return success(); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class ArithTransformDialectExtension + : public transform::TransformDialectExtension< + ArithTransformDialectExtension> { +public: + using Base::Base; + + void init() { + declareGeneratedDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Arith/TransformOps/ArithTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/Arith/TransformOps/ArithTransformOps.cpp.inc" + +void mlir::arith::registerTransformDialectExtension(DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Arith/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Arith/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/TransformOps/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRArithTransformOps + ArithTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/TransformOps + + DEPENDS + MLIRArithTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRTransformDialect +) diff --git a/mlir/test/Dialect/Arith/transform-op-arith-to-llvm.mlir b/mlir/test/Dialect/Arith/transform-op-arith-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/transform-op-arith-to-llvm.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s + +// CHECK-LABEL: func @lower_to_llvm +// CHECK-NOT: arith.fptosi +// CHECK: llvm.fptosi +func.func @lower_to_llvm(%arg0: f32) -> i32 { + %0 = arith.fptosi %arg0: f32 to i32 + return %0 : i32 +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_conversion_patterns to %0 { + transform.apply_conversion_patterns.arith.arith_to_llvm + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + } {legal_dialects = ["func", "llvm"]} : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7905,6 +7905,7 @@ ":ArithDialect", ":ArithToLLVM", ":ArithToSPIRV", + ":ArithTransformOps", ":ArithTransforms", ":ArithValueBoundsOpInterfaceImpl", ":ArmNeonDialect", @@ -10975,6 +10976,57 @@ ], ) +td_library( + name = "ArithTransformOpsTdFiles", + srcs = [ + "include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.td", + ], + includes = ["include"], + deps = [ + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "ArithTransformOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-decls"], + "include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.td", + deps = [ + ":ArithTransformOpsTdFiles", + ], +) + +cc_library( + name = "ArithTransformOps", + srcs = [ + "lib/Dialect/Arith/TransformOps/ArithTransformOps.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Arith/TransformOps/ArithTransformOps.h", + ], + includes = ["include"], + deps = [ + ":ArithDialect", + ":ArithToLLVM", + ":ArithTransformOpsIncGen", + ":IR", + ":LLVMCommonConversion", + ":LLVMDialect", + ":TransformDialect", + ], +) + cc_library( name = "ArithUtils", srcs = glob([