diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" @@ -115,8 +114,14 @@ // Type Conversion //===----------------------------------------------------------------------===// +static spirv::ScalarType getIndexType(MLIRContext *ctx, + const SPIRVConversionOptions &options) { + return cast( + IntegerType::get(ctx, options.use64bitIndex ? 64 : 32)); +} + Type SPIRVTypeConverter::getIndexType() const { - return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32); + return ::getIndexType(getContext(), options); } MLIRContext *SPIRVTypeConverter::getContext() const { @@ -242,12 +247,32 @@ intType.getSignedness()); } +/// Returns a type with the same shape but with any index element type converted +/// to the matching integer type. This is a noop when the element type is not +/// the index type. +static ShapedType +convertIndexElementType(ShapedType type, + const SPIRVConversionOptions &options) { + Type indexType = dyn_cast(type.getElementType()); + if (!indexType) + return type; + + return type.clone(getIndexType(type.getContext(), options)); +} + /// Converts a vector `type` to a suitable type under the given `targetEnv`. static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional storageClass = {}) { - auto scalarType = type.getElementType().cast(); + type = cast(convertIndexElementType(type, options)); + auto scalarType = dyn_cast_or_null(type.getElementType()); + if (!scalarType) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert non-scalar element type\n"); + return nullptr; + } + if (type.getRank() <= 1 && type.getNumElements() == 1) return convertScalarType(targetEnv, options, scalarType, storageClass); @@ -290,7 +315,8 @@ return nullptr; } - auto scalarType = type.getElementType().dyn_cast(); + type = cast(convertIndexElementType(type, options)); + auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert non-scalar element type\n"); @@ -396,6 +422,9 @@ } else if (auto scalarType = elementType.dyn_cast()) { arrayElemType = convertScalarType(targetEnv, options, scalarType, storageClass); + } else if (auto indexType = elementType.dyn_cast()) { + type = convertIndexElementType(type, options).cast(); + arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -115,7 +115,8 @@ // Index type //===----------------------------------------------------------------------===// -// The index type is always converted into i32. +// The index type is always converted into i32 or i64, with i32 being the +// default. module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { @@ -223,6 +224,10 @@ // CHECK-SAME: %{{.+}}: i32 func.func @one_element_vector(%arg0: vector<1xi8>) { return } +// CHECK-LABEL: spirv.func @index_vector +// CHECK-SAME: %{{.*}}: vector<4xi32> +func.func @index_vector(%arg0: vector<4xindex>) { return } + } // end module // ----- @@ -313,6 +318,14 @@ %arg1: memref<4x8xi1, #spirv.storage_class> ) { return } +// CHECK-LABEL: func @memref_index_type +// CHECK-SAME: !spirv.ptr [0])>, StorageBuffer> +// CHECK-SAME: !spirv.ptr)>, Function> +func.func @memref_index_type( + %arg0: memref<4xindex, #spirv.storage_class>, + %arg1: memref<4xindex, #spirv.storage_class> +) { return } + } // end module // ----- @@ -819,6 +832,11 @@ %arg2: tensor<8x4xf16> ) { return } + +// CHECK-LABEL: spirv.func @index_tensor_type +// CHECK-SAME: %{{.*}}: !spirv.array<20 x i32> +func.func @index_tensor_type(%arg0: tensor<4x5xindex>) { return } + } // end module // -----