diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -119,6 +119,15 @@ return {}; } +/// Casts the given `srcInt` into a boolean value. +static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { + if (srcInt.getType().isInteger(1)) + return srcInt; + + auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); + return builder.create(loc, srcInt, one); +} + /// Casts the given `srcBool` into an integer of `dstType`. static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder) { @@ -299,8 +308,11 @@ // If the rewrited load op has the same bit width, use the loading value // directly. if (srcBits == dstBits) { - rewriter.replaceOpWithNewOp(loadOp, - accessChainOp.getResult()); + Value loadVal = + rewriter.create(loc, accessChainOp.getResult()); + if (isBool) + loadVal = castIntNToBool(loc, loadVal, rewriter); + rewriter.replaceOp(loadOp, loadVal); return success(); } @@ -343,8 +355,7 @@ if (isBool) { dstType = typeConverter.convertType(loadOp.getType()); mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); - Value isOne = rewriter.create(loc, result, mask); - result = castBoolToIntN(loc, isOne, dstType, rewriter); + result = rewriter.create(loc, result, mask); } else if (result.getType().getIntOrFloatBitWidth() != static_cast(dstBits)) { result = rewriter.create(loc, dstType, result); diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -65,6 +65,25 @@ return } +// CHECK-LABEL: func @load_i1 +// CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index) +func @load_i1(%src: memref<4xi1>, %i : index) -> i1 { + // CHECK: %[[SRC_CAST:.+]] = unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[IDX_CAST:.+]] = unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 + // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[SRC_CAST]][%[[ZERO_0]], %[[ADD]]] + // CHECK: %[[VAL:.+]] = spv.Load "StorageBuffer" %[[ADDR]] : i8 + // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 + // CHECK: %[[BOOL:.+]] = spv.IEqual %[[VAL]], %[[ONE_I8]] : i8 + %0 = memref.load %src[%i] : memref<4xi1> + // CHECK: return %[[BOOL]] + return %0: i1 +} + // CHECK-LABEL: func @store_i1 // CHECK-SAME: %[[DST:.+]]: memref<4xi1>, // CHECK-SAME: %[[IDX:.+]]: index @@ -77,7 +96,7 @@ // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 - // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8 // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8