Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -813,6 +813,16 @@ return Value(); } +/// Fold extractOp with scalar result coming from BroadcastOp. +static Value foldExtractFromBroadcast(ExtractOp extractOp) { + if (extractOp.getType().isa()) + return Value(); + auto broadcastOp = extractOp.vector().getDefiningOp(); + if (!broadcastOp) + return Value(); + return broadcastOp.source(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -820,6 +830,8 @@ return getResult(); if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; + if (auto val = foldExtractFromBroadcast(*this)) + return val; return OpFoldResult(); } Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -348,6 +348,17 @@ // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func @fold_extract_broadcast(%a : f32) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = constant 0 : index