diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -77,11 +77,37 @@ return builder.create(loc, t, op.getBasePtr(), indices); } -/// Returns the shifted `targetBits`-bit value with the given offset. +/// Casts the given `srcBool` into an integer of `dstType`. +static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, + OpBuilder &builder) { + assert(srcBool.getType().isInteger(1)); + if (dstType.isInteger(1)) + return srcBool; + Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); + Value one = spirv::ConstantOp::getOne(dstType, loc, builder); + return builder.create(loc, dstType, srcBool, one, zero); +} + +/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast +/// to the type destination type, and masked. static Value shiftValue(Location loc, Value value, Value offset, Value mask, - int targetBits, OpBuilder &builder) { - Value result = builder.create(loc, value, mask); - return builder.create(loc, value.getType(), result, + OpBuilder &builder) { + IntegerType dstType = cast(mask.getType()); + int targetBits = static_cast(dstType.getWidth()); + int valueBits = value.getType().getIntOrFloatBitWidth(); + assert(valueBits <= targetBits); + + if (valueBits == 1) { + value = castBoolToIntN(loc, value, dstType, builder); + } else { + if (valueBits < targetBits) { + value = builder.create( + loc, builder.getIntegerType(targetBits), value); + } + + value = builder.create(loc, value, mask); + } + return builder.create(loc, value.getType(), value, offset); } @@ -136,17 +162,6 @@ return builder.create(loc, srcInt, one); } -/// Casts the given `srcBool` into an integer of `dstType`. -static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, - OpBuilder &builder) { - assert(srcBool.getType().isInteger(1)); - if (dstType.isInteger(1)) - return srcBool; - Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); - Value one = spirv::ConstantOp::getOne(dstType, loc, builder); - return builder.create(loc, dstType, srcBool, one, zero); -} - //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -553,7 +568,8 @@ ConversionPatternRewriter &rewriter) const { auto memrefType = cast(storeOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) - return failure(); + return rewriter.notifyMatchFailure(storeOp, + "element type is not a signless int"); auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); @@ -562,7 +578,8 @@ adaptor.getIndices(), loc, rewriter); if (!accessChain) - return failure(); + return rewriter.notifyMatchFailure( + storeOp, "failed to convert element pointer type"); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); @@ -576,23 +593,28 @@ "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); - Type dstType; + IntegerType dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { if (auto arrayType = dyn_cast(pointeeType)) - dstType = arrayType.getElementType(); + dstType = dyn_cast(arrayType.getElementType()); else - dstType = pointeeType; + dstType = dyn_cast(pointeeType); } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = cast(pointeeType).getElementType(0); if (auto arrayType = dyn_cast(structElemType)) - dstType = arrayType.getElementType(); + dstType = dyn_cast(arrayType.getElementType()); else - dstType = cast(structElemType).getElementType(); + dstType = dyn_cast( + cast(structElemType).getElementType()); } - int dstBits = dstType.getIntOrFloatBitWidth(); + if (!dstType) + return rewriter.notifyMatchFailure( + storeOp, "failed to determine destination element type"); + + int dstBits = static_cast(dstType.getWidth()); assert(dstBits % srcBits == 0); if (srcBits == dstBits) { @@ -612,17 +634,17 @@ if (!accessChainOp) return failure(); - // Since there are multi threads in the processing, the emulation will be done - // with atomic operations. E.g., if the storing value is i8, rewrite the - // StoreOp to + // Since there are multiple threads in the processing, the emulation will be + // done with atomic operations. E.g., if the stored value is i8, rewrite the + // StoreOp to: // 1) load a 32-bit integer - // 2) clear 8 bits in the loading value - // 3) store 32-bit value back - // 4) load a 32-bit integer - // 5) modify 8 bits in the loading value - // 6) store 32-bit value back - // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step - // 4 to step 6 are done by AtomicOr as another atomic step. + // 2) clear 8 bits in the loaded value + // 3) set 8 bits in the loaded value + // 4) store 32-bit value back + // + // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the + // loaded 32-bit value and the shifted 8-bit store value) as another atomic + // step. assert(accessChainOp.getIndices().size() == 2); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); @@ -635,15 +657,13 @@ rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = adaptor.getValue(); - if (isBool) - storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); - storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); + Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); std::optional scope = getAtomicOpScope(memrefType); if (!scope) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); + Value result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, clearBitsMask); @@ -740,13 +760,13 @@ ConversionPatternRewriter &rewriter) const { auto memrefType = cast(storeOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "signless int"); auto storePtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getMemref(), adaptor.getIndices(), storeOp.getLoc(), rewriter); if (!storePtr) - return failure(); + return rewriter.notifyMatchFailure(storeOp, "type conversion failed"); rewriter.replaceOpWithNewOp(storeOp, storePtr, adaptor.getValue()); diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir @@ -119,8 +119,7 @@ // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 // CHECK: %[[ONE:.+]] = spirv.Constant 1 : i32 // CHECK: %[[CASTED_ARG1:.+]] = spirv.Select %[[ARG1]], %[[ONE]], %[[ZERO]] : i1, i32 - // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 - // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CASTED_ARG1]], %[[OFFSET]] : i32, i32 // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] @@ -270,3 +269,96 @@ } } // end module + +// ----- + +// Check that we can access i8 storage with i8 types available but without +// 8-bit storage capabilities. +module attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @load_i8 +// INDEX64-LABEL: @load_i8 +func.func @load_i8(%arg0: memref>) -> i8 { + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spirv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spirv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // CHECK: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8 + // CHECK: return %[[CAST]] : i8 + + // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 + // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 + // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64 + // INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 + // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 + // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 + // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64 + // INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32 + // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32 + // INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // INDEX64: %[[CAST:.+]] = spirv.UConvert %[[SR]] : i32 to i8 + // INDEX64: return %[[CAST]] : i8 + %0 = memref.load %arg0[] : memref> + return %0 : i8 +} + +// CHECK-LABEL: @store_i8 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) +// INDEX64-LABEL: @store_i8 +// INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) +func.func @store_i8(%arg0: memref>, %value: i8) { + // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spirv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spirv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // CHECK: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + + // INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] + // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 + // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 + // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 + // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 + // INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32 + // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64 + // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // INDEX64: %[[ARG1_CAST:.+]] = spirv.UConvert %[[ARG1]] : i8 to i32 + // INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 + // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64 + // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64 + // INDEX64: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // INDEX64: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref> + return +} + +} // end module