diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -739,7 +739,10 @@ let options = [ Option<"boolNumBits", "bool-num-bits", "int", /*default=*/"8", - "The number of bits to store a boolean value"> + "The number of bits to store a boolean value">, + Option<"use64bitIndex", "use-64bit-index", + "bool", /*default=*/"false", + "Use 64-bit integers to convert index types"> ]; } diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -42,15 +42,13 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); - IntegerType targetType = builder.getIntegerType(targetBits); - IntegerAttr idxAttr = - builder.getIntegerAttr(targetType, targetBits / sourceBits); - auto idx = builder.create(loc, targetType, idxAttr); - IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); - auto srcBitsValue = - builder.create(loc, targetType, srcBitsAttr); + Type type = srcIdx.getType(); + IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); + auto idx = builder.create(loc, type, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); + auto srcBitsValue = builder.create(loc, type, srcBitsAttr); auto m = builder.create(loc, srcIdx, idx); - return builder.create(loc, targetType, m, srcBitsValue); + return builder.create(loc, type, m, srcBitsValue); } /// Returns an adjusted spirv::AccessChainOp. Based on the @@ -58,7 +56,7 @@ /// supported. During conversion if a memref of an unsupported type is used, /// load/stores to this memref need to be modified to use a supported higher /// bitwidth `targetBits` and extracting the required bits. For an accessing a -/// 1D array (spirv.array or spirv.rt_array), the last index is modified to load +/// 1D array (spirv.array or spirv.rtarray), the last index is modified to load /// the bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. static Value @@ -67,11 +65,10 @@ int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); const auto loc = op.getLoc(); - IntegerType targetType = builder.getIntegerType(targetBits); - IntegerAttr attr = - builder.getIntegerAttr(targetType, targetBits / sourceBits); - auto idx = builder.create(loc, targetType, attr); - auto lastDim = op->getOperand(op.getNumOperands() - 1); + Value lastDim = op->getOperand(op.getNumOperands() - 1); + Type type = lastDim.getType(); + IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); + auto idx = builder.create(loc, type, attr); auto indices = llvm::to_vector<4>(op.getIndices()); // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); @@ -83,9 +80,8 @@ /// Returns the shifted `targetBits`-bit value with the given offset. static Value shiftValue(Location loc, Value value, Value offset, Value mask, int targetBits, OpBuilder &builder) { - Type targetType = builder.getIntegerType(targetBits); Value result = builder.create(loc, value, mask); - return builder.create(loc, targetType, result, + return builder.create(loc, value.getType(), result, offset); } diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp @@ -41,6 +41,7 @@ SPIRVConversionOptions options; options.boolNumBits = this->boolNumBits; + options.use64bitIndex = this->use64bitIndex; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull in diff --git a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/bitwidth-emulation.mlir @@ -1,11 +1,12 @@ // RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" -cse %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8 use-64bit-index" -cse %s -o - | FileCheck %s --check-prefix=INDEX64 // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. // TODO: Test i64 types. module attributes { spirv.target_env = #spirv.target_env< - #spirv.vce, #spirv.resource_limits<>> + #spirv.vce, #spirv.resource_limits<>> } { // CHECK-LABEL: @load_i1 @@ -33,6 +34,7 @@ } // CHECK-LABEL: @load_i8 +// INDEX64-LABEL: @load_i8 func.func @load_i8(%arg0: memref>) -> i8 { // CHECK: %[[ZERO:.+]] = spirv.Constant 0 : i32 // CHECK: %[[FOUR:.+]] = spirv.Constant 4 : i32 @@ -49,6 +51,22 @@ // CHECK: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 // CHECK: builtin.unrealized_conversion_cast %[[SR]] + + // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 + // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 + // INDEX64: %[[QUOTIENT:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] : {{.+}}, i64, i64 + // INDEX64: %[[LOAD:.+]] = spirv.Load "StorageBuffer" %[[PTR]] : i32 + // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 + // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[BITS:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 + // INDEX64: %[[VALUE:.+]] = spirv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i64 + // INDEX64: %[[MASK:.+]] = spirv.Constant 255 : i32 + // INDEX64: %[[T1:.+]] = spirv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // INDEX64: %[[T2:.+]] = spirv.Constant 24 : i32 + // INDEX64: %[[T3:.+]] = spirv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // INDEX64: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // INDEX64: builtin.unrealized_conversion_cast %[[SR]] %0 = memref.load %arg0[] : memref> return %0 : i8 } @@ -113,6 +131,8 @@ // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) +// INDEX64-LABEL: @store_i8 +// INDEX64: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) func.func @store_i8(%arg0: memref>, %value: i8) { // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] @@ -130,6 +150,23 @@ // CHECK: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] // CHECK: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] // CHECK: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + + // INDEX64-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 + // INDEX64-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] + // INDEX64: %[[ZERO:.+]] = spirv.Constant 0 : i64 + // INDEX64: %[[FOUR:.+]] = spirv.Constant 4 : i64 + // INDEX64: %[[EIGHT:.+]] = spirv.Constant 8 : i64 + // INDEX64: %[[IDX:.+]] = spirv.UMod %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[OFFSET:.+]] = spirv.IMul %[[IDX]], %[[EIGHT]] : i64 + // INDEX64: %[[MASK1:.+]] = spirv.Constant 255 : i32 + // INDEX64: %[[TMP1:.+]] = spirv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i64 + // INDEX64: %[[MASK:.+]] = spirv.Not %[[TMP1]] : i32 + // INDEX64: %[[CLAMPED_VAL:.+]] = spirv.BitwiseAnd %[[ARG1_CAST]], %[[MASK1]] : i32 + // INDEX64: %[[STORE_VAL:.+]] = spirv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i64 + // INDEX64: %[[ACCESS_IDX:.+]] = spirv.SDiv %[[ZERO]], %[[FOUR]] : i64 + // INDEX64: %[[PTR:.+]] = spirv.AccessChain %[[ARG0_CAST]][%[[ZERO]], %[[ACCESS_IDX]]] : {{.+}}, i64, i64 + // INDEX64: spirv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // INDEX64: spirv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] memref.store %value, %arg0[] : memref> return } @@ -179,7 +216,7 @@ // emulated via 32-bit types. module attributes { spirv.target_env = #spirv.target_env< - #spirv.vce, #spirv.resource_limits<>> + #spirv.vce, #spirv.resource_limits<>> } { // CHECK-LABEL: @load_i4