diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1288,8 +1288,8 @@ if (!sourceType.hasStaticShape()) return {}; - auto resultType = op.result().getType().dyn_cast(); - if (!resultType || !resultType.hasStaticShape()) + auto resultType = op.result().getType().cast(); + if (!resultType.hasStaticShape()) return {}; // Control the size. Sice the way to get a new constant collects each element, @@ -1328,8 +1328,7 @@ outValues.reserve(sourceType.getNumElements()); sliceElements( elems.begin(), counts, offsets, sizes, strides, &outValues); - return DenseElementsAttr::get(op.result().getType().cast(), - outValues); + return DenseElementsAttr::get(resultType, outValues); } if (auto elems = attr.dyn_cast()) { @@ -1337,8 +1336,7 @@ outValues.reserve(sourceType.getNumElements()); sliceElements( elems.begin(), counts, offsets, sizes, strides, &outValues); - return DenseElementsAttr::get(op.result().getType().cast(), - outValues); + return DenseElementsAttr::get(resultType, outValues); } return {};