diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -97,6 +97,55 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } +/// Returns the offset of the value in `targetBits` representation. `srcIdx` is +/// an index into a 1-D array with each element having `sourceBits`. When +/// accessing an element in the array treating as having elements of +/// `targetBits`, multiple values are loaded in the same time. The method +/// returns the offset where the `srcIdx` locates in the value. For example, if +/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is +/// located at (x % 4) * 8. Because there are four elements in one i32, and one +/// element has 8 bits. +static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr idxAttr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); + auto srcBitsValue = + builder.create(loc, targetType, srcBitsAttr); + auto m = builder.create(loc, srcIdx, idx); + return builder.create(loc, targetType, m, srcBitsValue); +} + +/// Returns an adjusted spirv::AccessChainOp. Based on the +/// extension/capabilities, certain integer bitwidths `sourceBits` might not be +/// supported. During conversion if a memref of an unsupported type is used, +/// load/stores to this memref need to be modified to use a supported higher +/// bitwidth `targetBits` and extracting the required bits. For an accessing a +/// 1D array (spv.array or spv.rt_array), the last index is modified to load the +/// bits needed. The extraction of the actual bits needed are handled +/// separately. Note that this only works for a 1-D tensor. +static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter, + spirv::AccessChainOp op, + int sourceBits, int targetBits, + OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + const auto loc = op.getLoc(); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr attr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, attr); + auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1); + auto indices = llvm::to_vector<4>(op.indices()); + // There are two elements if this is a 1-D tensor. + assert(indices.size() == 2); + indices.back() = builder.create(loc, lastDim, idx); + Type t = typeConverter.convertType(op.component_ptr().getType()); + return builder.create(loc, t, op.base_ptr(), indices); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -205,6 +254,16 @@ }; /// Converts std.load to spv.Load. +class IntLoadOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts std.load to spv.Load. class LoadOpPattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -529,12 +588,78 @@ //===----------------------------------------------------------------------===// LogicalResult +IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + LoadOpOperandAdaptor loadOperands(operands); + auto loc = loadOp.getLoc(); + auto memrefType = loadOp.memref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + spirv::AccessChainOp accessChainOp = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loc, rewriter); + + int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + auto dstType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType() + .cast() + .getElementType(0) + .cast() + .getElementType(); + int dstBits = dstType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + // If the rewrited load op has the same bit width, use the loading value + // directly. + if (srcBits == dstBits) { + rewriter.replaceOpWithNewOp(loadOp, + accessChainOp.getResult()); + return success(); + } + + // Assume that getElementPtr() works linearizely. If it's a scalar, the method + // still returns a linearized accessing. If the accessing is not linearized, + // there will be offset issues. + assert(accessChainOp.indices().size() == 2); + Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, + srcBits, dstBits, rewriter); + Value spvLoadOp = rewriter.create( + loc, dstType, adjustedPtr, + loadOp.getAttrOfType( + spirv::attributeName()), + loadOp.getAttrOfType("alignment")); + + // Shift the bits to the rightmost. + // ____XXXX________ -> ____________XXXX + Value lastDim = accessChainOp.getOperation()->getOperand( + accessChainOp.getNumOperands() - 1); + Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); + Value result = rewriter.create( + loc, spvLoadOp.getType(), spvLoadOp, offset); + + // Apply the mask to extract corresponding bits. + Value mask = rewriter.create( + loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); + result = rewriter.create(loc, dstType, result, mask); + rewriter.replaceOp(loadOp, result); + + assert(accessChainOp.use_empty()); + rewriter.eraseOp(accessChainOp); + + return success(); +} + +LogicalResult LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); - auto loadPtr = spirv::getElementPtr( - typeConverter, loadOp.memref().getType().cast(), - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + auto memrefType = loadOp.memref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto loadPtr = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } @@ -642,8 +767,8 @@ BitwiseOpPattern, BitwiseOpPattern, BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, - CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern, - SelectOpPattern, StoreOpPattern, + CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern, + ReturnOpPattern, SelectOpPattern, StoreOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, XOrOpPattern>( diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -619,3 +619,115 @@ } } // end module + +// ----- + +// Check that access chain indices are properly adjusted if non-32-bit types are +// emulated via 32-bit types. +// TODO: Test i64 type. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @load_i8 +func @load_i8(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 + // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i16 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) +func @load_i16(%arg0: memref<10xi16>, %index : index) { + // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 + // CHECK: %[[FLAT_IDX:.+]] = spv.IMul %[[ONE]], %[[ARG1]] : i32 + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[TWO1:.+]] = spv.constant 2 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[FLAT_IDX]], %[[TWO1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[TWO2:.+]] = spv.constant 2 : i32 + // CHECK: %[[SIXTEEN:.+]] = spv.constant 16 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[FLAT_IDX]], %[[TWO2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[SIXTEEN]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 + // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + %0 = load %arg0[%index] : memref<10xi16> + return +} + +// CHECK-LABEL: @load_i32 +func @load_i32(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_f32 +func @load_f32(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + return +} + +} // end module + +// ----- + +// Check that access chain indices are properly adjusted if non-16/32-bit types +// are emulated via 32-bit types. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @load_i8 +func @load_i8(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 + // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i16 +func @load_i16(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + return +} + +} // end module