diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -97,6 +97,46 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Returns the offset of input value in i32 representation. For example, if +// `bits` equals to 8, the x-th element is located at (x % 4) * 8. Because there +// are four elements in one i32, and one element has 8 bits. +static Value getOffsetOfInt(spirv::AccessChainOp op, int bits, + ConversionPatternRewriter &rewriter) { + assert(32 % bits == 0); + // Only works for a linearized buffer. + assert(op.indices().size() == 2); + const auto loc = op.getLoc(); + Type i32Type = rewriter.getIntegerType(32); + auto idx = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(32 / bits)); + auto bitsValue = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(bits)); + auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1); + auto m = rewriter.create(loc, lastDim, idx); + return rewriter.create(loc, i32Type, m, bitsValue); +} + +/// Returns an adjusted spirv::AccessChainOp to access corresponding i32 +/// elements. One element was a `bits`-bit integer. The method adjust the last +/// index to make it access the corresponding i32 element. Note that this only +/// works for a scalar or 1-D tensor. +static Value convertToI32AccessChain(SPIRVTypeConverter &typeConverter, + spirv::AccessChainOp op, int bits, + ConversionPatternRewriter &rewriter) { + const auto loc = op.getLoc(); + auto i32Type = rewriter.getIntegerType(32); + auto idx = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(32 / bits)); + auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1); + SmallVector indices; + for (auto it : op.indices()) + indices.push_back(it); + if (indices.size() > 1) + indices.back() = rewriter.create(loc, lastDim, idx); + Type t = typeConverter.convertType(op.component_ptr().getType()); + return rewriter.create(loc, t, op.base_ptr(), indices); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -490,10 +530,48 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); - auto loadPtr = spirv::getElementPtr( - typeConverter, loadOp.memref().getType().cast(), - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); - rewriter.replaceOpWithNewOp(loadOp, loadPtr); + auto loc = loadOp.getLoc(); + auto memrefType = loadOp.memref().getType().cast(); + spirv::AccessChainOp accessChainOp = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loc, rewriter); + + int bits = memrefType.getElementType().getIntOrFloatBitWidth(); + Type convertedType = typeConverter.convertType(memrefType.getElementType()); + int convertedBits = convertedType.getIntOrFloatBitWidth(); + + // If the loading element is not an integer or the rewrited load op has the + // same bit width, use the loading value directly. Otherwise, extract + // corresponding bits out. + Value result; + if (!convertedType.isSignlessInteger() || (bits == convertedBits)) { + Value spvLoadOp = + rewriter.create(loc, accessChainOp.getResult()); + result = spvLoadOp; + } else { + // Assume that it is unconditionally converted to i32 type. + assert(convertedBits == 32); + assert(convertedBits % bits == 0); + + // If the converted type does not have the same bit width as the base type, + // we need to adjust the index to make it access the corresponding element. + auto i32Type = rewriter.getIntegerType(32); + Value i32AccessChainOp = + convertToI32AccessChain(typeConverter, accessChainOp, bits, rewriter); + Value spvLoadOp = rewriter.create( + loc, i32Type, i32AccessChainOp, + loadOp.getAttrOfType( + spirv::attributeName()), + loadOp.getAttrOfType("alignment")); + + Value offset = getOffsetOfInt(accessChainOp, bits, rewriter); + result = rewriter.create(loc, i32Type, + spvLoadOp, offset); + auto mask = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr((1 << bits) - 1)); + result = rewriter.create(loc, i32Type, result, mask); + } + rewriter.replaceOp(loadOp, result); return success(); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -610,3 +610,40 @@ } } // end module + +// ----- + +// Check that non-32-bit integer types are converted to 32-bit types if the +// corresponding capabilities are not available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @load +func @load(%arg0: memref, %arg1: memref<10xi16>, %arg2: memref, + %arg3: memref) { + // CHECK: spv.SDiv + // CHECK: spv.AccessChain + // CHECK: spv.Load + // CHECK: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + // CHECK: spv.SDiv + // CHECK: spv.AccessChain + // CHECK: spv.Load + // CHECK: spv.ShiftRightArithmetic + %cst0 = constant 0 : index + %1 = load %arg1[%cst0] : memref<10xi16> + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %2 = load %arg2[] : memref + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %3 = load %arg3[] : memref + return +} + +} // end module