diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -8,9 +8,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -116,34 +114,82 @@ void runOnOperation() override; }; -struct ConvertToAtomCmpExchangeWeak : public RewritePattern { - ConvertToAtomCmpExchangeWeak(MLIRContext *context); +struct ConvertToAtomCmpExchangeWeak : RewritePattern { + ConvertToAtomCmpExchangeWeak(MLIRContext *context) + : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1, + context, {"spirv.AtomicCompareExchangeWeak"}) {} + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + PatternRewriter &rewriter) const override { + Value ptr = op->getOperand(0); + Value value = op->getOperand(1); + Value comparator = op->getOperand(2); + + // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits + // in memory semantics to additionally require AtomicStorage capability. + rewriter.replaceOpWithNewOp( + op, value.getType(), ptr, spirv::Scope::Workgroup, + spirv::MemorySemantics::AcquireRelease | + spirv::MemorySemantics::AtomicCounterMemory, + spirv::MemorySemantics::Acquire, value, comparator); + return success(); + } }; -struct ConvertToBitReverse : public RewritePattern { - ConvertToBitReverse(MLIRContext *context); +struct ConvertToBitReverse : RewritePattern { + ConvertToBitReverse(MLIRContext *context) + : RewritePattern("test.convert_to_bit_reverse_op", 1, context, + {"spirv.BitReverse"}) {} + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + PatternRewriter &rewriter) const override { + Value predicate = op->getOperand(0); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), predicate); + return success(); + } }; -struct ConvertToGroupNonUniformBallot : public RewritePattern { - ConvertToGroupNonUniformBallot(MLIRContext *context); +struct ConvertToGroupNonUniformBallot : RewritePattern { + ConvertToGroupNonUniformBallot(MLIRContext *context) + : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, + context, {"spirv.GroupNonUniformBallot"}) {} + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + PatternRewriter &rewriter) const override { + Value predicate = op->getOperand(0); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); + return success(); + } }; -struct ConvertToModule : public RewritePattern { - ConvertToModule(MLIRContext *context); +struct ConvertToModule : RewritePattern { + ConvertToModule(MLIRContext *context) + : RewritePattern("test.convert_to_module_op", 1, context, + {"spirv.module"}) {} + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, spirv::AddressingModel::PhysicalStorageBuffer64, + spirv::MemoryModel::Vulkan); + return success(); + } }; -struct ConvertToSubgroupBallot : public RewritePattern { - ConvertToSubgroupBallot(MLIRContext *context); +struct ConvertToSubgroupBallot : RewritePattern { + ConvertToSubgroupBallot(MLIRContext *context) + : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context, + {"spirv.KHR.SubgroupBallot"}) {} + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + PatternRewriter &rewriter) const override { + Value predicate = op->getOperand(0); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), predicate); + return success(); + } }; } // namespace @@ -170,82 +216,6 @@ return signalPassFailure(); } -ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) - : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1, - context, {"spirv.AtomicCompareExchangeWeak"}) {} - -LogicalResult -ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - Value ptr = op->getOperand(0); - Value value = op->getOperand(1); - Value comparator = op->getOperand(2); - - // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits - // in memory semantics to additionally require AtomicStorage capability. - rewriter.replaceOpWithNewOp( - op, value.getType(), ptr, spirv::Scope::Workgroup, - spirv::MemorySemantics::AcquireRelease | - spirv::MemorySemantics::AtomicCounterMemory, - spirv::MemorySemantics::Acquire, value, comparator); - return success(); -} - -ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) - : RewritePattern("test.convert_to_bit_reverse_op", 1, context, - {"spirv.BitReverse"}) {} - -LogicalResult -ConvertToBitReverse::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - Value predicate = op->getOperand(0); - - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), predicate); - return success(); -} - -ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( - MLIRContext *context) - : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context, - {"spirv.GroupNonUniformBallot"}) {} - -LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - Value predicate = op->getOperand(0); - - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); - return success(); -} - -ConvertToModule::ConvertToModule(MLIRContext *context) - : RewritePattern("test.convert_to_module_op", 1, context, - {"spirv.module"}) {} - -LogicalResult -ConvertToModule::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp( - op, spirv::AddressingModel::PhysicalStorageBuffer64, - spirv::MemoryModel::Vulkan); - return success(); -} - -ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) - : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context, - {"spirv.KHR.SubgroupBallot"}) {} - -LogicalResult -ConvertToSubgroupBallot::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - Value predicate = op->getOperand(0); - - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), predicate); - return success(); -} - namespace mlir { void registerConvertToTargetEnvPass() { PassRegistration();