diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -22,8 +22,8 @@ class OwningRewritePatternList; -/// Creates an instance of the ExpandAtomic pass. -std::unique_ptr createExpandAtomicPass(); +void populateExpandAtomicPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); void populateExpandMemRefReshapePattern(OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -11,11 +11,6 @@ include "mlir/Pass/PassBase.td" -def ExpandAtomic : FunctionPass<"expand-atomic"> { - let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp."; - let constructor = "mlir::createExpandAtomicPass()"; -} - def StdBufferize : FunctionPass<"std-bufferize"> { let summary = "Bufferize the std dialect"; let constructor = "mlir::createStdBufferizePass()"; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3932,6 +3932,7 @@ XOrOpLowering, ZeroExtendIOpLowering>(converter); // clang-format on + populateExpandAtomicPattern(patterns, &converter.getContext()); } void mlir::populateStdToLLVMMemoryConversionPatterns( diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp @@ -70,25 +70,9 @@ } }; -struct ExpandAtomic : public ExpandAtomicBase { - void runOnFunction() override { - OwningRewritePatternList patterns; - patterns.insert(&getContext()); - - ConversionTarget target(getContext()); - target.addLegalOp(); - target.addDynamicallyLegalOp([](AtomicRMWOp op) { - return op.kind() != AtomicRMWKind::maxf && - op.kind() != AtomicRMWKind::minf; - }); - if (failed(mlir::applyPartialConversion(getFunction(), target, - std::move(patterns)))) - signalPassFailure(); - } -}; - } // namespace -std::unique_ptr mlir::createExpandAtomicPass() { - return std::make_unique(); +void mlir::populateExpandAtomicPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert(ctx); } diff --git a/mlir/test/Dialect/Standard/expand-atomic.mlir b/mlir/test/Dialect/Standard/expand-atomic.mlir --- a/mlir/test/Dialect/Standard/expand-atomic.mlir +++ b/mlir/test/Dialect/Standard/expand-atomic.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-expand-atomic -split-input-file | FileCheck %s // CHECK-LABEL: func @atomic_rmw_to_generic // CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms TestAffineLoopParametricTiling.cpp + TestExpandAtomic.cpp TestExpandMemRefReshape.cpp TestExpandTanh.cpp TestCallGraph.cpp diff --git a/mlir/test/lib/Transforms/TestExpandAtomic.cpp b/mlir/test/lib/Transforms/TestExpandAtomic.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestExpandAtomic.cpp @@ -0,0 +1,37 @@ +//===- TestExpandAtomic.cpp - Test expand tanh op into exp form ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for expanding atomic. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestExpandAtomicPass + : public PassWrapper { + void runOnFunction() override; +}; +} // end anonymous namespace + +void TestExpandAtomicPass::runOnFunction() { + OwningRewritePatternList patterns; + populateExpandAtomicPattern(patterns, &getContext()); + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +namespace mlir { +void registerTestExpandAtomicPass() { + PassRegistration pass("test-expand-atomic", + "Test expanding atomic"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -52,6 +52,7 @@ void registerTestDominancePass(); void registerTestDialect(DialectRegistry &); void registerTestDynamicPipelinePass(); +void registerTestExpandAtomicPass(); void registerTestExpandMemRefReshapePass(); void registerTestExpandTanhPass(); void registerTestFinalizingBufferizePass(); @@ -117,8 +118,9 @@ registerTestDynamicPipelinePass(); registerTestFinalizingBufferizePass(); registerTestFunc(); - registerTestExpandTanhPass(); + registerTestExpandAtomicPass(); registerTestExpandMemRefReshapePass(); + registerTestExpandTanhPass(); registerTestGpuMemoryPromotionPass(); registerTestInterfaces(); registerTestLinalgCodegenStrategy();