diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h --- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h +++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h @@ -21,9 +21,12 @@ template class OperationPass; -/// Creates a pass to convert GPU Ops to SPIR-V ops. For a gpu.func to be -/// converted, it should have a spv.entry_point_abi attribute. -std::unique_ptr> createConvertGPUToSPIRVPass(); +/// Creates a pass to convert GPU kernel ops to corresponding SPIR-V ops. For a +/// gpu.func to be converted, it should have a spv.entry_point_abi attribute. +/// If `mapMemorySpace` is true, performs MemRef memory space to SPIR-V mapping +/// according to default Vulkan rules first. +std::unique_ptr> +createConvertGPUToSPIRVPass(bool mapMemorySpace = false); } // namespace mlir #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRVPASS_H diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -69,15 +69,6 @@ /// Gets the SPIR-V correspondence for the standard index type. Type getIndexType() const; - /// Returns the corresponding memory space for memref given a SPIR-V storage - /// class. - static unsigned getMemorySpaceForStorageClass(spirv::StorageClass); - - /// Returns the SPIR-V storage class given a memory space for memref. Return - /// llvm::None if the memory space does not map to any SPIR-V storage class. - static Optional - getStorageClassForMemorySpace(unsigned space); - /// Returns the options controlling the SPIR-V type converter. const Options &getOptions() const; diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -204,6 +204,8 @@ for (const auto &argType : enumerate(funcOp.getFunctionType().getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); + if (!convertedType) + return nullptr; signatureConverter.addInputs(argType.index(), convertedType); } } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -35,8 +35,14 @@ /// replace it). /// /// 2) Lower the body of the spirv::ModuleOp. -struct GPUToSPIRVPass : public ConvertGPUToSPIRVBase { +class GPUToSPIRVPass : public ConvertGPUToSPIRVBase { +public: + explicit GPUToSPIRVPass(bool mapMemorySpace) + : mapMemorySpace(mapMemorySpace) {} void runOnOperation() override; + +private: + bool mapMemorySpace; }; } // namespace @@ -44,16 +50,30 @@ MLIRContext *context = &getContext(); ModuleOp module = getOperation(); - SmallVector kernelModules; + SmallVector gpuModules; OpBuilder builder(context); - module.walk([&builder, &kernelModules](gpu::GPUModuleOp moduleOp) { - // For each kernel module (should be only 1 for now, but that is not a - // requirement here), clone the module for conversion because the - // gpu.launch function still needs the kernel module. + module.walk([&](gpu::GPUModuleOp moduleOp) { + // Clone each GPU kernel module for conversion, given that the GPU + // launch op still needs the original GPU kernel module. builder.setInsertionPoint(moduleOp.getOperation()); - kernelModules.push_back(builder.clone(*moduleOp.getOperation())); + gpuModules.push_back(builder.clone(*moduleOp.getOperation())); }); + // Map MemRef memory space to SPIR-V sotrage class first if requested. + if (mapMemorySpace) { + std::unique_ptr target = + spirv::getMemorySpaceToStorageClassTarget(*context); + spirv::MemorySpaceToStorageClassMap memorySpaceMap = + spirv::getDefaultVulkanStorageClassMap(); + spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); + + RewritePatternSet patterns(context); + spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns); + + if (failed(applyFullConversion(gpuModules, *target, std::move(patterns)))) + return signalPassFailure(); + } + auto targetAttr = spirv::lookupTargetEnvOrDefault(module); std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); @@ -68,10 +88,11 @@ populateMemRefToSPIRVPatterns(typeConverter, patterns); populateFuncToSPIRVPatterns(typeConverter, patterns); - if (failed(applyFullConversion(kernelModules, *target, std::move(patterns)))) + if (failed(applyFullConversion(gpuModules, *target, std::move(patterns)))) return signalPassFailure(); } -std::unique_ptr> mlir::createConvertGPUToSPIRVPass() { - return std::make_unique(); +std::unique_ptr> +mlir::createConvertGPUToSPIRVPass(bool mapMemorySpace) { + return std::make_unique(mapMemorySpace); } diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -90,12 +90,12 @@ /// can be lowered to SPIR-V. static bool isAllocationSupported(Operation *allocOp, MemRefType type) { if (isa(allocOp)) { - if (SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Workgroup) != type.getMemorySpaceAsInt()) + auto sc = type.getMemorySpace().dyn_cast_or_null(); + if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) return false; } else if (isa(allocOp)) { - if (SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Function) != type.getMemorySpaceAsInt()) + auto sc = type.getMemorySpace().dyn_cast_or_null(); + if (!sc || sc.getValue() != spirv::StorageClass::Function) return false; } else { return false; @@ -116,12 +116,8 @@ /// operations of unsupported integer bitwidths, based on the memref /// type. Returns None on failure. static Optional getAtomicOpScope(MemRefType type) { - Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace( - type.getMemorySpaceAsInt()); - if (!storageClass) - return {}; - switch (*storageClass) { + auto sc = type.getMemorySpace().dyn_cast_or_null(); + switch (sc.getValue()) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; case spirv::StorageClass::Workgroup: diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/Sequence.h" @@ -117,65 +118,6 @@ return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); } -/// Mapping between SPIR-V storage classes to memref memory spaces. -/// -/// Note: memref does not have a defined semantics for each memory space; it -/// depends on the context where it is used. There are no particular reasons -/// behind the number assignments; we try to follow NVVM conventions and largely -/// give common storage classes a smaller number. The hope is use symbolic -/// memory space representation eventually after memref supports it. -// TODO: swap Generic and StorageBuffer assignment to be more akin -// to NVVM. -#define STORAGE_SPACE_MAP_LIST(MAP_FN) \ - MAP_FN(spirv::StorageClass::Generic, 1) \ - MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ - MAP_FN(spirv::StorageClass::Workgroup, 3) \ - MAP_FN(spirv::StorageClass::Uniform, 4) \ - MAP_FN(spirv::StorageClass::Private, 5) \ - MAP_FN(spirv::StorageClass::Function, 6) \ - MAP_FN(spirv::StorageClass::PushConstant, 7) \ - MAP_FN(spirv::StorageClass::UniformConstant, 8) \ - MAP_FN(spirv::StorageClass::Input, 9) \ - MAP_FN(spirv::StorageClass::Output, 10) \ - MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ - MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ - MAP_FN(spirv::StorageClass::Image, 13) \ - MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \ - MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \ - MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \ - MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \ - MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \ - MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \ - MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \ - MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \ - MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \ - MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23) - -unsigned -SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) { -#define STORAGE_SPACE_MAP_FN(storage, space) \ - case storage: \ - return space; - - switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) } -#undef STORAGE_SPACE_MAP_FN - llvm_unreachable("unhandled storage class!"); -} - -Optional -SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { -#define STORAGE_SPACE_MAP_FN(storage, space) \ - case space: \ - return storage; - - switch (space) { - STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) - default: - return llvm::None; - } -#undef STORAGE_SPACE_MAP_FN -} - const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { return options; } @@ -184,8 +126,6 @@ return targetEnv.getAttr().getContext(); } -#undef STORAGE_SPACE_MAP_LIST - // TODO: This is a utility function that should probably be exposed by the // SPIR-V dialect. Keeping it local till the use case arises. static Optional @@ -375,16 +315,8 @@ static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, - MemRefType type) { - Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace( - type.getMemorySpaceAsInt()); - if (!storageClass) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert memory space\n"); - return nullptr; - } - + MemRefType type, + spirv::StorageClass storageClass) { unsigned numBoolBits = options.boolNumBits; if (numBoolBits != 8) { LLVM_DEBUG(llvm::dbgs() @@ -407,34 +339,37 @@ } if (!type.hasStaticShape()) { - int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; + int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); - return wrapInStructAndGetPointer(arrayType, *storageClass); + return wrapInStructAndGetPointer(arrayType, storageClass); } int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); - int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; + int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); - return wrapInStructAndGetPointer(arrayType, *storageClass); + return wrapInStructAndGetPointer(arrayType, storageClass); } static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type) { - if (type.getElementType().isa() && - type.getElementTypeBitWidth() == 1) { - return convertBoolMemrefType(targetEnv, options, type); + auto attr = type.getMemorySpace().dyn_cast_or_null(); + if (!attr) { + LLVM_DEBUG( + llvm::dbgs() + << type + << " illegal: expected memory space to be a SPIR-V storage class " + "attribute; please use MemorySpaceToStorageClassConverter to map " + "numeric memory spaces beforehand\n"); + return nullptr; } + spirv::StorageClass storageClass = attr.getValue(); - Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace( - type.getMemorySpaceAsInt()); - if (!storageClass) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert memory space\n"); - return nullptr; + if (type.getElementType().isa() && + type.getElementTypeBitWidth() == 1) { + return convertBoolMemrefType(targetEnv, options, type, storageClass); } Type arrayElemType; @@ -463,9 +398,9 @@ } if (!type.hasStaticShape()) { - int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; + int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); - return wrapInStructAndGetPointer(arrayType, *storageClass); + return wrapInStructAndGetPointer(arrayType, storageClass); } Optional memrefSize = getTypeNumBytes(options, type); @@ -476,10 +411,10 @@ } auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); - int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0; + int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); - return wrapInStructAndGetPointer(arrayType, *storageClass); + return wrapInStructAndGetPointer(arrayType, storageClass); } SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -274,29 +274,51 @@ // CHECK-SAME: Private // CHECK-SAME: Function func.func @memref_mem_space( - %arg0: memref<4xf32, 0>, - %arg1: memref<4xf32, 4>, - %arg2: memref<4xf32, 3>, - %arg3: memref<4xf32, 7>, - %arg4: memref<4xf32, 5>, - %arg5: memref<4xf32, 6> + %arg0: memref<4xf32, #spv.storage_class>, + %arg1: memref<4xf32, #spv.storage_class>, + %arg2: memref<4xf32, #spv.storage_class>, + %arg3: memref<4xf32, #spv.storage_class>, + %arg4: memref<4xf32, #spv.storage_class>, + %arg5: memref<4xf32, #spv.storage_class> ) { return } // CHECK-LABEL: func @memref_1bit_type // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // CHECK-SAME: !spv.ptr)>, Function> // NOEMU-LABEL: func @memref_1bit_type -// NOEMU-SAME: memref<4x8xi1> -// NOEMU-SAME: memref<4x8xi1, 6> +// NOEMU-SAME: memref<4x8xi1, #spv.storage_class> +// NOEMU-SAME: memref<4x8xi1, #spv.storage_class> func.func @memref_1bit_type( - %arg0: memref<4x8xi1, 0>, - %arg1: memref<4x8xi1, 6> + %arg0: memref<4x8xi1, #spv.storage_class>, + %arg1: memref<4x8xi1, #spv.storage_class> ) { return } } // end module // ----- +// Reject memory spaces. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, #spv.resource_limits<>> +} { + +// CHECK-LABEL: func @numeric_memref_mem_space1 +// CHECK-SAME: memref<4xf32> +// NOEMU-LABEL: func @numeric_memref_mem_space1 +// NOEMU-SAME: memref<4xf32> +func.func @numeric_memref_mem_space1(%arg0: memref<4xf32>) { return } + +// CHECK-LABEL: func @numeric_memref_mem_space2 +// CHECK-SAME: memref<4xf32, 3> +// NOEMU-LABEL: func @numeric_memref_mem_space2 +// NOEMU-SAME: memref<4xf32, 3> +func.func @numeric_memref_mem_space2(%arg0: memref<4xf32, 3>) { return } + +} // end module + +// ----- + // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not // satisfied. @@ -308,86 +330,86 @@ // CHECK-LABEL: spv.func @memref_1bit_type // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_1bit_type -// NOEMU-SAME: memref<5xi1> -func.func @memref_1bit_type(%arg0: memref<5xi1>) { return } +// NOEMU-SAME: memref<5xi1, #spv.storage_class> +func.func @memref_1bit_type(%arg0: memref<5xi1, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_8bit_StorageBuffer -// NOEMU-SAME: memref<16xi8> -func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } +// NOEMU-SAME: memref<16xi8, #spv.storage_class> +func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_8bit_Uniform // CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_8bit_Uniform -// NOEMU-SAME: memref<16xsi8, 4> -func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, 4>) { return } +// NOEMU-SAME: memref<16xsi8, #spv.storage_class> +func.func @memref_8bit_Uniform(%arg0: memref<16xsi8, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_8bit_PushConstant // CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_8bit_PushConstant -// NOEMU-SAME: memref<16xui8, 7> -func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, 7>) { return } +// NOEMU-SAME: memref<16xui8, #spv.storage_class> +func.func @memref_8bit_PushConstant(%arg0: memref<16xui8, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_16bit_StorageBuffer -// NOEMU-SAME: memref<16xi16> -func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, 0>) { return } +// NOEMU-SAME: memref<16xi16, #spv.storage_class> +func.func @memref_16bit_StorageBuffer(%arg0: memref<16xi16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform // CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_16bit_Uniform -// NOEMU-SAME: memref<16xsi16, 4> -func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, 4>) { return } +// NOEMU-SAME: memref<16xsi16, #spv.storage_class> +func.func @memref_16bit_Uniform(%arg0: memref<16xsi16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant // CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_16bit_PushConstant -// NOEMU-SAME: memref<16xui16, 7> -func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, 7>) { return } +// NOEMU-SAME: memref<16xui16, #spv.storage_class> +func.func @memref_16bit_PushConstant(%arg0: memref<16xui16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input // CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: func @memref_16bit_Input -// NOEMU-SAME: memref<16xf16, 9> -func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } +// NOEMU-SAME: memref<16xf16, #spv.storage_class> +func.func @memref_16bit_Input(%arg3: memref<16xf16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output // CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: func @memref_16bit_Output -// NOEMU-SAME: memref<16xf16, 10> -func.func @memref_16bit_Output(%arg4: memref<16xf16, 10>) { return } +// NOEMU-SAME: memref<16xf16, #spv.storage_class> +func.func @memref_16bit_Output(%arg4: memref<16xf16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_64bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_64bit_StorageBuffer -// NOEMU-SAME: memref<16xi64> -func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, 0>) { return } +// NOEMU-SAME: memref<16xi64, #spv.storage_class> +func.func @memref_64bit_StorageBuffer(%arg0: memref<16xi64, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_64bit_Uniform // CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_64bit_Uniform -// NOEMU-SAME: memref<16xsi64, 4> -func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, 4>) { return } +// NOEMU-SAME: memref<16xsi64, #spv.storage_class> +func.func @memref_64bit_Uniform(%arg0: memref<16xsi64, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_64bit_PushConstant // CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_64bit_PushConstant -// NOEMU-SAME: memref<16xui64, 7> -func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, 7>) { return } +// NOEMU-SAME: memref<16xui64, #spv.storage_class> +func.func @memref_64bit_PushConstant(%arg0: memref<16xui64, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_64bit_Input // CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: func @memref_64bit_Input -// NOEMU-SAME: memref<16xf64, 9> -func.func @memref_64bit_Input(%arg3: memref<16xf64, 9>) { return } +// NOEMU-SAME: memref<16xf64, #spv.storage_class> +func.func @memref_64bit_Input(%arg3: memref<16xf64, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_64bit_Output // CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: func @memref_64bit_Output -// NOEMU-SAME: memref<16xf64, 10> -func.func @memref_64bit_Output(%arg4: memref<16xf64, 10>) { return } +// NOEMU-SAME: memref<16xf64, #spv.storage_class> +func.func @memref_64bit_Output(%arg4: memref<16xf64, #spv.storage_class>) { return } } // end module @@ -406,7 +428,7 @@ // CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: spv.func @memref_8bit_PushConstant // NOEMU-SAME: !spv.ptr [0])>, PushConstant> -func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, 7>) { return } +func.func @memref_8bit_PushConstant(%arg0: memref<16xi8, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant // CHECK-SAME: !spv.ptr [0])>, PushConstant> @@ -415,8 +437,8 @@ // NOEMU-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-SAME: !spv.ptr [0])>, PushConstant> func.func @memref_16bit_PushConstant( - %arg0: memref<16xi16, 7>, - %arg1: memref<16xf16, 7> + %arg0: memref<16xi16, #spv.storage_class>, + %arg1: memref<16xf16, #spv.storage_class> ) { return } // CHECK-LABEL: spv.func @memref_64bit_PushConstant @@ -426,8 +448,8 @@ // NOEMU-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-SAME: !spv.ptr [0])>, PushConstant> func.func @memref_64bit_PushConstant( - %arg0: memref<16xi64, 7>, - %arg1: memref<16xf64, 7> + %arg0: memref<16xi64, #spv.storage_class>, + %arg1: memref<16xf64, #spv.storage_class> ) { return } } // end module @@ -447,7 +469,7 @@ // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: spv.func @memref_8bit_StorageBuffer // NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> -func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return } +func.func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> @@ -456,8 +478,8 @@ // NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> func.func @memref_16bit_StorageBuffer( - %arg0: memref<16xi16, 0>, - %arg1: memref<16xf16, 0> + %arg0: memref<16xi16, #spv.storage_class>, + %arg1: memref<16xf16, #spv.storage_class> ) { return } // CHECK-LABEL: spv.func @memref_64bit_StorageBuffer @@ -467,8 +489,8 @@ // NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-SAME: !spv.ptr [0])>, StorageBuffer> func.func @memref_64bit_StorageBuffer( - %arg0: memref<16xi64, 0>, - %arg1: memref<16xf64, 0> + %arg0: memref<16xi64, #spv.storage_class>, + %arg1: memref<16xf64, #spv.storage_class> ) { return } } // end module @@ -488,7 +510,7 @@ // CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: spv.func @memref_8bit_Uniform // NOEMU-SAME: !spv.ptr [0])>, Uniform> -func.func @memref_8bit_Uniform(%arg0: memref<16xi8, 4>) { return } +func.func @memref_8bit_Uniform(%arg0: memref<16xi8, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform // CHECK-SAME: !spv.ptr [0])>, Uniform> @@ -497,8 +519,8 @@ // NOEMU-SAME: !spv.ptr [0])>, Uniform> // NOEMU-SAME: !spv.ptr [0])>, Uniform> func.func @memref_16bit_Uniform( - %arg0: memref<16xi16, 4>, - %arg1: memref<16xf16, 4> + %arg0: memref<16xi16, #spv.storage_class>, + %arg1: memref<16xf16, #spv.storage_class> ) { return } // CHECK-LABEL: spv.func @memref_64bit_Uniform @@ -508,8 +530,8 @@ // NOEMU-SAME: !spv.ptr [0])>, Uniform> // NOEMU-SAME: !spv.ptr [0])>, Uniform> func.func @memref_64bit_Uniform( - %arg0: memref<16xi64, 4>, - %arg1: memref<16xf64, 4> + %arg0: memref<16xi64, #spv.storage_class>, + %arg1: memref<16xf64, #spv.storage_class> ) { return } } // end module @@ -528,13 +550,13 @@ // CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: spv.func @memref_16bit_Input // NOEMU-SAME: !spv.ptr)>, Input> -func.func @memref_16bit_Input(%arg3: memref<16xf16, 9>) { return } +func.func @memref_16bit_Input(%arg3: memref<16xf16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output // CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: spv.func @memref_16bit_Output // NOEMU-SAME: !spv.ptr)>, Output> -func.func @memref_16bit_Output(%arg4: memref<16xi16, 10>) { return } +func.func @memref_16bit_Output(%arg4: memref<16xi16, #spv.storage_class>) { return } // CHECK-LABEL: spv.func @memref_64bit_Input // CHECK-SAME: !spv.ptr)>, Input> @@ -543,8 +565,8 @@ // NOEMU-SAME: !spv.ptr)>, Input> // NOEMU-SAME: !spv.ptr)>, Input> func.func @memref_64bit_Input( - %arg0: memref<16xi64, 9>, - %arg1: memref<16xf64, 9> + %arg0: memref<16xi64, #spv.storage_class>, + %arg1: memref<16xf64, #spv.storage_class> ) { return } // CHECK-LABEL: spv.func @memref_64bit_Output @@ -554,8 +576,8 @@ // NOEMU-SAME: !spv.ptr)>, Output> // NOEMU-SAME: !spv.ptr)>, Output> func.func @memref_64bit_Output( - %arg0: memref<16xi64, 10>, - %arg1: memref<16xf64, 10> + %arg0: memref<16xi64, #spv.storage_class>, + %arg1: memref<16xf64, #spv.storage_class> ) { return } } // end module @@ -575,22 +597,22 @@ // CHECK-SAME: !spv.array<256 x f32, stride=4> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<64 x f32, stride=4> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<88 x f32, stride=4> [0])>, StorageBuffer> - %arg0: memref<16x4xf32, offset: 0, strides: [4, 1]>, // tightly packed; row major - %arg1: memref<16x4xf32, offset: 8, strides: [4, 1]>, // offset 8 - %arg2: memref<16x4xf32, offset: 0, strides: [16, 1]>, // pad 12 after each row - %arg3: memref<16x4xf32, offset: 0, strides: [1, 16]>, // tightly packed; col major - %arg4: memref<16x4xf32, offset: 0, strides: [1, 22]>, // pad 4 after each col + %arg0: memref<16x4xf32, offset: 0, strides: [4, 1], #spv.storage_class>, // tightly packed; row major + %arg1: memref<16x4xf32, offset: 8, strides: [4, 1], #spv.storage_class>, // offset 8 + %arg2: memref<16x4xf32, offset: 0, strides: [16, 1], #spv.storage_class>, // pad 12 after each row + %arg3: memref<16x4xf32, offset: 0, strides: [1, 16], #spv.storage_class>, // tightly packed; col major + %arg4: memref<16x4xf32, offset: 0, strides: [1, 22], #spv.storage_class>, // pad 4 after each col // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<72 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<256 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<64 x f16, stride=2> [0])>, StorageBuffer> // CHECK-SAME: !spv.array<88 x f16, stride=2> [0])>, StorageBuffer> - %arg5: memref<16x4xf16, offset: 0, strides: [4, 1]>, - %arg6: memref<16x4xf16, offset: 8, strides: [4, 1]>, - %arg7: memref<16x4xf16, offset: 0, strides: [16, 1]>, - %arg8: memref<16x4xf16, offset: 0, strides: [1, 16]>, - %arg9: memref<16x4xf16, offset: 0, strides: [1, 22]> + %arg5: memref<16x4xf16, offset: 0, strides: [4, 1], #spv.storage_class>, + %arg6: memref<16x4xf16, offset: 8, strides: [4, 1], #spv.storage_class>, + %arg7: memref<16x4xf16, offset: 0, strides: [16, 1], #spv.storage_class>, + %arg8: memref<16x4xf16, offset: 0, strides: [1, 16], #spv.storage_class>, + %arg9: memref<16x4xf16, offset: 0, strides: [1, 22], #spv.storage_class> ) { return } } // end module @@ -610,14 +632,15 @@ // CHECK-LABEL: func @memref_1bit_type // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_1bit_type -// NOEMU-SAME: memref -func.func @memref_1bit_type(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_1bit_type(%arg0: memref>) { return } // CHECK-LABEL: func @dynamic_dim_memref // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> -func.func @dynamic_dim_memref(%arg0: memref<8x?xi32>, - %arg1: memref) { return } +func.func @dynamic_dim_memref( + %arg0: memref<8x?xi32, #spv.storage_class>, + %arg1: memref>) { return } // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not @@ -626,50 +649,50 @@ // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_8bit_StorageBuffer -// NOEMU-SAME: memref -func.func @memref_8bit_StorageBuffer(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_8bit_StorageBuffer(%arg0: memref>) { return } // CHECK-LABEL: spv.func @memref_8bit_Uniform // CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_8bit_Uniform -// NOEMU-SAME: memref -func.func @memref_8bit_Uniform(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_8bit_Uniform(%arg0: memref>) { return } // CHECK-LABEL: spv.func @memref_8bit_PushConstant // CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_8bit_PushConstant -// NOEMU-SAME: memref -func.func @memref_8bit_PushConstant(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_8bit_PushConstant(%arg0: memref>) { return } // CHECK-LABEL: spv.func @memref_16bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_16bit_StorageBuffer -// NOEMU-SAME: memref -func.func @memref_16bit_StorageBuffer(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_16bit_StorageBuffer(%arg0: memref>) { return } // CHECK-LABEL: spv.func @memref_16bit_Uniform // CHECK-SAME: !spv.ptr [0])>, Uniform> // NOEMU-LABEL: func @memref_16bit_Uniform -// NOEMU-SAME: memref -func.func @memref_16bit_Uniform(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_16bit_Uniform(%arg0: memref>) { return } // CHECK-LABEL: spv.func @memref_16bit_PushConstant // CHECK-SAME: !spv.ptr [0])>, PushConstant> // NOEMU-LABEL: func @memref_16bit_PushConstant -// NOEMU-SAME: memref -func.func @memref_16bit_PushConstant(%arg0: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_16bit_PushConstant(%arg0: memref>) { return } // CHECK-LABEL: spv.func @memref_16bit_Input // CHECK-SAME: !spv.ptr)>, Input> // NOEMU-LABEL: func @memref_16bit_Input -// NOEMU-SAME: memref -func.func @memref_16bit_Input(%arg3: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_16bit_Input(%arg3: memref>) { return } // CHECK-LABEL: spv.func @memref_16bit_Output // CHECK-SAME: !spv.ptr)>, Output> // NOEMU-LABEL: func @memref_16bit_Output -// NOEMU-SAME: memref -func.func @memref_16bit_Output(%arg4: memref) { return } +// NOEMU-SAME: memref> +func.func @memref_16bit_Output(%arg4: memref>) { return } } // end module @@ -684,15 +707,16 @@ // CHECK-SAME: !spv.ptr, stride=8> [0])>, StorageBuffer> // CHECK-SAME: !spv.ptr, stride=16> [0])>, Uniform> func.func @memref_vector( - %arg0: memref<4xvector<2xf32>, 0>, - %arg1: memref<4xvector<4xf32>, 4>) + %arg0: memref<4xvector<2xf32>, #spv.storage_class>, + %arg1: memref<4xvector<4xf32>, #spv.storage_class>) { return } // CHECK-LABEL: func @dynamic_dim_memref_vector // CHECK-SAME: !spv.ptr, stride=16> [0])>, StorageBuffer> // CHECK-SAME: !spv.ptr, stride=8> [0])>, StorageBuffer> -func.func @dynamic_dim_memref_vector(%arg0: memref<8x?xvector<4xi32>>, - %arg1: memref>) +func.func @dynamic_dim_memref_vector( + %arg0: memref<8x?xvector<4xi32>, #spv.storage_class>, + %arg1: memref, #spv.storage_class>) { return } } // end module @@ -705,9 +729,9 @@ } { // CHECK-LABEL: func @memref_vector_wrong_size -// CHECK-SAME: memref<4xvector<5xf32>> +// CHECK-SAME: memref<4xvector<5xf32>, #spv.storage_class> func.func @memref_vector_wrong_size( - %arg0: memref<4xvector<5xf32>, 0>) + %arg0: memref<4xvector<5xf32>, #spv.storage_class>) { return } } // end module diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir rename from mlir/test/Conversion/GPUToSPIRV/simple.mlir rename to mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/gpu-to-spirv.mlir @@ -7,7 +7,7 @@ // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), StorageBuffer>} // CHECK-SAME: {{%.*}}: !spv.ptr [0])>, StorageBuffer> {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>} // CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi : vector<3xi32>> - gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32>) kernel + gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { // CHECK: spv.Return gpu.return @@ -16,11 +16,11 @@ func.func @main() { %0 = "op"() : () -> (f32) - %1 = "op"() : () -> (memref<12xf32>) + %1 = "op"() : () -> (memref<12xf32, #spv.storage_class>) %cst = arith.constant 1 : index gpu.launch_func @kernels::@basic_module_structure blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) - args(%0 : f32, %1 : memref<12xf32>) + args(%0 : f32, %1 : memref<12xf32, #spv.storage_class>) return } } @@ -39,7 +39,7 @@ gpu.func @basic_module_structure_preset_ABI( %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>}, - %arg1 : memref<12xf32> + %arg1 : memref<12xf32, #spv.storage_class> {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { @@ -55,18 +55,18 @@ gpu.module @kernels { // expected-error @below {{failed to legalize operation 'gpu.func'}} // expected-remark @below {{match failure: missing 'spv.entry_point_abi' attribute}} - gpu.func @missing_entry_point_abi(%arg0 : f32, %arg1 : memref<12xf32>) kernel { + gpu.func @missing_entry_point_abi(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class>) kernel { gpu.return } } func.func @main() { %0 = "op"() : () -> (f32) - %1 = "op"() : () -> (memref<12xf32>) + %1 = "op"() : () -> (memref<12xf32, #spv.storage_class>) %cst = arith.constant 1 : index gpu.launch_func @kernels::@missing_entry_point_abi blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) - args(%0 : f32, %1 : memref<12xf32>) + args(%0 : f32, %1 : memref<12xf32, #spv.storage_class>) return } } @@ -80,7 +80,7 @@ gpu.func @missing_entry_point_abi( %arg0 : f32 {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>}, - %arg1 : memref<12xf32>) kernel + %arg1 : memref<12xf32, #spv.storage_class>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { gpu.return @@ -96,7 +96,7 @@ // expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 0}} gpu.func @missing_entry_point_abi( %arg0 : f32, - %arg1 : memref<12xf32> + %arg1 : memref<12xf32, #spv.storage_class> {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { @@ -110,7 +110,7 @@ module attributes {gpu.container_module} { gpu.module @kernels { // CHECK-LABEL: spv.func @barrier - gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel + gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { // CHECK: spv.ControlBarrier , , gpu.barrier @@ -120,11 +120,11 @@ func.func @main() { %0 = "op"() : () -> (f32) - %1 = "op"() : () -> (memref<12xf32>) + %1 = "op"() : () -> (memref<12xf32, #spv.storage_class>) %cst = arith.constant 1 : index gpu.launch_func @kernels::@barrier blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) - args(%0 : f32, %1 : memref<12xf32>) + args(%0 : f32, %1 : memref<12xf32, #spv.storage_class>) return } } diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -5,7 +5,7 @@ spv.target_env = #spv.target_env< #spv.vce, #spv.resource_limits<>> } { - func.func @load_store(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>) { + func.func @load_store(%arg0: memref<12x4xf32, #spv.storage_class>, %arg1: memref<12x4xf32, #spv.storage_class>, %arg2: memref<12x4xf32, #spv.storage_class>) { %c0 = arith.constant 0 : index %c12 = arith.constant 12 : index %0 = arith.subi %c12, %c0 : index @@ -17,7 +17,7 @@ %c1_2 = arith.constant 1 : index gpu.launch_func @kernels::@load_store_kernel blocks in (%0, %c1_2, %c1_2) threads in (%1, %c1_2, %c1_2) - args(%arg0 : memref<12x4xf32>, %arg1 : memref<12x4xf32>, %arg2 : memref<12x4xf32>, + args(%arg0 : memref<12x4xf32, #spv.storage_class>, %arg1 : memref<12x4xf32, #spv.storage_class>, %arg2 : memref<12x4xf32, #spv.storage_class>, %c0 : index, %c0_0 : index, %c1 : index, %c1_1 : index) return } @@ -35,7 +35,7 @@ // CHECK-SAME: %[[ARG4:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 4), StorageBuffer>} // CHECK-SAME: %[[ARG5:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 5), StorageBuffer>} // CHECK-SAME: %[[ARG6:.*]]: i32 {spv.interface_var_abi = #spv.interface_var_abi<(0, 6), StorageBuffer>} - gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel + gpu.func @load_store_kernel(%arg0: memref<12x4xf32, #spv.storage_class>, %arg1: memref<12x4xf32, #spv.storage_class>, %arg2: memref<12x4xf32, #spv.storage_class>, %arg3: index, %arg4: index, %arg5: index, %arg6: index) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { // CHECK: %[[ADDRESSWORKGROUPID:.*]] = spv.mlir.addressof @[[$WORKGROUPIDVAR]] // CHECK: %[[WORKGROUPID:.*]] = spv.Load "Input" %[[ADDRESSWORKGROUPID]] @@ -69,15 +69,15 @@ // CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32 // CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}} // CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]] - %14 = memref.load %arg0[%12, %13] : memref<12x4xf32> + %14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spv.storage_class> // CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}} // CHECK-NEXT: %[[VAL2:.*]] = spv.Load "StorageBuffer" %[[PTR2]] - %15 = memref.load %arg1[%12, %13] : memref<12x4xf32> + %15 = memref.load %arg1[%12, %13] : memref<12x4xf32, #spv.storage_class> // CHECK: %[[VAL3:.*]] = spv.FAdd %[[VAL1]], %[[VAL2]] %16 = arith.addf %14, %15 : f32 // CHECK: %[[PTR3:.*]] = spv.AccessChain %[[ARG2]]{{\[}}{{%.*}}, {{%.*}}{{\]}} // CHECK-NEXT: spv.Store "StorageBuffer" %[[PTR3]], %[[VAL3]] - memref.store %16, %arg2[%12, %13] : memref<12x4xf32> + memref.store %16, %arg2[%12, %13] : memref<12x4xf32, #spv.storage_class> gpu.return } } diff --git a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir rename from mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir rename to mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-structure-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir @@ -12,7 +12,7 @@ // CHECK-SAME: {{%.*}}: !spv.ptr)>, CrossWorkgroup> // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi : vector<3xi32>> - gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, 11>) kernel + gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class>) kernel attributes {spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>>} { gpu.return } @@ -20,11 +20,11 @@ func.func @main() { %0 = "op"() : () -> (f32) - %1 = "op"() : () -> (memref<12xf32, 11>) + %1 = "op"() : () -> (memref<12xf32, #spv.storage_class>) %cst = arith.constant 1 : index gpu.launch_func @kernels::@basic_module_structure blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) - args(%0 : f32, %1 : memref<12xf32, 11>) + args(%0 : f32, %1 : memref<12xf32, #spv.storage_class>) return } } diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir --- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir +++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir @@ -44,12 +44,12 @@ // CHECK: } // CHECK: spv.Return -func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { +func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class>, %output: memref<1xi32, #spv.storage_class>) attributes { spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>> } { linalg.generic #single_workgroup_reduction_trait - ins(%input : memref<16xi32>) - outs(%output : memref<1xi32>) { + ins(%input : memref<16xi32, #spv.storage_class>) + outs(%output : memref<1xi32, #spv.storage_class>) { ^bb(%in: i32, %out: i32): %sum = arith.addi %in, %out : i32 linalg.yield %sum : i32 @@ -74,11 +74,11 @@ spv.target_env = #spv.target_env< #spv.vce, #spv.resource_limits<>> } { -func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) { +func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class>, %output: memref<1xi32, #spv.storage_class>) { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} linalg.generic #single_workgroup_reduction_trait - ins(%input : memref<16xi32>) - outs(%output : memref<1xi32>) { + ins(%input : memref<16xi32, #spv.storage_class>) + outs(%output : memref<1xi32, #spv.storage_class>) { ^bb(%in: i32, %out: i32): %sum = arith.addi %in, %out : i32 linalg.yield %sum : i32 @@ -103,13 +103,13 @@ spv.target_env = #spv.target_env< #spv.vce, #spv.resource_limits<>> } { -func.func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes { +func.func @single_workgroup_reduction(%input: memref<16xi32, #spv.storage_class>, %output: memref<1xi32, #spv.storage_class>) attributes { spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>> } { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} linalg.generic #single_workgroup_reduction_trait - ins(%input : memref<16xi32>) - outs(%output : memref<1xi32>) { + ins(%input : memref<16xi32, #spv.storage_class>) + outs(%output : memref<1xi32, #spv.storage_class>) { ^bb(%in: i32, %out: i32): %sum = arith.addi %in, %out : i32 linalg.yield %sum : i32 @@ -134,13 +134,13 @@ spv.target_env = #spv.target_env< #spv.vce, #spv.resource_limits<>> } { -func.func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi32>) attributes { +func.func @single_workgroup_reduction(%input: memref<16x8xi32, #spv.storage_class>, %output: memref<16xi32, #spv.storage_class>) attributes { spv.entry_point_abi = #spv.entry_point_abi: vector<3xi32>> } { // expected-error @+1 {{failed to legalize operation 'linalg.generic'}} linalg.generic #single_workgroup_reduction_trait - ins(%input : memref<16x8xi32>) - outs(%output : memref<16xi32>) { + ins(%input : memref<16x8xi32, #spv.storage_class>) + outs(%output : memref<16xi32, #spv.storage_class>) { ^bb(%in: i32, %out: i32): %sum = arith.addi %in, %out : i32 linalg.yield %sum : i32 diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloc.mlir @@ -6,10 +6,10 @@ } { func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { - %0 = memref.alloc() : memref<4x5xf32, 3> - %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 3> - memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 3> - memref.dealloc %0 : memref<4x5xf32, 3> + %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class> + %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class> + memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class> + memref.dealloc %0 : memref<4x5xf32, #spv.storage_class> return } } @@ -31,10 +31,10 @@ } { func.func @alloc_dealloc_workgroup_mem(%arg0 : index, %arg1 : index) { - %0 = memref.alloc() : memref<4x5xi16, 3> - %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, 3> - memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, 3> - memref.dealloc %0 : memref<4x5xi16, 3> + %0 = memref.alloc() : memref<4x5xi16, #spv.storage_class> + %1 = memref.load %0[%arg0, %arg1] : memref<4x5xi16, #spv.storage_class> + memref.store %1, %0[%arg0, %arg1] : memref<4x5xi16, #spv.storage_class> + memref.dealloc %0 : memref<4x5xi16, #spv.storage_class> return } } @@ -60,8 +60,8 @@ } { func.func @two_allocs() { - %0 = memref.alloc() : memref<4x5xf32, 3> - %1 = memref.alloc() : memref<2x3xi32, 3> + %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class> + %1 = memref.alloc() : memref<2x3xi32, #spv.storage_class> return } } @@ -80,8 +80,8 @@ } { func.func @two_allocs_vector() { - %0 = memref.alloc() : memref<4xvector<4xf32>, 3> - %1 = memref.alloc() : memref<2xvector<2xi32>, 3> + %0 = memref.alloc() : memref<4xvector<4xf32>, #spv.storage_class> + %1 = memref.alloc() : memref<2xvector<2xi32>, #spv.storage_class> return } } @@ -103,8 +103,8 @@ // CHECK-LABEL: func @alloc_dynamic_size func.func @alloc_dynamic_size(%arg0 : index) -> f32 { // CHECK: memref.alloc - %0 = memref.alloc(%arg0) : memref<4x?xf32, 3> - %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 3> + %0 = memref.alloc(%arg0) : memref<4x?xf32, #spv.storage_class> + %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spv.storage_class> return %1: f32 } } @@ -119,8 +119,8 @@ // CHECK-LABEL: func @alloc_unsupported_memory_space func.func @alloc_unsupported_memory_space(%arg0: index) -> f32 { // CHECK: memref.alloc - %0 = memref.alloc() : memref<4x5xf32> - %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32> + %0 = memref.alloc() : memref<4x5xf32, #spv.storage_class> + %1 = memref.load %0[%arg0, %arg0] : memref<4x5xf32, #spv.storage_class> return %1: f32 } } @@ -134,9 +134,9 @@ } { // CHECK-LABEL: func @dealloc_dynamic_size - func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, 3>) { + func.func @dealloc_dynamic_size(%arg0 : memref<4x?xf32, #spv.storage_class>) { // CHECK: memref.dealloc - memref.dealloc %arg0 : memref<4x?xf32, 3> + memref.dealloc %arg0 : memref<4x?xf32, #spv.storage_class> return } } @@ -149,9 +149,9 @@ } { // CHECK-LABEL: func @dealloc_unsupported_memory_space - func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32>) { + func.func @dealloc_unsupported_memory_space(%arg0 : memref<4x5xf32, #spv.storage_class>) { // CHECK: memref.dealloc - memref.dealloc %arg0 : memref<4x5xf32> + memref.dealloc %arg0 : memref<4x5xf32, #spv.storage_class> return } } diff --git a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/alloca.mlir @@ -2,9 +2,9 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>>} { func.func @alloc_function_variable(%arg0 : index, %arg1 : index) { - %0 = memref.alloca() : memref<4x5xf32, 6> - %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, 6> - memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, 6> + %0 = memref.alloca() : memref<4x5xf32, #spv.storage_class> + %1 = memref.load %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class> + memref.store %1, %0[%arg0, %arg1] : memref<4x5xf32, #spv.storage_class> return } } @@ -21,8 +21,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>>} { func.func @two_allocs() { - %0 = memref.alloca() : memref<4x5xf32, 6> - %1 = memref.alloca() : memref<2x3xi32, 6> + %0 = memref.alloca() : memref<4x5xf32, #spv.storage_class> + %1 = memref.alloca() : memref<2x3xi32, #spv.storage_class> return } } @@ -35,8 +35,8 @@ module attributes {spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>>} { func.func @two_allocs_vector() { - %0 = memref.alloca() : memref<4xvector<4xf32>, 6> - %1 = memref.alloca() : memref<2xvector<2xi32>, 6> + %0 = memref.alloca() : memref<4xvector<4xf32>, #spv.storage_class> + %1 = memref.alloca() : memref<2xvector<2xi32>, #spv.storage_class> return } } @@ -52,8 +52,8 @@ // CHECK-LABEL: func @alloc_dynamic_size func.func @alloc_dynamic_size(%arg0 : index) -> f32 { // CHECK: memref.alloca - %0 = memref.alloca(%arg0) : memref<4x?xf32, 6> - %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, 6> + %0 = memref.alloca(%arg0) : memref<4x?xf32, #spv.storage_class> + %1 = memref.load %0[%arg0, %arg0] : memref<4x?xf32, #spv.storage_class> return %1: f32 } } diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -15,60 +15,60 @@ } { // CHECK-LABEL: @load_store_zero_rank_float -func.func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { - // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> - // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> +func.func @load_store_zero_rank_float(%arg0: memref>, %arg1: memref>) { + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr [0])>, StorageBuffer> // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ // CHECK-SAME: [[ZERO1]], [[ZERO1]] // CHECK-SAME: ] : // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG1]][ // CHECK-SAME: [[ZERO2]], [[ZERO2]] // CHECK-SAME: ] : // CHECK: spv.Store "StorageBuffer" %{{.*}} : f32 - memref.store %0, %arg1[] : memref + memref.store %0, %arg1[] : memref> return } // CHECK-LABEL: @load_store_zero_rank_int -func.func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { - // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> - // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> +func.func @load_store_zero_rank_int(%arg0: memref>, %arg1: memref>) { + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr [0])>, StorageBuffer> // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ // CHECK-SAME: [[ZERO1]], [[ZERO1]] // CHECK-SAME: ] : // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG1]][ // CHECK-SAME: [[ZERO2]], [[ZERO2]] // CHECK-SAME: ] : // CHECK: spv.Store "StorageBuffer" %{{.*}} : i32 - memref.store %0, %arg1[] : memref + memref.store %0, %arg1[] : memref> return } // CHECK-LABEL: func @load_store_unknown_dim -func.func @load_store_unknown_dim(%i: index, %source: memref, %dest: memref) { - // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> - // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> +func.func @load_store_unknown_dim(%i: index, %source: memref>, %dest: memref>) { + // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr [0])>, StorageBuffer> // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]] // CHECK: spv.Load "StorageBuffer" %[[AC0]] - %0 = memref.load %source[%i] : memref + %0 = memref.load %source[%i] : memref> // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]] // CHECK: spv.Store "StorageBuffer" %[[AC1]] - memref.store %0, %dest[%i]: memref + memref.store %0, %dest[%i]: memref> return } // CHECK-LABEL: func @load_i1 -// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index) -func.func @load_i1(%src: memref<4xi1>, %i : index) -> i1 { - // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class>, %[[IDX:.+]]: index) +func.func @load_i1(%src: memref<4xi1, #spv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class> to !spv.ptr [0])>, StorageBuffer> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 @@ -79,17 +79,17 @@ // CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[ADDR]] : i8 // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8 - %0 = memref.load %src[%i] : memref<4xi1> + %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class> // CHECK: return %[[BOOL]] return %0: i1 } // CHECK-LABEL: func @store_i1 -// CHECK-SAME: %[[DST:.+]]: memref<4xi1>, +// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class>, // CHECK-SAME: %[[IDX:.+]]: index -func.func @store_i1(%dst: memref<4xi1>, %i: index) { +func.func @store_i1(%dst: memref<4xi1, #spv.storage_class>, %i: index) { %true = arith.constant true - // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class> to !spv.ptr [0])>, StorageBuffer> // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 @@ -101,7 +101,7 @@ // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8 // CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8 - memref.store %true, %dst[%i]: memref<4xi1> + memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class> return } @@ -118,7 +118,7 @@ } { // CHECK-LABEL: @load_i1 -func.func @load_i1(%arg0: memref) -> i1 { +func.func @load_i1(%arg0: memref>) -> i1 { // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 @@ -138,12 +138,12 @@ // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 // CHECK: %[[RES:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 // CHECK: return %[[RES]] - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> return %0 : i1 } // CHECK-LABEL: @load_i8 -func.func @load_i8(%arg0: memref) { +func.func @load_i8(%arg0: memref>) { // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 @@ -159,13 +159,13 @@ // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> return } // CHECK-LABEL: @load_i16 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: index) -func.func @load_i16(%arg0: memref<10xi16>, %index : index) { +func.func @load_i16(%arg0: memref<10xi16, #spv.storage_class>, %index : index) { // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32 @@ -186,31 +186,31 @@ // CHECK: %[[T2:.+]] = spv.Constant 16 : i32 // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - %0 = memref.load %arg0[%index] : memref<10xi16> + %0 = memref.load %arg0[%index] : memref<10xi16, #spv.storage_class> return } // CHECK-LABEL: @load_i32 -func.func @load_i32(%arg0: memref) { +func.func @load_i32(%arg0: memref>) { // CHECK-NOT: spv.SDiv // CHECK: spv.Load // CHECK-NOT: spv.ShiftRightArithmetic - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> return } // CHECK-LABEL: @load_f32 -func.func @load_f32(%arg0: memref) { +func.func @load_f32(%arg0: memref>) { // CHECK-NOT: spv.SDiv // CHECK: spv.Load // CHECK-NOT: spv.ShiftRightArithmetic - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> return } // CHECK-LABEL: @store_i1 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) -func.func @store_i1(%arg0: memref, %value: i1) { +func.func @store_i1(%arg0: memref>, %value: i1) { // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 @@ -230,13 +230,13 @@ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[] : memref + memref.store %value, %arg0[] : memref> return } // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) -func.func @store_i8(%arg0: memref, %value: i8) { +func.func @store_i8(%arg0: memref>, %value: i8) { // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 @@ -254,13 +254,13 @@ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[] : memref + memref.store %value, %arg0[] : memref> return } // CHECK-LABEL: @store_i16 -// CHECK: (%[[ARG0:.+]]: memref<10xi16>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16) -func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { +// CHECK: (%[[ARG0:.+]]: memref<10xi16, #spv.storage_class>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16) +func.func @store_i16(%arg0: memref<10xi16, #spv.storage_class>, %index: index, %value: i16) { // CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 @@ -283,25 +283,25 @@ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[%index] : memref<10xi16> + memref.store %value, %arg0[%index] : memref<10xi16, #spv.storage_class> return } // CHECK-LABEL: @store_i32 -func.func @store_i32(%arg0: memref, %value: i32) { +func.func @store_i32(%arg0: memref>, %value: i32) { // CHECK: spv.Store // CHECK-NOT: spv.AtomicAnd // CHECK-NOT: spv.AtomicOr - memref.store %value, %arg0[] : memref + memref.store %value, %arg0[] : memref> return } // CHECK-LABEL: @store_f32 -func.func @store_f32(%arg0: memref, %value: f32) { +func.func @store_f32(%arg0: memref>, %value: f32) { // CHECK: spv.Store // CHECK-NOT: spv.AtomicAnd // CHECK-NOT: spv.AtomicOr - memref.store %value, %arg0[] : memref + memref.store %value, %arg0[] : memref> return } @@ -318,7 +318,7 @@ } { // CHECK-LABEL: @load_i8 -func.func @load_i8(%arg0: memref) { +func.func @load_i8(%arg0: memref>) { // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 @@ -334,22 +334,22 @@ // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> return } // CHECK-LABEL: @load_i16 -func.func @load_i16(%arg0: memref) { +func.func @load_i16(%arg0: memref>) { // CHECK-NOT: spv.SDiv // CHECK: spv.Load // CHECK-NOT: spv.ShiftRightArithmetic - %0 = memref.load %arg0[] : memref + %0 = memref.load %arg0[] : memref> return } // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) -func.func @store_i8(%arg0: memref, %value: i8) { +func.func @store_i8(%arg0: memref>, %value: i8) { // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 @@ -367,16 +367,16 @@ // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] - memref.store %value, %arg0[] : memref + memref.store %value, %arg0[] : memref> return } // CHECK-LABEL: @store_i16 -func.func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { +func.func @store_i16(%arg0: memref<10xi16, #spv.storage_class>, %index: index, %value: i16) { // CHECK: spv.Store // CHECK-NOT: spv.AtomicAnd // CHECK-NOT: spv.AtomicOr - memref.store %value, %arg0[%index] : memref<10xi16> + memref.store %value, %arg0[%index] : memref<10xi16, #spv.storage_class> return } diff --git a/mlir/test/Conversion/SCFToSPIRV/for.mlir b/mlir/test/Conversion/SCFToSPIRV/for.mlir --- a/mlir/test/Conversion/SCFToSPIRV/for.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/for.mlir @@ -5,7 +5,7 @@ #spv.vce, #spv.resource_limits<>> } { -func.func @loop_kernel(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) { +func.func @loop_kernel(%arg2 : memref<10xf32, #spv.storage_class>, %arg3 : memref<10xf32, #spv.storage_class>) { // CHECK: %[[LB:.*]] = spv.Constant 4 : i32 %lb = arith.constant 4 : index // CHECK: %[[UB:.*]] = spv.Constant 42 : i32 @@ -36,14 +36,14 @@ // CHECK: spv.mlir.merge // CHECK: } scf.for %arg4 = %lb to %ub step %step { - %1 = memref.load %arg2[%arg4] : memref<10xf32> - memref.store %1, %arg3[%arg4] : memref<10xf32> + %1 = memref.load %arg2[%arg4] : memref<10xf32, #spv.storage_class> + memref.store %1, %arg3[%arg4] : memref<10xf32, #spv.storage_class> } return } // CHECK-LABEL: @loop_yield -func.func @loop_yield(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>) { +func.func @loop_yield(%arg2 : memref<10xf32, #spv.storage_class>, %arg3 : memref<10xf32, #spv.storage_class>) { // CHECK: %[[LB:.*]] = spv.Constant 4 : i32 %lb = arith.constant 4 : index // CHECK: %[[UB:.*]] = spv.Constant 42 : i32 @@ -78,8 +78,8 @@ // CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32 // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32 // CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32 - memref.store %result#0, %arg3[%lb] : memref<10xf32> - memref.store %result#1, %arg3[%ub] : memref<10xf32> + memref.store %result#0, %arg3[%lb] : memref<10xf32, #spv.storage_class> + memref.store %result#1, %arg3[%ub] : memref<10xf32, #spv.storage_class> return } diff --git a/mlir/test/Conversion/SCFToSPIRV/if.mlir b/mlir/test/Conversion/SCFToSPIRV/if.mlir --- a/mlir/test/Conversion/SCFToSPIRV/if.mlir +++ b/mlir/test/Conversion/SCFToSPIRV/if.mlir @@ -6,7 +6,7 @@ } { // CHECK-LABEL: @kernel_simple_selection -func.func @kernel_simple_selection(%arg2 : memref<10xf32>, %arg3 : i1) { +func.func @kernel_simple_selection(%arg2 : memref<10xf32, #spv.storage_class>, %arg3 : i1) { %value = arith.constant 0.0 : f32 %i = arith.constant 0 : index @@ -20,13 +20,13 @@ // CHECK-NEXT: spv.Return scf.if %arg3 { - memref.store %value, %arg2[%i] : memref<10xf32> + memref.store %value, %arg2[%i] : memref<10xf32, #spv.storage_class> } return } // CHECK-LABEL: @kernel_nested_selection -func.func @kernel_nested_selection(%arg3 : memref<10xf32>, %arg4 : memref<10xf32>, %arg5 : i1, %arg6 : i1) { +func.func @kernel_nested_selection(%arg3 : memref<10xf32, #spv.storage_class>, %arg4 : memref<10xf32, #spv.storage_class>, %arg5 : i1, %arg6 : i1) { %i = arith.constant 0 : index %j = arith.constant 9 : index @@ -61,26 +61,26 @@ scf.if %arg5 { scf.if %arg6 { - %value = memref.load %arg3[%i] : memref<10xf32> - memref.store %value, %arg4[%i] : memref<10xf32> + %value = memref.load %arg3[%i] : memref<10xf32, #spv.storage_class> + memref.store %value, %arg4[%i] : memref<10xf32, #spv.storage_class> } else { - %value = memref.load %arg4[%i] : memref<10xf32> - memref.store %value, %arg3[%i] : memref<10xf32> + %value = memref.load %arg4[%i] : memref<10xf32, #spv.storage_class> + memref.store %value, %arg3[%i] : memref<10xf32, #spv.storage_class> } } else { scf.if %arg6 { - %value = memref.load %arg3[%j] : memref<10xf32> - memref.store %value, %arg4[%j] : memref<10xf32> + %value = memref.load %arg3[%j] : memref<10xf32, #spv.storage_class> + memref.store %value, %arg4[%j] : memref<10xf32, #spv.storage_class> } else { - %value = memref.load %arg4[%j] : memref<10xf32> - memref.store %value, %arg3[%j] : memref<10xf32> + %value = memref.load %arg4[%j] : memref<10xf32, #spv.storage_class> + memref.store %value, %arg3[%j] : memref<10xf32, #spv.storage_class> } } return } // CHECK-LABEL: @simple_if_yield -func.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) { +func.func @simple_if_yield(%arg2 : memref<10xf32, #spv.storage_class>, %arg3 : i1) { // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr // CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr // CHECK: spv.mlir.selection { @@ -116,15 +116,15 @@ } %i = arith.constant 0 : index %j = arith.constant 1 : index - memref.store %0#0, %arg2[%i] : memref<10xf32> - memref.store %0#1, %arg2[%j] : memref<10xf32> + memref.store %0#0, %arg2[%i] : memref<10xf32, #spv.storage_class> + memref.store %0#1, %arg2[%j] : memref<10xf32, #spv.storage_class> return } // TODO: The transformation should only be legal if VariablePointer capability // is supported. This test is still useful to make sure we can handle scf op // result with type change. -func.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) { +func.func @simple_if_yield_type_change(%arg2 : memref<10xf32, #spv.storage_class>, %arg3 : memref<10xf32, #spv.storage_class>, %arg4 : i1) { // CHECK-LABEL: @simple_if_yield_type_change // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr [0])>, StorageBuffer>, Function> // CHECK: spv.mlir.selection { @@ -144,12 +144,12 @@ // CHECK: spv.Return %i = arith.constant 0 : index %value = arith.constant 0.0 : f32 - %0 = scf.if %arg4 -> (memref<10xf32>) { - scf.yield %arg2 : memref<10xf32> + %0 = scf.if %arg4 -> (memref<10xf32, #spv.storage_class>) { + scf.yield %arg2 : memref<10xf32, #spv.storage_class> } else { - scf.yield %arg3 : memref<10xf32> + scf.yield %arg3 : memref<10xf32, #spv.storage_class> } - memref.store %value, %0[%i] : memref<10xf32> + memref.store %value, %0[%i] : memref<10xf32, #spv.storage_class> return } diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -75,7 +75,7 @@ PassManager passManager(module.getContext()); applyPassManagerCLOptions(passManager); passManager.addPass(createGpuKernelOutliningPass()); - passManager.addPass(createConvertGPUToSPIRVPass()); + passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); OpPassManager &nestedPM = passManager.nest(); nestedPM.addPass(spirv::createLowerABIAttributesPass()); diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -47,10 +47,12 @@ passManager.addPass(createGpuKernelOutliningPass()); passManager.addPass(memref::createFoldSubViewOpsPass()); - passManager.addPass(createConvertGPUToSPIRVPass()); + + passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); OpPassManager &modulePM = passManager.nest(); modulePM.addPass(spirv::createLowerABIAttributesPass()); modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass()); + passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module)); passManager.addPass(createMemRefToLLVMPass()); @@ -58,6 +60,7 @@ passManager.addPass(createConvertFuncToLLVMPass(llvmOptions)); passManager.addPass(createReconcileUnrealizedCastsPass()); passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass()); + return passManager.run(module); }