diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRBufferizationTransforms MLIRFuncDialect MLIRFunctionInterfaces + MLIRIndexDialect MLIRIR MLIRLinalgDialect MLIRLinalgTransforms diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" @@ -36,6 +37,7 @@ declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir --- a/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir +++ b/mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir @@ -1,6 +1,13 @@ -// RUN: mlir-opt %s -o - -test-lower-to-llvm -cse -split-input-file | FileCheck %s // Note: We run CSE here to make the pattern matching more direct. +// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s + +// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \ +// RUN: -test-transform-dialect-erase-schedule -cse \ +// RUN: | FileCheck %s + +module attributes {transform.target_tag="payload"} { + // Check that we properly lower to llvm memref operations that require to be // expanded first, like `memref.subview`. func.func @subview(%0 : memref<64x4xf32, strided<[4, 1], offset: 0>>, %arg0 : index, %arg1 : index, %arg2 : index) @@ -43,3 +50,15 @@ to memref> return %1 : memref> } + +} // transform payload + +module @named_inclusion_in_named attributes { transform.with_named_sequence } { + transform.named_sequence private @lower_to_cpu(!transform.any_op {transform.consumed}) -> !transform.any_op + + transform.sequence failures(propagate) { + ^bb1(%toplevel_module: !transform.any_op): + %m2 = transform.include @lower_to_cpu failures(suppress) (%toplevel_module) + : (!transform.any_op) -> (!transform.any_op) + } +} diff --git a/mlir/test/Dialect/LLVM/lower-to-llvm-transform-symbol-def.mlir b/mlir/test/Dialect/LLVM/lower-to-llvm-transform-symbol-def.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVM/lower-to-llvm-transform-symbol-def.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s + +/// Schedule to lower to LLVM. +module @lower_module_to_cpu attributes { transform.with_named_sequence } { + +transform.named_sequence @lower_to_cpu( + %module: !transform.any_op {transform.consumed}) -> !transform.any_op { + + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op + %f = transform.apply_registered_pass "convert-vector-to-scf" to %func : (!transform.any_op) -> !transform.any_op + %f2 = transform.apply_registered_pass "convert-linalg-to-loops" to %f : (!transform.any_op) -> !transform.any_op + %f3 = transform.apply_registered_pass "convert-scf-to-cf" to %f2 : (!transform.any_op) -> !transform.any_op + %f4 = transform.apply_registered_pass "expand-strided-metadata" to %f3 : (!transform.any_op) -> !transform.any_op + %f5 = transform.apply_registered_pass "lower-affine" to %f4 : (!transform.any_op) -> !transform.any_op + + transform.apply_conversion_patterns to %f5 { + transform.apply_conversion_patterns.dialect_to_llvm "math" + transform.apply_conversion_patterns.vector.vector_to_llvm + transform.apply_conversion_patterns.dialect_to_llvm "memref" + transform.apply_conversion_patterns.func.func_to_llvm + transform.apply_conversion_patterns.dialect_to_llvm "index" + transform.apply_conversion_patterns.dialect_to_llvm "arith" + transform.apply_conversion_patterns.dialect_to_llvm "cf" + } with type_converter { + transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter + {index_bitwidth = 64, + use_bare_ptr = false, + use_bare_ptr_memref_call_conv = false, + use_opaque_pointers = true} + } { + legal_dialects = ["llvm"], + partial_conversion + } : !transform.any_op + + %m2 = transform.apply_registered_pass "reconcile-unrealized-casts" to %module : (!transform.any_op) -> !transform.any_op + transform.yield %m2 : !transform.any_op +} + +} // transform module diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9816,6 +9816,7 @@ ":FuncDialect", ":FunctionInterfaces", ":GPUDialect", + ":IndexDialect", ":IR", ":LinalgDialect", ":LinalgMatchOpsIncGen", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel @@ -24,6 +24,7 @@ include = ["**/*.mlir"], exclude = [ "IRDL/*.irdl.mlir", + "LLVM/*-symbol-def.mlir", "Transform/*-source.mlir", "Transform/*-symbol-def.mlir", "Transform/*-symbol-decl-and-schedule.mlir",