diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2236,8 +2236,8 @@ if (parentOp) { if (auto castOp = dyn_cast(parentOp)) { Value castSource = castOp.getSource(); - auto castSourceType = castSource.getType().cast(); - if (castSourceType.hasStaticShape()) + auto castSourceType = castSource.getType().dyn_cast(); + if (castSourceType && castSourceType.hasStaticShape()) sourceShape = castSourceType.getShape(); } } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -47,7 +47,6 @@ // ----- - // CHECK-LABEL: func @tensor.cast( func.func @tensor.cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>) -> tensor<3x?xf32> @@ -68,6 +67,30 @@ // ----- +// CHECK-LABEL: func @tensor.cast.unranked( +func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : tensor<*xf32>) + -> tensor<*xf32> +{ + // CHECK: tensor.cast + // CHECK: tensor.cast + // CHECK: tensor.cast + %ta = tensor.cast %a : tensor<*xf32> to tensor + %tb = tensor.cast %b : tensor<*xf32> to tensor + %tc = tensor.cast %c : tensor<*xf32> to tensor + + // CHECK: linalg.matmul ins({{.*}}tensor, tensor) + // CHECK-SAME: outs({{.*}}tensor) -> tensor + %0 = linalg.matmul ins(%ta, %tb: tensor, tensor) + outs(%tc: tensor) -> tensor + + // CHECK: tensor.cast + %1 = tensor.cast %0 : tensor to tensor<*xf32> + + return %1: tensor<*xf32> +} + +// ----- + // CHECK-LABEL: func @linalg_effects( // CHECK-SAME: %[[A:[a-z0-9]*]]: tensor // CHECK-SAME: %[[B:[a-z0-9]*]]: memref