diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -101,17 +101,9 @@ } bool CompositeType::isValid(VectorType type) { - switch (type.getNumElements()) { - case 2: - case 3: - case 4: - case 8: - case 16: - break; - default: - return false; - } - return type.getRank() == 1 && llvm::isa(type.getElementType()); + return type.getRank() == 1 && + llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && + llvm::isa(type.getElementType()); } Type CompositeType::getElementType(unsigned index) const { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" @@ -303,16 +304,35 @@ type = cast(convertIndexElementType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { - LLVM_DEBUG(llvm::dbgs() - << type << " illegal: cannot convert non-scalar element type\n"); - return nullptr; + // If this is not a spec allowed scalar type, try to handle sub-byte integer + // types. + auto intType = dyn_cast(type.getElementType()); + if (!intType) { + LLVM_DEBUG(llvm::dbgs() + << type + << " illegal: cannot convert non-scalar element type\n"); + return nullptr; + } + + Type elementType = convertSubByteIntegerType(options, intType); + if (type.getRank() <= 1 && type.getNumElements() == 1) + return elementType; + + if (type.getNumElements() > 4) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: > 4-element unimplemented\n"); + return nullptr; + } + + return VectorType::get(type.getShape(), elementType); } if (type.getRank() <= 1 && type.getNumElements() == 1) return convertScalarType(targetEnv, options, scalarType, storageClass); if (!spirv::CompositeType::isValid(type)) { - LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: not a valid composite type\n"); return nullptr; } diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -990,9 +990,9 @@ return %0: f64 } -// CHECK-LABEL: @trunci4 +// CHECK-LABEL: @trunci4_scalar // CHECK-SAME: %[[ARG:.*]]: i32 -func.func @trunci4(%arg0 : i32) -> i4 { +func.func @trunci4_scalar(%arg0 : i32) -> i4 { // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32 %0 = arith.trunci %arg0 : i32 to i4 @@ -1001,8 +1001,19 @@ return %0 : i4 } -// CHECK-LABEL: @zexti4 -func.func @zexti4(%arg0: i4) -> i32 { +// CHECK-LABEL: @trunci4_vector +// CHECK-SAME: %[[ARG:.*]]: vector<2xi32> +func.func @trunci4_vector(%arg0 : vector<2xi32>) -> vector<2xi4> { + // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<2xi32> + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : vector<2xi32> + %0 = arith.trunci %arg0 : vector<2xi32> to vector<2xi4> + // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : vector<2xi32> to vector<2xi4> + // CHECK: return %[[RET]] : vector<2xi4> + return %0 : vector<2xi4> +} + +// CHECK-LABEL: @zexti4_scalar +func.func @zexti4_scalar(%arg0: i4) -> i32 { // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32 // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32 @@ -1011,8 +1022,18 @@ return %0 : i32 } -// CHECK-LABEL: @sexti4 -func.func @sexti4(%arg0: i4) -> i32 { +// CHECK-LABEL: @zexti4_vector +func.func @zexti4_vector(%arg0: vector<3xi4>) -> vector<3xi32> { + // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : vector<3xi4> to vector<3xi32> + // CHECK: %[[MASK:.+]] = spirv.Constant dense<15> : vector<3xi32> + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : vector<3xi32> + %0 = arith.extui %arg0 : vector<3xi4> to vector<3xi32> + // CHECK: return %[[AND]] : vector<3xi32> + return %0 : vector<3xi32> +} + +// CHECK-LABEL: @sexti4_scalar +func.func @sexti4_scalar(%arg0: i4) -> i32 { // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32 // CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32 // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32 @@ -1022,6 +1043,17 @@ return %0 : i32 } +// CHECK-LABEL: @sexti4_vector +func.func @sexti4_vector(%arg0: vector<4xi4>) -> vector<4xi32> { + // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<4xi4> to vector<4xi32> + // CHECK: %[[SIZE:.+]] = spirv.Constant dense<28> : vector<4xi32> + // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : vector<4xi32>, vector<4xi32> + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : vector<4xi32>, vector<4xi32> + %0 = arith.extsi %arg0 : vector<4xi4> to vector<4xi32> + // CHECK: return %[[SR]] : vector<4xi32> + return %0 : vector<4xi32> +} + } // end module // ----- diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -105,6 +105,9 @@ // CHECK: spirv.func @integer4(%{{.+}}: i32) func.func @integer4(%arg0: i4) { return } +// CHECK: spirv.func @v3i4(%{{.+}}: vector<3xi32>) +func.func @v3i4(%arg0: vector<3xi4>) { return } + } // end module // -----