diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -991,6 +991,9 @@ loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + // Treat memrefs on i1 as i8. + if (srcBits == 1) + srcBits = 8; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -1117,6 +1120,11 @@ spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + + // Treat memrefs on i1 as i8. + bool isInt1 = srcBits == 1; + if (isInt1) + srcBits = 8; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -1156,8 +1164,14 @@ rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = - shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); + Value storeVal = storeOperands.value(); + if (isInt1) { + Value zero = rewriter.create(loc, 0, dstType); + Value one = rewriter.create(loc, 1, dstType); + storeVal = + rewriter.create(loc, dstType, storeVal, one, zero); + } + storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -451,6 +451,15 @@ }); addConversion([this](MemRefType memRefType) { + if (auto elemType = memRefType.getElementType().dyn_cast()) { + // Treat i1 on memrefs as i8. + if (elemType.getWidth() == 1) { + return convertMemrefType( + targetEnv, + MemRefType::get(memRefType.getShape(), + IntegerType::get(memRefType.getContext(), 8))); + } + } return convertMemrefType(targetEnv, memRefType); }); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -911,12 +911,33 @@ // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. -// TODO: Test i1 and i64 types. +// TODO: Test i64 types. module attributes { spv.target_env = #spv.target_env< #spv.vce, {}> } { +// CHECK-LABEL: @load_i1 +func @load_i1(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + %0 = memref.load %arg0[] : memref + return +} + // CHECK-LABEL: @load_i8 func @load_i8(%arg0: memref) { // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 @@ -982,6 +1003,31 @@ return } +// CHECK-LABEL: @store_i1 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) +func @store_i1(%arg0: memref, %value: i1) { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32 + // CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref + return +} + // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) func @store_i8(%arg0: memref, %value: i8) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -253,20 +253,6 @@ // ----- -// Check that boolean memref is not supported at the moment. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) -func @memref_type(%arg0: memref<3xi1>) { - return -} - -} // end module - -// ----- - // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not // satisfied. @@ -274,6 +260,10 @@ spv.target_env = #spv.target_env<#spv.vce, {}> } { +// CHECK-LABEL: spv.func @memref_1bit_type +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +func @memref_1bit_type(%arg0: memref<3xi1>) { return } + // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }