diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -23,6 +23,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" #include using namespace mlir; @@ -33,6 +34,13 @@ return (*attr.getAsValueRange().begin()).getZExtValue(); } +/// Returns the number of bits for the given scalar/vector type. +static int getNumBits(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.cast().getSizeInBits(); + return type.getIntOrFloatBitWidth(); +} + namespace { struct VectorBitcastConvert final @@ -46,12 +54,24 @@ if (!dstType) return failure(); - if (dstType == adaptor.getSource().getType()) + if (dstType == adaptor.getSource().getType()) { rewriter.replaceOp(bitcastOp, adaptor.getSource()); - else - rewriter.replaceOpWithNewOp(bitcastOp, dstType, - adaptor.getSource()); + return success(); + } + + // Check that the source and destination type have the same bitwidth. + // Depending on the target environment, we may need to emulate certain + // types, which can cause issue with bitcast. + Type srcType = adaptor.getSource().getType(); + if (getNumBits(dstType) != getNumBits(srcType)) { + return rewriter.notifyMatchFailure( + bitcastOp, + llvm::formatv("different source ({0}) and target ({1}) bitwidth", + srcType, dstType)); + } + rewriter.replaceOpWithNewOp(bitcastOp, dstType, + adaptor.getSource()); return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -16,6 +16,23 @@ // ----- +// Check that without the proper capability we fail the pattern application +// to avoid generating invalid ops. + +module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { + +// CHECK-LABEL: @bitcast +func.func @bitcast(%arg0 : vector<2xf32>, %arg1: vector<2xf16>) -> (vector<4xf16>, vector<1xf32>) { + // CHECK-COUNT-2: vector.bitcast + %0 = vector.bitcast %arg0 : vector<2xf32> to vector<4xf16> + %1 = vector.bitcast %arg1 : vector<2xf16> to vector<1xf32> + return %0, %1: vector<4xf16>, vector<1xf32> +} + +} // end module + +// ----- + module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { // CHECK-LABEL: @cl_fma