diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md --- a/mlir/docs/Passes.md +++ b/mlir/docs/Passes.md @@ -39,3 +39,7 @@ ## `spv` Dialect Passes [include "SPIRVPasses.md"] + +## `standard` Dialect Passes + +[include "StandardPasses.md"] diff --git a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/StandardOps/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(MLIRStandardTransformsIncGen) + +add_mlir_doc(Passes -gen-pass-doc StandardPasses ./) diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -0,0 +1,29 @@ + +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors in the loop +// transformation library. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ + +#include + +namespace mlir { + +class Pass; + +/// Creates an instance of the ExpandAtomic pass. +std::unique_ptr createExpandAtomicPass(); + +} // end namespace mlir + +#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -0,0 +1,19 @@ +//===-- Passes.td - StandardOps pass definition file -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES +#define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def ExpandAtomic : FunctionPass<"expand-atomic"> { + let summary = "Expands AtomicRMWOp into GenericAtomicRMWOp."; + let constructor = "mlir::createExpandAtomicPass()"; +} + +#endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/LoopOps/Passes.h" #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Transforms/LocationSnapshot.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/ViewOpGraph.h" @@ -86,6 +87,10 @@ // SPIR-V #define GEN_PASS_REGISTRATION #include "mlir/Dialect/SPIRV/Passes.h.inc" + + // Standard +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc" } } // namespace mlir diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2642,113 +2642,6 @@ } }; -/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be -/// retried until it succeeds in atomically storing a new value into memory. -/// -/// +---------------------------------+ -/// | | -/// | | -/// | br loop(%loaded) | -/// +---------------------------------+ -/// | -/// -------| | -/// | v v -/// | +--------------------------------+ -/// | | loop(%loaded): | -/// | | | -/// | | %pair = cmpxchg | -/// | | %ok = %pair[0] | -/// | | %new = %pair[1] | -/// | | cond_br %ok, end, loop(%new) | -/// | +--------------------------------+ -/// | | | -/// |----------- | -/// v -/// +--------------------------------+ -/// | end: | -/// | | -/// +--------------------------------+ -/// -struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering { - using Base::Base; - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto atomicOp = cast(op); - auto maybeKind = matchSimpleAtomicOp(atomicOp); - if (maybeKind) - return failure(); - - LLVM::FCmpPredicate predicate; - switch (atomicOp.kind()) { - case AtomicRMWKind::maxf: - predicate = LLVM::FCmpPredicate::ogt; - break; - case AtomicRMWKind::minf: - predicate = LLVM::FCmpPredicate::olt; - break; - default: - return failure(); - } - - OperandAdaptor adaptor(operands); - auto loc = op->getLoc(); - auto valueType = adaptor.value().getType().cast(); - - // Split the block into initial, loop, and ending parts. - auto *initBlock = rewriter.getInsertionBlock(); - auto initPosition = rewriter.getInsertionPoint(); - auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); - auto loopArgument = loopBlock->addArgument(valueType); - auto loopPosition = rewriter.getInsertionPoint(); - auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); - - // Compute the loaded value and branch to the loop block. - rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemRefType(); - auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter, getModule()); - Value init = rewriter.create(loc, dataPtr); - rewriter.create(loc, init, loopBlock); - - // Prepare the body of the loop block. - rewriter.setInsertionPointToStart(loopBlock); - auto predicateI64 = - rewriter.getI64IntegerAttr(static_cast(predicate)); - auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); - auto lhs = loopArgument; - auto rhs = adaptor.value(); - auto cmp = - rewriter.create(loc, boolType, predicateI64, lhs, rhs); - auto select = rewriter.create(loc, cmp, lhs, rhs); - - // Prepare the epilog of the loop block. - rewriter.setInsertionPointToEnd(loopBlock); - // Append the cmpxchg op to the end of the loop block. - auto successOrdering = LLVM::AtomicOrdering::acq_rel; - auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); - auto cmpxchg = rewriter.create( - loc, pairType, dataPtr, loopArgument, select, successOrdering, - failureOrdering); - // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create( - loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); - Value ok = rewriter.create( - loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); - - // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create(loc, ok, endBlock, ArrayRef(), - loopBlock, newLoaded); - - // The 'result' of the atomic_rmw op is the newly loaded value. - rewriter.replaceOp(op, {newLoaded}); - - return success(); - } -}; - /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// @@ -2884,7 +2777,6 @@ AddIOpLowering, AllocaOpLowering, AndOpLowering, - AtomicCmpXchgOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -19,3 +19,5 @@ MLIRViewLikeInterface LLVMSupport ) + +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRStandardOpsTransforms + ExpandAtomic.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps/Transforms + + DEPENDS + MLIRStandardTransformsIncGen + ) +target_link_libraries(MLIRStandardOpsTransforms + PUBLIC + MLIRIR + MLIRPass + MLIRStandardOps + MLIRSupport + LLVMSupport + ) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandAtomic.cpp @@ -0,0 +1,93 @@ +//===- ExpandAtomic.cpp - Code to perform loop fusion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements expansion of AtomicRMWOp into GenericAtomicRMWOp. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with +/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to +/// `generic_atomic_rmw` with the expanded code. +/// +/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 +/// +/// will be lowered to +/// +/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> { +/// ^bb0(%current: f32): +/// %cmp = cmpf "ogt", %current, %fval : f32 +/// %new_value = select %cmp, %current, %fval : f32 +/// atomic_yield %new_value : f32 +/// } +struct AtomicRMWOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtomicRMWOp op, + PatternRewriter &rewriter) const final { + CmpFPredicate predicate; + switch (op.kind()) { + case AtomicRMWKind::maxf: + predicate = CmpFPredicate::OGT; + break; + case AtomicRMWKind::minf: + predicate = CmpFPredicate::OLT; + break; + default: + return failure(); + } + + auto loc = op.getLoc(); + auto genericOp = + rewriter.create(loc, op.memref(), op.indices()); + OpBuilder bodyBuilder = OpBuilder::atBlockEnd(genericOp.getBody()); + + Value lhs = genericOp.getCurrentValue(); + Value rhs = op.value(); + Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); + Value select = bodyBuilder.create(loc, cmp, lhs, rhs); + bodyBuilder.create(loc, select); + + rewriter.replaceOp(op, genericOp.getResult()); + return success(); + } +}; + +struct ExpandAtomic : public ExpandAtomicBase { + void runOnFunction() override { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + ConversionTarget target(getContext()); + target.addLegalOp(); + target.addDynamicallyLegalOp([](AtomicRMWOp op) { + return op.kind() != AtomicRMWKind::maxf && + op.kind() != AtomicRMWKind::minf; + }); + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createExpandAtomicPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/StandardOps/Transforms/PassDetail.h @@ -0,0 +1,23 @@ +//===- PassDetail.h - GPU Pass class details --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class AtomicRMWOp; + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc" + +} // end namespace mlir + +#endif // DIALECT_STANDARD_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1086,25 +1086,6 @@ // ----- -// CHECK-LABEL: func @cmpxchg -func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 { - %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 - // CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*"> - // CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float) - // CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float): - // CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float - // CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float - // CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float - // CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }"> - // CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }"> - // CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float) - // CHECK-NEXT: ^bb2: - return %x : f32 - // CHECK-NEXT: llvm.return %[[new]] -} - -// ----- - // CHECK-LABEL: func @generic_atomic_rmw func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 { %x = generic_atomic_rmw %I[%i] : memref<10xf32> { diff --git a/mlir/test/Dialect/Standard/expand-atomic.mlir b/mlir/test/Dialect/Standard/expand-atomic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/expand-atomic.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -expand-atomic -split-input-file | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: func @atomic_rmw_to_generic +// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index) +func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { + %x = atomic_rmw "maxf" %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %x : f32 +} +// CHECK: %0 = std.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> { +// CHECK: ^bb0([[CUR_VAL:%.*]]: f32): +// CHECK: [[CMP:%.*]] = cmpf "ogt", [[CUR_VAL]], [[f]] : f32 +// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32 +// CHECK: atomic_yield [[SELECT]] : f32 +// CHECK: } +// CHECK: return %0 : f32 + +// ----- + +// CHECK-LABEL: func @atomic_rmw_no_conversion +func @atomic_rmw_no_conversion(%F: memref<10xf32>, %f: f32, %i: index) -> f32 { + %x = atomic_rmw "addf" %f, %F[%i] : (f32, memref<10xf32>) -> f32 + return %x : f32 +} +// CHECK-NOT: generic_atomic_rmw