diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -97,6 +97,25 @@ return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Returns the offset of input value in i32 representation. For example, if +// `bits` equals to 8, the x-th element is located at (x % 4) * 8. Because there +// are four elements in one i32, and one element has 8 bits. +static Value getOffsetOfInt(spirv::AccessChainOp op, int bits, + ConversionPatternRewriter &rewriter) { + assert(32 % bits == 0); + // Only works for a linearized buffer. + assert(op.indices().size() == 2); + const auto loc = op.getLoc(); + Type i32Type = rewriter.getIntegerType(32); + auto idx = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(32 / bits)); + auto bitsValue = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(bits)); + auto lastDim = op.getOperation()->getOperand(op.getNumOperands() - 1); + auto m = rewriter.create(loc, lastDim, idx); + return rewriter.create(loc, i32Type, m, bitsValue); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -490,10 +509,30 @@ LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); - auto loadPtr = spirv::getElementPtr( - typeConverter, loadOp.memref().getType().cast(), - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); - rewriter.replaceOpWithNewOp(loadOp, loadPtr); + auto loc = loadOp.getLoc(); + auto memrefType = loadOp.memref().getType().cast(); + spirv::AccessChainOp loadPtr = + spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), + loadOperands.indices(), loc, rewriter); + Value spvLoadOp = rewriter.create(loc, loadPtr.getResult()); + + // If the loading element is not an integer or the rewrited load op has the + // same bit width, use the loading value directly. Otherwise, extract + // corresponding bits out. + Value result; + auto bits = memrefType.getElementType().getIntOrFloatBitWidth(); + Type convertedType = typeConverter.convertType(memrefType.getElementType()); + if (!convertedType.isSignlessInteger() || + (bits == convertedType.getIntOrFloatBitWidth())) { + result = spvLoadOp; + } else { + // Assume that it is unconditionally converted to i32 type. + assert(convertedType.getIntOrFloatBitWidth() == 32); + Value offset = getOffsetOfInt(loadPtr, bits, rewriter); + result = rewriter.create( + loc, rewriter.getIntegerType(32), spvLoadOp, offset); + } + rewriter.replaceOp(loadOp, result); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -602,6 +602,18 @@ if (!ptrLoc) { ptrLoc = zero; } + // If the converted type does not have the same bit width as the base type, we + // need to adjust the index to make it access the corresponding element. + int bits = baseType.getElementType().getIntOrFloatBitWidth(); + int convertedBits = typeConverter.convertType(baseType.getElementType()) + .getIntOrFloatBitWidth(); + if (bits != convertedBits) { + assert(convertedBits % bits == 0); + auto divisor = builder.create( + loc, builder.getIntegerType(convertedBits), + builder.getI32IntegerAttr(convertedBits / bits)); + ptrLoc = builder.create(loc, ptrLoc, divisor); + } linearizedIndices.push_back(ptrLoc); return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -610,3 +610,40 @@ } } // end module + +// ----- + +// Check that non-32-bit integer types are converted to 32-bit types if the +// corresponding capabilities are not available. +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK-LABEL: @load +func @load(%arg0: memref, %arg1: memref<10xi16>, %arg2: memref, + %arg3: memref) { + // CHECK: spv.SDiv + // CHECK: spv.AccessChain + // CHECK: spv.Load + // CHECK: spv.ShiftRightArithmetic + %0 = load %arg0[] : memref + // CHECK: spv.SDiv + // CHECK: spv.AccessChain + // CHECK: spv.Load + // CHECK: spv.ShiftRightArithmetic + %cst0 = constant 0 : index + %1 = load %arg1[%cst0] : memref<10xi16> + // CHECK-NOT: spv.SDiv + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %2 = load %arg2[] : memref + // CHECK: spv.Load + // CHECK-NOT: spv.ShiftRightArithmetic + %3 = load %arg3[] : memref + return +} + +} // end module