diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -97,6 +97,48 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } +/// Returns the offset of input value in `targetBits` integer representation. +/// For example, if `elementBits` equals to 8 and `targetBits` equals to 32, the +/// x-th element is located at (x % 4) * 8. Because there are four elements in +/// one i32, and one element has 8 bits. +static Value getOffsetOfInt(Location loc, Value lastDim, int elementBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % elementBits == 0); + // Only works for a linearized buffer. + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr attr = + builder.getIntegerAttr(targetType, targetBits / elementBits); + auto idx = builder.create(loc, targetType, attr); + auto elemBitsValue = builder.create(loc, targetType, attr); + auto m = builder.create(loc, lastDim, idx); + return builder.create(loc, targetType, m, elemBitsValue); +} + +/// Returns an adjusted spirv::AccessChainOp to access corresponding +/// `targetBits` integer representation elements. One element was a +/// `elementBits`-bit integer. The method adjust the last index to make it +/// access the corresponding `elementBits`-bit integer element. Note that this +/// only works for a scalar or 1-D tensor. +static Value convertToTargetAccessChain(SPIRVTypeConverter &typeConverter, + spirv::AccessChainOp op, + int elementBits, int targetBits, + OpBuilder &builder) { + assert(targetBits % elementBits == 0); + const auto loc = op.getLoc(); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr attr = + builder.getIntegerAttr(targetType, targetBits / elementBits); + auto idx = builder.create(loc, targetType, attr); + auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1); + auto indices = llvm::to_vector<4>(op.indices()); + // There are two elements if this is a 1-D tensor. + assert(indices.size() <= 2); + if (indices.size() > 1) + indices.back() = builder.create(loc, lastDim, idx); + Type t = typeConverter.convertType(op.component_ptr().getType()); + return builder.create(loc, t, op.base_ptr(), indices); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -204,6 +246,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.load to spv.Load. +class IntLoadOpPattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts std.load to spv.Load. class LoadOpPattern final : public SPIRVOpLowering { public: @@ -528,13 +580,70 @@ // LoadOp //===----------------------------------------------------------------------===// +LogicalResult +IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + LoadOpOperandAdaptor loadOperands(operands); + auto loc = loadOp.getLoc(); + auto memrefType = loadOp.memref().getType().cast(); + if (!memrefType.getElementType().isSignlessInteger()) + return failure(); + spirv::AccessChainOp accessChainOp = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loc, rewriter); + + int bits = memrefType.getElementType().getIntOrFloatBitWidth(); + auto convertedType = typeConverter.convertType(memrefType) + .cast() + .getPointeeType() + .cast() + .getElementType(0) + .cast() + .getElementType(); + int convertedBits = convertedType.getIntOrFloatBitWidth(); + + // If the rewrited load op has the same bit width, use the loading value + // directly. Otherwise, extract corresponding bits out. + Value result; + if (bits == convertedBits) { + result = rewriter.create(loc, accessChainOp.getResult()); + } else { + // If the converted type does not have the same bit width as the base type, + // we need to adjust the index to make it access the corresponding element. + assert(convertedBits % bits == 0); + Value i32AccessChainOp = convertToTargetAccessChain( + typeConverter, accessChainOp, bits, convertedBits, rewriter); + Value spvLoadOp = rewriter.create( + loc, convertedType, i32AccessChainOp, + loadOp.getAttrOfType( + spirv::attributeName()), + loadOp.getAttrOfType("alignment")); + + Value lastDim = accessChainOp.getOperation()->getOperand( + accessChainOp.getNumOperands() - 1); + Value offset = getOffsetOfInt(loc, lastDim, bits, convertedBits, rewriter); + result = rewriter.create(loc, convertedType, + spvLoadOp, offset); + auto mask = rewriter.create( + loc, convertedType, + rewriter.getIntegerAttr(convertedType, (1 << bits) - 1)); + result = + rewriter.create(loc, convertedType, result, mask); + } + rewriter.replaceOp(loadOp, result); + return success(); +} + LogicalResult LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); - auto loadPtr = spirv::getElementPtr( - typeConverter, loadOp.memref().getType().cast(), - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + auto memrefType = loadOp.memref().getType().cast(); + if (memrefType.getElementType().isSignlessInteger()) + return failure(); + auto loadPtr = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } @@ -642,8 +751,8 @@ BitwiseOpPattern, BitwiseOpPattern, BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, - CmpFOpPattern, CmpIOpPattern, LoadOpPattern, ReturnOpPattern, - SelectOpPattern, StoreOpPattern, + CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern, + ReturnOpPattern, SelectOpPattern, StoreOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, XOrOpPattern>( diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -619,3 +619,112 @@ } } // end module + +// ----- + +// Check that access chain indices are properly adjusted if non-32-bit types are +// emulated via 32-bit types. +// TODO: Test i64 type. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @load_i8 +func @load_i8(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[FOUR3:.+]] = spv.constant 4 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[FOUR3]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 + // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i16 +func @load_i16(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[TWO1:.+]] = spv.constant 2 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[TWO1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[TWO2:.+]] = spv.constant 2 : i32 + // CHECK: %[[TWO3:.+]] = spv.constant 2 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[TWO2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[TWO3]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 65535 : i32 + // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i32 +func @load_i32(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_f32 +func @load_f32(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + return +} + +} // end module + +// ----- + +// Check that access chain indices are properly adjusted if non-16/32-bit types +// are emulated via 32-bit types. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @load_i8 +func @load_i8(%arg0: memref) { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.constant 4 : i32 + // CHECK: %[[FOUR3:.+]] = spv.constant 4 : i32 + // CHECK: %[[IDX:.+]] = spv.SMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[FOUR3]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.constant 255 : i32 + // CHECK: spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + %0 = load %arg0[] : memref + return +} + +// CHECK-LABEL: @load_i16 +func @load_i16(%arg0: memref) { + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + return +} + +} // end module