diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -10,11 +10,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #include @@ -182,6 +183,17 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts memref.automic_rmw operations to SPIR-V atomic operations. +class AtomicRMWOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Removed a deallocation if it is a supported allocation. Currently only /// removes deallocation if the memory space is workgroup memory. class DeallocOpPattern final : public OpConversionPattern { @@ -303,6 +315,62 @@ return success(); } +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + +LogicalResult +AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + if (atomicOp.getType().isa()) + return rewriter.notifyMatchFailure(atomicOp, + "unimplemented floating-point case"); + + auto memrefType = atomicOp.getMemref().getType().cast(); + std::optional scope = getAtomicOpScope(memrefType); + if (!scope) + return rewriter.notifyMatchFailure(atomicOp, + "unsupported memref memory space"); + + auto &typeConverter = *getTypeConverter(); + Type resultType = typeConverter.convertType(atomicOp.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(atomicOp, + "failed to convert result type"); + + auto loc = atomicOp.getLoc(); + Value ptr = + spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), + adaptor.getIndices(), loc, rewriter); + + if (!ptr) + return failure(); + +#define ATOMIC_CASE(kind, spirvOp) \ + case arith::AtomicRMWKind::kind: \ + rewriter.replaceOpWithNewOp( \ + atomicOp, resultType, ptr, *scope, \ + spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ + break + + switch (atomicOp.getKind()) { + ATOMIC_CASE(addi, AtomicIAddOp); + ATOMIC_CASE(maxs, AtomicSMaxOp); + ATOMIC_CASE(maxu, AtomicUMaxOp); + ATOMIC_CASE(mins, AtomicSMinOp); + ATOMIC_CASE(minu, AtomicUMinOp); + ATOMIC_CASE(ori, AtomicOrOp); + ATOMIC_CASE(andi, AtomicAndOp); + default: + return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind"); + } + +#undef ATOMIC_CASE + + return success(); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// @@ -656,9 +724,9 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToSPIRV/atomic.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s + +module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + +// CHECK: func.func @atomic_addi_storage_buffer +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_addi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicIAdd "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "addi" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +// CHECK: func.func @atomic_maxs_workgroup +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_maxs_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicSMax "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "maxs" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +// CHECK: func.func @atomic_maxu_storage_buffer +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_maxu_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicUMax "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "maxu" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +// CHECK: func.func @atomic_mins_workgroup +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_mins_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicSMin "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "mins" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +// CHECK: func.func @atomic_minu_storage_buffer +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_minu_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicUMin "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "minu" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +// CHECK: func.func @atomic_ori_workgroup +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_ori_workgroup(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicOr "Workgroup" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "ori" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +// CHECK: func.func @atomic_andi_storage_buffer +// CHECK-SAME: (%[[VAL:.+]]: i32, +func.func @atomic_andi_storage_buffer(%value: i32, %memref: memref<2x3x4xi32, #spirv.storage_class>, %i0: index, %i1: index, %i2: index) -> i32 { + // CHECK: %[[AC:.+]] = spirv.AccessChain + // CHECK: %[[ATOMIC:.+]] = spirv.AtomicAnd "Device" "AcquireRelease" %[[AC]], %[[VAL]] : !spirv.ptr + // CHECK: return %[[ATOMIC]] + %0 = memref.atomic_rmw "andi" %value, %memref[%i0, %i1, %i2] : (i32, memref<2x3x4xi32, #spirv.storage_class>) -> i32 + return %0: i32 +} + +} +