diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -99,10 +99,16 @@ rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); Value greaterRank = rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); + auto erasedRankType = + RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + Value rankErasedLhs = + rewriter.create(loc, erasedRankType, transformed.lhs()); + Value rankErasedRhs = + rewriter.create(loc, erasedRankType, transformed.rhs()); Value lesserRankOperand = - rewriter.create(loc, lhsRankULE, op.lhs(), op.rhs()); + rewriter.create(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); Value greaterRankOperand = - rewriter.create(loc, lhsRankULE, op.rhs(), op.lhs()); + rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); // Allocate stack memory for the broadcasted extent tensor. Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -305,9 +305,9 @@ // ----- -// CHECK-LABEL: @broadcast +// CHECK-LABEL: @broadcast_unknown_extents // CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) -func @broadcast(%a : tensor, %b : tensor) { +func @broadcast_unknown_extents(%a : tensor, %b : tensor) { // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor @@ -315,8 +315,10 @@ // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor + // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { @@ -340,3 +342,43 @@ : tensor, tensor -> tensor return } + +// ----- + +// CHECK-LABEL: @broadcast_known_different_extents +// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>) +func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) { + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex> + // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex> + // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index + // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor + // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref + // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index + // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { + // CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor + // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor + // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { + // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index + // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor + // CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index + // CHECK: } else { + // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index + // CHECK: } + // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref + // CHECK: } + // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref + %0 = shape.broadcast %a, %b + : tensor<2xindex>, tensor<3xindex> -> tensor + return +}