diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -294,10 +294,11 @@ return failure(); auto dstElementsAttr = constOp.getValue().dyn_cast(); - ShapedType dstAttrType = dstElementsAttr.getType(); if (!dstElementsAttr) return failure(); + ShapedType dstAttrType = dstElementsAttr.getType(); + // If the composite type has more than one dimensions, perform linearization. if (srcType.getRank() > 1) { if (srcType.isa()) {