diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -773,6 +773,9 @@ loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + // Treat memrefs on i1 as i8. + if (srcBits == 1) + srcBits = 8; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -894,6 +897,10 @@ spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + // Treat memrefs on i1 as i8. + bool isInt1 = srcBits == 1; + if (isInt1) + srcBits = 8; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -934,8 +941,14 @@ rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = - shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); + Value storeVal = storeOperands.value(); + if (isInt1) { + Value zero = rewriter.create(loc, 0, dstType); + Value one = rewriter.create(loc, 1, dstType); + storeVal = + rewriter.create(loc, dstType, storeVal, one, zero); + } + storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -437,6 +437,14 @@ }); addConversion([this](MemRefType memRefType) { + auto elemType = memRefType.getElementType(); + // Treat i1 on memrefs as i8. + if (elemType.isSignlessInteger() && elemType.getIntOrFloatBitWidth() == 1) { + return convertMemrefType( + targetEnv, + MemRefType::get(memRefType.getShape(), + IntegerType::get(8, memRefType.getContext()))); + } return convertMemrefType(targetEnv, memRefType); }); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -724,7 +724,7 @@ // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. -// TODO: Test i1 and i64 types. +// TODO: Test i64 types. module attributes { spv.target_env = #spv.target_env< #spv.vce, @@ -732,6 +732,27 @@ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { +// CHECK-LABEL: @load_i1 +func @load_i1(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + %0 = load %arg0[] : memref + return +} + // CHECK-LABEL: @load_i8 func @load_i8(%arg0: memref) { // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 @@ -797,6 +818,31 @@ return } +// CHECK-LABEL: @store_i1 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) +func @store_i1(%arg0: memref, %value: i1) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[ZERO1:.+]] = spv.constant 0 : i32 + // CHECK: %[[ONE1:.+]] = spv.constant 1 : i32 + // CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + store %value, %arg0[] : memref + return +} + // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) func @store_i8(%arg0: memref, %value: i8) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -283,23 +283,6 @@ // ----- -// Check that boolean memref is not supported at the moment. -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, - {max_compute_workgroup_invocations = 128 : i32, - max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> -} { - -// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) -func @memref_type(%arg0: memref<3xi1>) { - return -} - -} // end module - -// ----- - // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not // satisfied. @@ -310,6 +293,10 @@ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> } { +// CHECK-LABEL: spv.func @memref_1bit_type +// CHECK-SAME: !spv.ptr [0]>, StorageBuffer> +func @memref_1bit_type(%arg0: memref<3xi1>) { return } + // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spv.ptr [0]>, StorageBuffer> func @memref_8bit_StorageBuffer(%arg0: memref<16xi8, 0>) { return }