diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -146,6 +146,15 @@ return builder.create(loc, t, op.base_ptr(), indices); } +/// Returns the shifted `targetBits`-bit value with the given offset. +Value shiftValue(Location loc, Value value, Value offset, Value mask, + int targetBits, OpBuilder &builder) { + Type targetType = builder.getIntegerType(targetBits); + Value result = builder.create(loc, value, mask); + return builder.create(loc, targetType, result, + offset); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -292,6 +301,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.store to spv.Store on integers. +class IntStoreOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(StoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts std.store to spv.Store. class StoreOpPattern final : public SPIRVOpLowering { public: @@ -697,13 +716,91 @@ //===----------------------------------------------------------------------===// LogicalResult +IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + StoreOpOperandAdaptor storeOperands(operands); + auto memrefType = storeOp.memref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + + auto loc = storeOp.getLoc(); + spirv::AccessChainOp accessChainOp = + spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), + storeOperands.indices(), loc, rewriter); + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + auto dstType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType() + .cast() + .getElementType(0) + .cast() + .getElementType(); + int dstBits = dstType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + if (srcBits == dstBits) { + rewriter.replaceOpWithNewOp( + storeOp, accessChainOp.getResult(), storeOperands.value()); + return success(); + } + + // Since there are multi threads in the processing, the emulation will be done + // with atomic operations. E.g., if the storing value is i8, rewrite the + // StoreOp to + // 1) load a 32-bit integer + // 2) clear 8 bits in the loading value + // 3) store 32-bit value back + // 4) load a 32-bit integer + // 5) modify 8 bits in the loading value + // 6) store 32-bit value back + // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step + // 4 to step 6 are done by AtomicOr as another atomic step. + assert(accessChainOp.indices().size() == 2); + Value lastDim = accessChainOp.getOperation()->getOperand( + accessChainOp.getNumOperands() - 1); + Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); + + // Create a mask to clear the destination. E.g., if it is the second i8 in + // i32, 0xFFFF00FF is created. + Value mask = rewriter.create( + loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); + Value clearBitsMask = + rewriter.create(loc, dstType, mask, offset); + clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); + + Value storeVal = + shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); + Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, + srcBits, dstBits, rewriter); + Value result = rewriter.create( + loc, dstType, adjustedPtr, spirv::Scope::Device, + spirv::MemorySemantics::AcquireRelease, clearBitsMask); + result = rewriter.create( + loc, dstType, adjustedPtr, spirv::Scope::Device, + spirv::MemorySemantics::AcquireRelease, storeVal); + + // The AtomicOrOp has no side effect. Since it is already inserted, we can + // just remove the original StoreOp. Note that rewriter.replaceOp() + // doesn't work because it only accepts that the numbers of result are the + // same. + rewriter.eraseOp(storeOp); + + assert(accessChainOp.use_empty()); + rewriter.eraseOp(accessChainOp); + + return success(); +} + +LogicalResult StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); - auto storePtr = spirv::getElementPtr( - typeConverter, storeOp.memref().getType().cast(), - storeOperands.memref(), storeOperands.indices(), storeOp.getLoc(), - rewriter); + auto memrefType = storeOp.memref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto storePtr = + spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), + storeOperands.indices(), storeOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); return success(); @@ -769,7 +866,7 @@ BitwiseOpPattern, BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern, - ReturnOpPattern, SelectOpPattern, StoreOpPattern, + ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -654,7 +654,7 @@ // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. -// TODO: Test i64 type. +// TODO: Test i1 and i64 types. module attributes { spv.target_env = #spv.target_env< #spv.vce, @@ -719,6 +719,70 @@ return } +// CHECK-LABEL: @store_i8 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) +func @store_i8(%arg0: memref, %value: i8) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i16 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) +func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { + // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 + // CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[TWO:.+]] = spv.constant 2 : i32 + // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 + // CHECK: %[[MASK1:.+]] = spv.constant 65535 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG2]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[TWO2:.+]] = spv.constant 2 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + store %value, %arg0[%index] : memref<10xi16> + return +} + +// CHECK-LABEL: @store_i32 +func @store_i32(%arg0: memref, %value: i32) { + // CHECK: spv.Store + // CHECK-NOT: spv.AtomicAnd + // CHECK-NOT: spv.AtomicOr + store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_f32 +func @store_f32(%arg0: memref, %value: f32) { + // CHECK: spv.Store + // CHECK-NOT: spv.AtomicAnd + // CHECK-NOT: spv.AtomicOr + store %value, %arg0[] : memref + return +} + } // end module // ----- @@ -760,4 +824,35 @@ return } +// CHECK-LABEL: @store_i8 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) +func @store_i8(%arg0: memref, %value: i8) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + store %value, %arg0[] : memref + return +} + +// CHECK-LABEL: @store_i16 +func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { + // CHECK: spv.Store + // CHECK-NOT: spv.AtomicAnd + // CHECK-NOT: spv.AtomicOr + store %value, %arg0[%index] : memref<10xi16> + return +} + } // end module