diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -test-transform-dialect-erase-schedule --test-lower-to-llvm --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @matmul_tensors +func.func @matmul_tensors( + %arg0: tensor<2x4xf32>, %arg1: tensor<4x6xf32>, %arg2: tensor<2x6xf32>) + -> tensor<2x6xf32> { +// CHECK-NOT: linalg +// CHECK: llvm.intr.fmuladd{{.*}} + %0 = linalg.matmul ins(%arg0, %arg1: tensor<2x4xf32>, tensor<4x6xf32>) + outs(%arg2: tensor<2x6xf32>) + -> tensor<2x6xf32> + return %0 : tensor<2x6xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %module_op + %1, %loops:3 = transform.structured.tile %0 [2, 2, 2] + %2 = get_closest_isolated_parent %1 : (!pdl.operation) -> !pdl.operation + transform.structured.vectorize %2 + transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %module_op + {bufferize_function_boundaries = true} + %func = transform.structured.match ops{["func.func"]} in %module_op + transform.vector.lower_vectors %func { multireduction_lowering = "innerreduce"} +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(Func) add_subdirectory(GPU) add_subdirectory(Linalg) +add_subdirectory(LLVM) add_subdirectory(Math) add_subdirectory(MemRef) add_subdirectory(NVGPU) diff --git a/mlir/test/lib/Dialect/LLVM/CMakeLists.txt b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/LLVM/CMakeLists.txt @@ -0,0 +1,24 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRLLVMTestPasses + TestLowerToLLVM.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRAffineToStandard + MLIRFuncDialect + MLIRFuncToLLVM + MLIRIndexToLLVM + MLIRIR + MLIRLinalgToLLVM + MLIRLLVMDialect + MLIRLinalgTransforms + MLIRMathToLLVM + MLIRMemRefToLLVM + MLIRPass + MLIRReconcileUnrealizedCasts + MLIRSCFToControlFlow + MLIRTransforms + MLIRVectorToLLVM + MLIRVectorToSCF + ) diff --git a/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp b/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp @@ -0,0 +1,106 @@ +//===- TestLowerToLLVM.cpp - Test lowering to LLVM as a sink pass ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing the lowering to LLVM as a generally +// usable sink pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +struct TestLowerToLLVM + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLowerToLLVM) + + TestLowerToLLVM() = default; + TestLowerToLLVM(const TestLowerToLLVM &pass) : PassWrapper(pass) {} + + StringRef getArgument() const final { return "test-lower-to-llvm"; } + StringRef getDescription() const final { + return "Test lowering to LLVM as a generally usable sink pass"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option reassociateFPReductions{ + *this, "reassociate-fp-reductions", + llvm::cl::desc("Allow reassociation og FP reductions"), + llvm::cl::init(false)}; + + void runOnOperation() final; +}; +} // namespace + +void TestLowerToLLVM::runOnOperation() { + MLIRContext *context = &this->getContext(); + RewritePatternSet patterns(context); + + // TODO: it is feasible to scope lowering at arbitrary level and introduce + // unrealized casts, but there needs to be the final module-wise cleanup in + // the end. Keep module-level for now. + PassManager pm(&getContext()); + + // Blanket-convert any remaining high-level vector ops to loops if any remain. + pm.addNestedPass(createConvertVectorToSCFPass()); + // Blanket-convert any remaining linalg ops to loops if any remain. + pm.addNestedPass(createConvertLinalgToLoopsPass()); + // Blanket-convert any remaining affine ops if any remain. + pm.addPass(createLowerAffinePass()); + // Convert SCF to CF (always needed). + pm.addPass(createConvertSCFToCFPass()); + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + // Blanket-convert any remaining linalg ops to LLVM if any remain. + pm.addPass(createConvertLinalgToLLVMPass()); + // Convert vector to LLVM (always needed). + pm.addPass(createConvertVectorToLLVMPass( + // TODO: add more options on a per-need basis. + LowerVectorToLLVMOptions().enableReassociateFPReductions( + reassociateFPReductions))); + // Convert Math to LLVM (always needed). + pm.addNestedPass(createConvertMathToLLVMPass()); + // Convert MemRef to LLVM (always needed). + pm.addPass(createMemRefToLLVMConversionPass()); + // Convert Func to LLVM (always needed). + pm.addPass(createConvertFuncToLLVMPass()); + // Convert Index to LLVM (always needed). + pm.addPass(createConvertIndexToLLVMPass()); + // Convert remaining unrealized_casts (always needed). + pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(pm.run(getOperation()))) { + getOperation()->dump(); + return signalPassFailure(); + } +} + +namespace mlir { +namespace test { +void registerTestLowerToLLVM() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -38,6 +38,7 @@ MLIRTestTransforms MLIRTilingInterfaceTestPasses MLIRVectorTestPasses + MLIRLLVMTestPasses ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -98,6 +98,7 @@ void registerTestLoopFusion(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); +void registerTestLowerToLLVM(); void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); @@ -201,6 +202,7 @@ mlir::test::registerTestLoopFusion(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); + mlir::test::registerTestLowerToLLVM(); mlir::test::registerTestMatchReductionPass(); mlir::test::registerTestMathAlgebraicSimplificationPass(); mlir::test::registerTestMathPolynomialApproximationPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6873,6 +6873,7 @@ "//mlir/test:TestGPU", "//mlir/test:TestIR", "//mlir/test:TestLinalg", + "//mlir/test:TestLLVM", "//mlir/test:TestMath", "//mlir/test:TestMemRef", "//mlir/test:TestNVGPU", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -583,6 +583,31 @@ ], ) +cc_library( + name = "TestLLVM", + srcs = glob(["lib/Dialect/LLVM/*.cpp"]), + defines = ["MLIR_CUDA_CONVERSIONS_ENABLED"], + includes = ["lib/Dialect/Test"], + deps = [ + "//mlir:AffineToStandard", + "//mlir:FuncDialect", + "//mlir:FuncToLLVM", + "//mlir:IndexToLLVM", + "//mlir:IR", + "//mlir:LinalgToLLVM", + "//mlir:LinalgTransforms", + "//mlir:LLVMDialect", + "//mlir:MathToLLVM", + "//mlir:MemRefToLLVM", + "//mlir:Pass", + "//mlir:ReconcileUnrealizedCasts", + "//mlir:SCFToControlFlow", + "//mlir:Transforms", + "//mlir:VectorToLLVM", + "//mlir:VectorToSCF", + ], +) + cc_library( name = "TestMath", srcs = glob(["lib/Dialect/Math/*.cpp"]),