Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -813,6 +813,14 @@ return Value(); } +/// Fold extractOp with scalar result coming from BroadcastOp. +static Value foldExtractFromBroadcast(ExtractOp extractOp) { + auto broadcastOp = extractOp.vector().getDefiningOp(); + if (!broadcastOp || extractOp.getType() != broadcastOp.getSourceType()) + return Value(); + return broadcastOp.source(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -820,6 +828,8 @@ return getResult(); if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; + if (auto val = foldExtractFromBroadcast(*this)) + return val; return OpFoldResult(); } Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -348,6 +348,42 @@ // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func @fold_extract_broadcast(%a : f32) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast_vector +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: return %[[A]] : vector<4xf32> +func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 1] : vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// Negative test for extract_op folding when the type of broadcast source +// doesn't match the type of vector.extract. +// CHECK-LABEL: fold_extract_broadcast_negative +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf32> to vector<1x2x4xf32> +// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1, 2] : vector<1x2x4xf32> +// CHECK: return %[[R]] : f32 +func @fold_extract_broadcast_negative(%a : vector<4xf32>) -> f32 { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = constant 0 : index