Index: mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -285,6 +285,35 @@ } }; + +/// Convert Alloca to VariableOp for Function and Private address space. +class AllocaOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(AllocaOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MemRefType allocType = operation.getType(); + auto storage = SPIRVTypeConverter::getStorageClassForMemorySpace( + allocType.getMemorySpace()); + if (!allocType.hasStaticShape() || !storage.hasValue()) + return operation.emitError("unhandled allocation type"); + if (storage.getValue() != mlir::spirv::StorageClass::Function && + storage.getValue() != mlir::spirv::StorageClass::Private) + return operation.emitError("unhandled allocation type"); + Location loc = operation.getLoc(); + // Get the SPIR-V type for the allocation. + auto spirvType = typeConverter.convertType(allocType); + Value alloca = rewriter.create( + loc, spirvType, + rewriter.getI32IntegerAttr(static_cast(storage.getValue())), + nullptr); + rewriter.replaceOp(operation, alloca); + return success(); + } +}; + /// Converts unary and binary standard operations to SPIR-V operations. template class UnaryAndBinaryOpPattern final : public SPIRVOpLowering { @@ -1031,7 +1060,7 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - AllocOpPattern, DeallocOpPattern, + AllocOpPattern, DeallocOpPattern, AllocaOpPattern, BitwiseOpPattern, BitwiseOpPattern, BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, Index: mlir/test/Conversion/StandardToSPIRV/alloc.mlir =================================================================== --- mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -142,3 +142,29 @@ return } } + + +// ----- + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> + } +{ + func @alloca_function_mem(%arg0 : index, %arg1 : index) { + %0 = alloca() : memref<4x5xf32, 6> + %1 = load %0[%arg0, %arg1] : memref<4x5xf32, 6> + store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6> + return + } +} +// CHECK: func @alloca_function_mem +// CHECK-NOT: alloca +// CHECK: %[[PTR:.+]] = spv.Variable : !spv.ptr [0]>, Function> +// CHECK: %[[LOADPTR:.+]] = spv.AccessChain %[[PTR]] +// CHECK: %[[VAL:.+]] = spv.Load "Function" %[[LOADPTR]] : f32 +// CHECK: %[[STOREPTR:.+]] = spv.AccessChain %[[PTR]] +// CHECK: spv.Store "Function" %[[STOREPTR]], %[[VAL]] : f32 +// CHECK: spv.Return