Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -813,6 +813,37 @@ return Value(); } +/// Fold extractOp with scalar result coming from BroadcastOp. +static Value foldExtractFromBroadcast(ExtractOp extractOp) { + auto broadcastOp = extractOp.vector().getDefiningOp(); + if (!broadcastOp) + return Value(); + if (extractOp.getType() == broadcastOp.getSourceType()) + return broadcastOp.source(); + auto getRank = [](Type type) { + return type.isa() ? type.cast().getRank() : 0; + }; + unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType()); + unsigned extractResultRank = getRank(extractOp.getType()); + if (extractResultRank < broadcasrSrcRank) { + auto extractPos = extractVector(extractOp.position()); + unsigned rankDiff = broadcasrSrcRank - extractResultRank; + extractPos.erase( + extractPos.begin(), + std::next(extractPos.begin(), extractPos.size() - rankDiff)); + extractOp.setOperand(broadcastOp.source()); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(extractOp.getContext()); + extractOp.setAttr(ExtractOp::getPositionAttrName(), + b.getI64ArrayAttr(extractPos)); + return extractOp.getResult(); + } + // TODO: In case the rank of the broadcast source is greater than the rank of + // the extract result this can be combined into a new broadcast op. This needs + // to be added a canonicalization pattern if needed. + return Value(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -820,6 +851,8 @@ return getResult(); if (auto val = foldExtractOpFromInsertChainAndTranspose(*this)) return val; + if (auto val = foldExtractFromBroadcast(*this)) + return val; return OpFoldResult(); } Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -348,6 +348,54 @@ // ----- +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func @fold_extract_broadcast(%a : f32) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast_vector +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: return %[[A]] : vector<4xf32> +func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 1] : vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_broadcast +// CHECK-SAME: %[[A:.*]]: vector<4xf32> +// CHECK: %[[R:.*]] = vector.extract %[[A]][2] : vector<4xf32> +// CHECK: return %[[R]] : f32 +func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 { + %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// Negative test for extract_op folding when the type of broadcast source +// doesn't match the type of vector.extract. +// CHECK-LABEL: fold_extract_broadcast_negative +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32> +// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32> +// CHECK: return %[[R]] : vector<4xf32> +func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> + %r = vector.extract %b[0, 1] : vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = constant 0 : index