diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_SPIRV_TRANSFORMS_SPIRVCONVERSION_H #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Transforms/DialectConversion.h" @@ -56,6 +57,10 @@ /// default. unsigned boolNumBits{8}; + /// Has kernel capability a.k.a uses OpenCL flavored SPIRV rather than + /// GLSL/Vulkan. + bool hasKernelCapability{false}; + // Note: we need this instead of inline initializers because of // https://bugs.llvm.org/show_bug.cgi?id=36684 Options() @@ -156,6 +161,10 @@ ValueRange indices, Location loc, OpBuilder &builder); +spirv::PtrAccessChainOp getElementPtrDirect(SPIRVTypeConverter &typeConverter, + MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, + OpBuilder &builder); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -73,7 +73,7 @@ // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); indices.back() = builder.create(loc, lastDim, idx); - Type t = typeConverter.convertType(op.component_ptr().getType()); + Type t = typeConverter.convertType(op.getResult().getType()); return builder.create(loc, t, op.base_ptr(), indices); } @@ -192,7 +192,7 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts memref.load to spv.Load. +/// Converts memref.load to spv.Load + spv.AccessChain. class IntLoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -202,7 +202,17 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts memref.load to spv.Load. +/// Converts memref.load to spv.Load + spv.PtrAccessChain. +class IntLoadOpPtrPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts memref.load to spv.Load + spv.AccessChain. class LoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -212,6 +222,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts memref.load to spv.Load + spv.PtrAccessChain. +class LoadOpPtrPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts memref.store to spv.Store on integers. class IntStoreOpPattern final : public OpConversionPattern { public: @@ -222,6 +242,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts memref.store to spv.Store + spv.PtrAccessChain on integers. +class IntStoreOpPtrPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts memref.store to spv.Store. class StoreOpPattern final : public OpConversionPattern { public: @@ -232,6 +262,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts memref.store to spv.Store + PtrAccessChain. +class StoreOpPtrPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -406,6 +446,52 @@ return success(); } +LogicalResult IntLoadOpPtrPattern::matchAndRewrite( + memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = loadOp.getLoc(); + auto memrefType = loadOp.getMemref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + + auto &typeConverter = *getTypeConverter(); + spirv::PtrAccessChainOp accessChainOp = + spirv::getElementPtrDirect(typeConverter, memrefType, adaptor.getMemref(), + adaptor.getIndices(), loc, rewriter); + if (!accessChainOp) + return failure(); + + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; + Type pointeeType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType(); + bool isScalar = spirv::ScalarType::classof(pointeeType); + bool isComposite = spirv::CompositeType::classof(pointeeType); + if (!(isScalar ^ isComposite)) + return rewriter.notifyMatchFailure( + loadOp, "Designated pointer type needs to be scalar or composite."); + Type dstType = pointeeType; + int dstBits = dstType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + // TODO: Add support for bit-width emulation on PtrAccessChain. + if (srcBits != dstBits) { + return rewriter.notifyMatchFailure( + loadOp, "bit-width emulation not supported for kernel capability yet."); + } + // If the rewrited load op has the same bit width, use the loading value + // directly. + Value loadVal = + rewriter.create(loc, accessChainOp.getResult()); + if (isBool) + loadVal = castIntNToBool(loc, loadVal, rewriter); + rewriter.replaceOp(loadOp, loadVal); + return success(); +} + LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -423,6 +509,23 @@ return success(); } +LogicalResult +LoadOpPtrPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto memrefType = loadOp.getMemref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto loadPtr = spirv::getElementPtrDirect( + *getTypeConverter(), memrefType, adaptor.getMemref(), + adaptor.getIndices(), loadOp.getLoc(), rewriter); + + if (!loadPtr) + return failure(); + + rewriter.replaceOpWithNewOp(loadOp, loadPtr); + return success(); +} + LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -518,6 +621,54 @@ return success(); } +LogicalResult IntStoreOpPtrPattern::matchAndRewrite( + memref::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto memrefType = storeOp.getMemref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + + auto loc = storeOp.getLoc(); + auto &typeConverter = *getTypeConverter(); + spirv::PtrAccessChainOp accessChainOp = + spirv::getElementPtrDirect(typeConverter, memrefType, adaptor.getMemref(), + adaptor.getIndices(), loc, rewriter); + if (!accessChainOp) + return failure(); + + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; + + Type pointeeType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType(); + bool isScalar = spirv::ScalarType::classof(pointeeType); + bool isComposite = spirv::CompositeType::classof(pointeeType); + if (!(isScalar ^ isComposite)) + return rewriter.notifyMatchFailure( + storeOp, "Designated pointer type needs to be scalar or composite."); + Type dstType = pointeeType; + + int dstBits = dstType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + // TODO: Add support for bit-width emulation on PtrAccessChain. + if (srcBits != dstBits) { + return rewriter.notifyMatchFailure( + storeOp, + "bit-width emulation not supported for kernel capability yet."); + } + Value storeVal = adaptor.getValue(); + if (isBool) + storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); + rewriter.replaceOpWithNewOp( + storeOp, accessChainOp.getResult(), storeVal); + return success(); +} + LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -536,6 +687,23 @@ return success(); } +LogicalResult +StoreOpPtrPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto memrefType = storeOp.getMemref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto storePtr = spirv::getElementPtrDirect( + *getTypeConverter(), memrefType, adaptor.getMemref(), + adaptor.getIndices(), storeOp.getLoc(), rewriter); + + if (!storePtr) + return failure(); + + rewriter.replaceOpWithNewOp(storeOp, storePtr, + adaptor.getValue()); + return success(); +} //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -543,9 +711,14 @@ namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns - .add( - typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); + bool hasKernelCapability = typeConverter.getOptions().hasKernelCapability; + if (hasKernelCapability) { + patterns.add(typeConverter, patterns.getContext()); + } else + patterns.add(typeConverter, patterns.getContext()); } } // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -338,6 +338,10 @@ return nullptr; } + if (options.hasKernelCapability) { + return spirv::PointerType::get(arrayElemType, storageClass); + } + if (!type.hasStaticShape()) { int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); @@ -397,6 +401,10 @@ return nullptr; } + if (options.hasKernelCapability) { + return spirv::PointerType::get(arrayElemType, storageClass); + } + if (!type.hasStaticShape()) { int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); @@ -420,6 +428,9 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, Options options) : targetEnv(targetAttr), options(options) { + // Checks if it is OpenCL flavored kernel. + this->options.hasKernelCapability = + this->targetEnv.allows(spirv::Capability::Kernel); // Add conversions. The order matters here: later ones will be tried earlier. // Allow all SPIR-V dialect specific types. This assumes all builtin types @@ -746,6 +757,35 @@ return builder.create(loc, basePtr, linearizedIndices); } +spirv::PtrAccessChainOp mlir::spirv::getElementPtrDirect( + SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, + ValueRange indices, Location loc, OpBuilder &builder) { + // Get base and offset of the MemRefType and verify they are static. + + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(baseType, strides, offset)) || + llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) || + offset == MemRefType::getDynamicStrideOrOffset()) { + return nullptr; + } + + auto indexType = typeConverter.getIndexType(); + + SmallVector linearizedIndices; + auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); + + Value linearIndex; + if (baseType.getRank() == 0) { + linearIndex = zero; + } else { + linearIndex = + linearizeIndex(indices, strides, offset, indexType, loc, builder); + } + return builder.create(loc, basePtr, linearIndex, + linearizedIndices); +} + //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir @@ -9,7 +9,7 @@ // CHECK: spv.func // CHECK-SAME: {{%.*}}: f32 // CHECK-NOT: spv.interface_var_abi - // CHECK-SAME: {{%.*}}: !spv.ptr)>, CrossWorkgroup> + // CHECK-SAME: {{%.*}}: !spv.ptr // CHECK-NOT: spv.interface_var_abi // CHECK-SAME: spv.entry_point_abi = #spv.entry_point_abi : vector<3xi32>> gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spv.storage_class>) kernel diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -109,6 +109,111 @@ // ----- +// Check for Kernel capability, that with proper compute and storage extensions, we don't need to +// perform special tricks. + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, #spv.resource_limits<>> +} { + +// CHECK-LABEL: @load_store_zero_rank_float +func.func @load_store_zero_rank_float(%arg0: memref>, %arg1: memref>) { + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "CrossWorkgroup" %{{.*}} : f32 + %0 = memref.load %arg0[] : memref> + // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "CrossWorkgroup" %{{.*}} : f32 + memref.store %0, %arg1[] : memref> + return +} + +// CHECK-LABEL: @load_store_zero_rank_int +func.func @load_store_zero_rank_int(%arg0: memref>, %arg1: memref>) { + // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG0]][ + // CHECK-SAME: [[ZERO1]] + // CHECK-SAME: ] : + // CHECK: spv.Load "CrossWorkgroup" %{{.*}} : i32 + %0 = memref.load %arg0[] : memref> + // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 + // CHECK: spv.PtrAccessChain [[ARG1]][ + // CHECK-SAME: [[ZERO2]] + // CHECK-SAME: ] : + // CHECK: spv.Store "CrossWorkgroup" %{{.*}} : i32 + memref.store %0, %arg1[] : memref> + return +} + +// CHECK-LABEL: func @load_store_unknown_dim +func.func @load_store_unknown_dim(%i: index, %source: memref>, %dest: memref>) { + // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref> to !spv.ptr + // CHECK: %[[AC0:.+]] = spv.PtrAccessChain %[[SRC]] + // CHECK: spv.Load "CrossWorkgroup" %[[AC0]] + %0 = memref.load %source[%i] : memref> + // CHECK: %[[AC1:.+]] = spv.PtrAccessChain %[[DST]] + // CHECK: spv.Store "CrossWorkgroup" %[[AC1]] + memref.store %0, %dest[%i]: memref> + return +} + +// CHECK-LABEL: func @load_i1 +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1, #spv.storage_class>, %[[IDX:.+]]: index) +func.func @load_i1(%src: memref<4xi1, #spv.storage_class>, %i : index) -> i1 { + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1, #spv.storage_class> to !spv.ptr + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 + // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[SRC_CAST]][%[[ADD]]] + // CHECK: %[[VAL:.+]] = spv.Load "CrossWorkgroup" %[[ADDR]] : i8 + // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 + // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8 + %0 = memref.load %src[%i] : memref<4xi1, #spv.storage_class> + // CHECK: return %[[BOOL]] + return %0: i1 +} + +// CHECK-LABEL: func @store_i1 +// CHECK-SAME: %[[DST:.+]]: memref<4xi1, #spv.storage_class>, +// CHECK-SAME: %[[IDX:.+]]: index +func.func @store_i1(%dst: memref<4xi1, #spv.storage_class>, %i: index) { + %true = arith.constant true + // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1, #spv.storage_class> to !spv.ptr + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 + // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spv.PtrAccessChain %[[DST_CAST]][%[[ADD]]] + // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8 + // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 + // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8 + // CHECK: spv.Store "CrossWorkgroup" %[[ADDR]], %[[RES]] : i8 + memref.store %true, %dst[%i]: memref<4xi1, #spv.storage_class> + return +} + +} // end module + +// ----- + // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. // TODO: Test i64 types.