diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -385,7 +385,8 @@ /*============================================================================*/ StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); - structType = value.getType().cast(); + structType = value.getType().dyn_cast(); + assert(structType && "expected llvm type"); } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, @@ -2303,6 +2304,8 @@ return matchFailure(); // Create the descriptor. + if (!operands.front().getType().isa()) + return matchFailure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -verify-diagnostics -split-input-file + +#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +func @invalid_memref_cast(%arg0: memref) { + %c1 = constant 1 : index + %c0 = constant 0 : index + // expected-error@+1: 'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, + %5 = memref_cast %arg0 : memref to memref + %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref to memref + return +} +