diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -106,6 +106,7 @@ static bool classof(Type type); /// Returns true if the given vector type is valid for the SPIR-V dialect. + static bool hasValidSize(VectorType); static bool isValid(VectorType); /// Return the number of elements of the type. This should only be called if diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -100,18 +100,22 @@ spirv::StructType>(type); } -bool CompositeType::isValid(VectorType type) { +bool CompositeType::hasValidSize(VectorType type) { switch (type.getNumElements()) { case 2: case 3: case 4: case 8: case 16: - break; + return true; default: return false; } - return type.getRank() == 1 && llvm::isa(type.getElementType()); +} + +bool CompositeType::isValid(VectorType type) { + return type.getRank() == 1 && hasValidSize(type) && + llvm::isa(type.getElementType()); } Type CompositeType::getElementType(unsigned index) const { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -303,9 +303,27 @@ type = cast(convertIndexElementType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert non-scalar element type\n"); - return nullptr; + // If this is not a spec allowed scalar type, try to handle sub-byte integer + // types. + auto intType = dyn_cast(type.getElementType()); + if (!intType) { + LLVM_DEBUG(llvm::dbgs() + << type + << " illegal: cannot convert non-scalar element type\n"); + return nullptr; + } + + Type elementType = convertSubByteIntegerType(options, intType); + if (type.getRank() <= 1 && type.getNumElements() == 1) + return elementType; + + if (!spirv::CompositeType::hasValidSize(type)) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: > 4-element unimplemented\n"); + return nullptr; + } + + return VectorType::get(type.getShape(), elementType); } if (type.getRank() <= 1 && type.getNumElements() == 1) diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -1000,9 +1000,9 @@ return %0: f64 } -// CHECK-LABEL: @trunci4 +// CHECK-LABEL: @trunci4_scalar // CHECK-SAME: %[[ARG:.*]]: i32 -func.func @trunci4(%arg0 : i32) -> i4 { +func.func @trunci4_scalar(%arg0 : i32) -> i4 { // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32 %0 = arith.trunci %arg0 : i32 to i4 @@ -1011,8 +1011,19 @@ return %0 : i4 } -// CHECK-LABEL: @zexti4 -func.func @zexti4(%arg0: i4) -> i32 { +// CHECK-LABEL: @trunci4_vector +// CHECK-SAME: %[[ARG:.*]]: vector<2xi32> +func.func @trunci4_vector(%arg0 : vector<2xi32>) -> vector<2xi4> { + // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<2xi32> + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : vector<2xi32> + %0 = arith.trunci %arg0 : vector<2xi32> to vector<2xi4> + // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : vector<2xi32> to vector<2xi4> + // CHECK: return %[[RET]] : vector<2xi4> + return %0 : vector<2xi4> +} + +// CHECK-LABEL: @zexti4_scalar +func.func @zexti4_scalar(%arg0: i4) -> i32 { // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32 // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32 @@ -1021,8 +1032,18 @@ return %0 : i32 } -// CHECK-LABEL: @sexti4 -func.func @sexti4(%arg0: i4) -> i32 { +// CHECK-LABEL: @zexti4_vector +func.func @zexti4_vector(%arg0: vector<3xi4>) -> vector<3xi32> { + // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<3xi4> to vector<3xi32> + // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<3xi32> + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : vector<3xi32> + %0 = arith.extui %arg0 : vector<3xi4> to vector<3xi32> + // CHECK: return %[[AND]] : vector<3xi32> + return %0 : vector<3xi32> +} + +// CHECK-LABEL: @sexti4_scalar +func.func @sexti4_scalar(%arg0: i4) -> i32 { // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32 // CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32 // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32 @@ -1032,6 +1053,17 @@ return %0 : i32 } +// CHECK-LABEL: @sexti4_vector +func.func @sexti4_vector(%arg0: vector<4xi4>) -> vector<4xi32> { + // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<4xi4> to vector<4xi32> + // CHECK: %[[SIZE:.+]] = spirv.Constant dense<28> : vector<4xi32> + // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : vector<4xi32>, vector<4xi32> + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : vector<4xi32>, vector<4xi32> + %0 = arith.extsi %arg0 : vector<4xi4> to vector<4xi32> + // CHECK: return %[[SR]] : vector<4xi32> + return %0 : vector<4xi32> +} + } // end module // ----- diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -105,6 +105,9 @@ // CHECK: spirv.func @integer4(%{{.+}}: i32) func.func @integer4(%arg0: i4) { return } +// CHECK: spirv.func @v3i4(%{{.+}}: vector<3xi32>) +func.func @v3i4(%arg0: vector<3xi4>) { return } + } // end module // -----