diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "llvm/Support/Debug.h" @@ -85,15 +86,27 @@ offset); } -/// Returns true if the allocations of type `t` can be lowered to SPIR-V. -static bool isAllocationSupported(MemRefType t) { - // Currently only support workgroup local memory allocations with static - // shape and int or float or vector of int or float element type. - if (!(t.hasStaticShape() && - SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) +/// Returns true if the allocations of memref `type` generated from `allocOp` +/// can be lowered to SPIR-V. +static bool isAllocationSupported(Operation *allocOp, MemRefType type) { + if (isa(allocOp)) { + if (SPIRVTypeConverter::getMemorySpaceForStorageClass( + spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt()) + return false; + } else if (isa(allocOp)) { + if (SPIRVTypeConverter::getMemorySpaceForStorageClass( + spirv::StorageClass::Function) != type.getMemorySpaceAsInt()) + return false; + } else { return false; - Type elementType = t.getElementType(); + } + + // Currently only support static shape and int or float or vector of int or + // float element type. + if (!type.hasStaticShape()) + return false; + + Type elementType = type.getElementType(); if (auto vecType = elementType.dyn_cast()) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); @@ -102,10 +115,10 @@ /// Returns the scope to use for atomic operations use for emulating store /// operations of unsupported integer bitwidths, based on the memref /// type. Returns None on failure. -static Optional getAtomicOpScope(MemRefType t) { +static Optional getAtomicOpScope(MemRefType type) { Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace( - t.getMemorySpaceAsInt()); + type.getMemorySpaceAsInt()); if (!storageClass) return {}; switch (*storageClass) { @@ -149,6 +162,16 @@ namespace { +/// Converts memref.alloca to SPIR-V Function variables. +class AllocaOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts an allocation operation to SPIR-V. Currently only supports lowering /// to Workgroup memory when the size is constant. Note that this pattern needs /// to be applied in a pass that runs at least at spv.module scope since it wil @@ -215,6 +238,25 @@ } // namespace +//===----------------------------------------------------------------------===// +// AllocaOp +//===----------------------------------------------------------------------===// + +LogicalResult +AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + MemRefType allocType = allocaOp.getType(); + if (!isAllocationSupported(allocaOp, allocType)) + return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type"); + + // Get the SPIR-V type for the allocation. + Type spirvType = getTypeConverter()->convertType(allocType); + rewriter.replaceOpWithNewOp(allocaOp, spirvType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); + return success(); +} + //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// @@ -223,8 +265,8 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType allocType = operation.getType(); - if (!isAllocationSupported(allocType)) - return operation.emitError("unhandled allocation type"); + if (!isAllocationSupported(operation, allocType)) + return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); // Get the SPIR-V type for the allocation. Type spirvType = getTypeConverter()->convertType(allocType); @@ -262,8 +304,8 @@ OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType deallocType = operation.memref().getType().cast(); - if (!isAllocationSupported(deallocType)) - return operation.emitError("unhandled deallocation type"); + if (!isAllocationSupported(operation, deallocType)) + return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); rewriter.eraseOp(operation); return success(); } @@ -505,8 +547,9 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns + .add( + typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir @@ -100,10 +100,12 @@ #spv.vce, {}> } { - func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : index) { - // expected-error @+1 {{unhandled allocation type}} + // CHECK-LABEL: func @alloc_dynamic_size + func.func @alloc_dynamic_size(%arg0 : index) -> f32 { + // CHECK: memref.alloc %0 = memref.alloc(%arg0) : memref<4x?xf32, 3> - return + %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3> + return %1: f32 } } @@ -114,10 +116,12 @@ #spv.vce, {}> } { - func.func @alloc_dealloc_mem() { - // expected-error @+1 {{unhandled allocation type}} + // CHECK-LABEL: func @alloc_unsupported_memory_space + func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 { + // CHECK: memref.alloc %0 = memref.alloc() : memref<4x5xf32> - return + %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32> + return %1: f32 } } @@ -129,8 +133,9 @@ #spv.vce, {}> } { - func.func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) { - // expected-error @+1 {{unhandled deallocation type}} + // CHECK-LABEL: func @dealloc_dynamic_size + func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) { + // CHECK: memref.dealloc memref.dealloc %arg0 : memref<4x?xf32, 3> return } @@ -143,8 +148,9 @@ #spv.vce, {}> } { - func.func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) { - // expected-error @+1 {{unhandled deallocation type}} + // CHECK-LABEL: func @dealloc_unsupported_memory_space + func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) { + // CHECK: memref.dealloc memref.dealloc %arg0 : memref<4x5xf32> return } diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt -split-input-file -convert-memref-to-spirv -canonicalize -verify-diagnostics %s -o - | FileCheck %s + +module attributes {spv.target_env = #spv.target_env<#spv.vce, {}>} { + func.func @alloc_function_variable(%arg0 : index, %arg1 : index) { + %0 = memref.alloca() : memref<4x5xf32, 6> + %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6> + memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6> + return + } +} + +// CHECK-LABEL: func @alloc_function_variable +// CHECK: %[[VAR:.+]] = spv.Variable : !spv.ptr)>, Function> +// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[VAR]] +// CHECK: %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32 +// CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[VAR]] +// CHECK: spv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32 + + +// ----- + +module attributes {spv.target_env = #spv.target_env<#spv.vce, {}>} { + func.func @two_allocs() { + %0 = memref.alloca() : memref<4x5xf32, 6> + %1 = memref.alloca() : memref<2x3xi32, 6> + return + } +} + +// CHECK-LABEL: func @two_allocs +// CHECK-DAG: spv.Variable : !spv.ptr)>, Function> +// CHECK-DAG: spv.Variable : !spv.ptr)>, Function> + +// ----- + +module attributes {spv.target_env = #spv.target_env<#spv.vce, {}>} { + func.func @two_allocs_vector() { + %0 = memref.alloca() : memref<4xvector<4xf32>, 6> + %1 = memref.alloca() : memref<2xvector<2xi32>, 6> + return + } +} + +// CHECK-LABEL: func @two_allocs_vector +// CHECK-DAG: spv.Variable : !spv.ptr, stride=8>)>, Function> +// CHECK-DAG: spv.Variable : !spv.ptr, stride=16>)>, Function> + + +// ----- + +module attributes {spv.target_env = #spv.target_env<#spv.vce, {}>} { + // CHECK-LABEL: func @alloc_dynamic_size + func.func @alloc_dynamic_size(%arg0 : index) -> f32 { + // CHECK: memref.alloca + %0 = memref.alloca(%arg0) : memref<4x?xf32, 6> + %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6> + return %1: f32 + } +} + +// ----- + +module attributes {spv.target_env = #spv.target_env<#spv.vce, {}>} { + // CHECK-LABEL: func @alloc_unsupported_memory_space + func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 { + // CHECK: memref.alloca + %0 = memref.alloca() : memref<4x5xf32> + %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32> + return %1: f32 + } +}