diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -169,6 +169,7 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/MemRefTransformOps.td SOURCES + dialects/_memref_transform_ops_ext.py dialects/transform/memref.py DIALECT_NAME transform EXTENSION_NAME memref_transform) diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -0,0 +1,67 @@ +# RUN: %PYTHON %s | FileCheck %s + + +from mlir.ir import * +from mlir.dialects import transform +from mlir.dialects.transform import memref + + +def run(f): + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + print("\nTEST:", f.__name__) + f() + print(module) + return f + + +@run +def testMemRefMultiBufferOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + memref.MemRefMultiBufferOp(sequence.bodyTarget, 4) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefMultiBufferOpCompact + # CHECK: = transform.memref.multibuffer + # CHECK-SAME: factor = 4 : i64 + # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op + + +@run +def testMemRefMultiBufferOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + memref.MemRefMultiBufferOp( + transform.OperationType.get("memref.alloc"), sequence.bodyTarget, 4 + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefMultiBufferOpTyped + # CHECK: = transform.memref.multibuffer + # CHECK-SAME: factor = 4 : i64 + # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.op<"memref.alloc"> + + +@run +def testMemRefMultiBufferOpAttributes(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.PROPAGATE, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + memref.MemRefMultiBufferOp(sequence.bodyTarget, 4, skip_analysis=True) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefMultiBufferOpAttributes + # CHECK: = transform.memref.multibuffer + # CHECK-SAME: factor = 4 : i64 + # CHECK-SAME: skip_analysis + # CHECK-SAME: (!transform.op<"memref.alloc">) -> !transform.any_op diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -889,6 +889,7 @@ "mlir/dialects/_bufferization_transform_ops_ext.py", "mlir/dialects/_gpu_transform_ops_ext.py", "mlir/dialects/_loop_transform_ops_ext.py", + "mlir/dialects/_memref_transform_ops_ext.py", "mlir/dialects/_structured_transform_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", "mlir/dialects/_transform_pdl_extension_ops_ext.py",