diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2114,7 +2114,7 @@ let arguments = (ins TransformHandleTypeInterface:$target, Variadic:$vector_sizes, - UnitAttr:$vectorize_nd_extract, + OptionalAttr:$vectorize_nd_extract, DefaultValuedOptionalAttr: $scalable_sizes, DefaultValuedOptionalAttr: @@ -2126,7 +2126,8 @@ `vector_sizes` custom($vector_sizes, $static_vector_sizes, type($vector_sizes), - $scalable_sizes) + $scalable_sizes) | + `vectorize_nd_extract` $vectorize_nd_extract ) attr-dict `:` type($target) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3231,7 +3231,9 @@ if (failed(linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(), - getVectorizeNdExtract()))) { + getVectorizeNdExtract().has_value() + ? getVectorizeNdExtract().value() + : false))) { return mlir::emitSilenceableFailure(target->getLoc()) << "Attempted to vectorize, but failed"; } diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir @@ -28,7 +28,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op } // ----- @@ -83,7 +83,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op } // ----- @@ -121,7 +121,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op } // ----- @@ -176,7 +176,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [1, 4] { vectorize_nd_extract } : !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 4] vectorize_nd_extract : !transform.any_op } // ----- @@ -226,7 +226,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [3, 3] vectorize_nd_extract : !transform.any_op } // ----- @@ -269,5 +269,5 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op + transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] vectorize_nd_extract : !transform.any_op }