diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -110,47 +110,48 @@ Value greaterRankOperand = rewriter.create(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); - // Allocate stack memory for the broadcasted extent tensor. - Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); - Value mem = rewriter.create(loc, memTy, ValueRange{greaterRank}); - - // Copy extents from greater operand that are not challenged. Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); - rewriter.create(loc, zero, rankDiff, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange) { - Value extent = b.create( - loc, greaterRankOperand, ValueRange{iv}); - b.create(loc, extent, mem, ValueRange{iv}); - b.create(loc); - }); - - // Determine remaining broadcasted extents. - rewriter.create( - loc, rankDiff, greaterRank, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange) { - Value greaterOperandExtent = - b.create(loc, greaterRankOperand, ValueRange{iv}); - Value greaterOperandExtentIsOne = - b.create(loc, CmpIPredicate::eq, greaterOperandExtent, one); + rewriter.replaceOpWithNewOp( + op, getExtentTensorType(op.getContext()), ValueRange{greaterRank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value outputDimension = args[0]; + Value isUnchallengedDimension = b.create( + loc, CmpIPredicate::ult, outputDimension, rankDiff); + Value greaterRankOperandExtent = b.create( + loc, greaterRankOperand, outputDimension); + // The initial dimensions of the greater-rank operand are unchallenged, + // so we can take them as-is. Otherwise, we need to do a comparison. + // We need an actual branch here (instead of a select) because the + // lesser-rank operand might be rank 0, so any extract_element would be + // invalid. auto ifOp = b.create( - loc, TypeRange{indexTy}, greaterOperandExtentIsOne, + loc, TypeRange{indexTy}, isUnchallengedDimension, [&](OpBuilder &b, Location loc) { - Value ivShifted = b.create(loc, indexTy, iv, rankDiff); - Value lesserRankOperandExtent = b.create( - loc, lesserRankOperand, ValueRange{ivShifted}); - b.create(loc, lesserRankOperandExtent); + b.create(loc, greaterRankOperandExtent); }, [&](OpBuilder &b, Location loc) { - b.create(loc, greaterOperandExtent); + // The broadcasting logic is: + // - if one extent (here we arbitrariliy choose the extent from + // the greater-rank operand) is equal to 1, then take the extent + // from the other operand + // - otherwise, take the extent as-is. + // Note that this logic remains correct in the presence of + // dimensions of zero extent. + Value lesserRankOperandDimension = + b.create(loc, indexTy, outputDimension, rankDiff); + Value lesserRankOperandExtent = b.create( + loc, lesserRankOperand, + ValueRange{lesserRankOperandDimension}); + Value greaterRankOperandExtentIsOne = b.create( + loc, CmpIPredicate::eq, greaterRankOperandExtent, one); + Value broadcastedExtent = b.create( + loc, greaterRankOperandExtentIsOne, lesserRankOperandExtent, + greaterRankOperandExtent); + b.create(loc, broadcastedExtent); }); - Value extent = ifOp.getResult(0); - b.create(loc, extent, mem, ValueRange{iv}); - b.create(loc); + b.create(loc, ifOp.getResult(0)); }); - - // Load broadcasted shape as an extent tensor. - rewriter.replaceOpWithNewOp(op, mem); return success(); } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -305,39 +305,39 @@ // ----- -// CHECK-LABEL: @broadcast_unknown_extents -// CHECK-SAME: (%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) +// CHECK-LABEL: func @broadcast_unknown_extents( +// CHECK-SAME: %[[LHS:.*]]: tensor, +// CHECK-SAME: %[[RHS:.*]]: tensor) { func @broadcast_unknown_extents(%a : tensor, %b : tensor) { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor - // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor - // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor - // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor - // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor - // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref - // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { - // CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor - // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref - // CHECK: } - // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { - // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor - // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index - // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { - // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index - // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor - // CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index - // CHECK: } else { - // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index - // CHECK: } - // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref - // CHECK: } - // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor + // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor + // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index + // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor to tensor + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor + // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index + // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { + // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): + // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor + // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) { + // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index + // CHECK: } else { + // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index + // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index + // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index + // CHECK: } + // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index + // CHECK: } : tensor + // CHECK: return + // CHECK: } %0 = shape.broadcast %a, %b : tensor, tensor -> tensor return @@ -345,39 +345,39 @@ // ----- -// CHECK-LABEL: @broadcast_known_different_extents -// CHECK-SAME: (%[[LHS:.*]]: tensor<2xindex>, %[[RHS:.*]]: tensor<3xindex>) +// CHECK-LABEL: func @broadcast_known_different_extents( +// CHECK-SAME: %[[LHS:.*]]: tensor<2xindex>, +// CHECK-SAME: %[[RHS:.*]]: tensor<3xindex>) { func @broadcast_known_different_extents(%a : tensor<2xindex>, %b : tensor<3xindex>) { - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex> - // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex> - // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index - // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index - // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor - // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor - // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor - // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor - // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref - // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index - // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { - // CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor - // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref - // CHECK: } - // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { - // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor - // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index - // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { - // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index - // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor - // CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index - // CHECK: } else { - // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index - // CHECK: } - // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref - // CHECK: } - // CHECK: %[[BROADCASTED:.*]] = tensor_load %[[MEM]] : memref + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<2xindex> + // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<3xindex> + // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index + // CHECK: %[[ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<2xindex> to tensor + // CHECK: %[[ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<3xindex> to tensor + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_LHS]], %[[ERASED_RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[ERASED_RHS]], %[[ERASED_LHS]] : tensor + // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index + // CHECK: %[[RESULT:.*]] = dynamic_tensor_from_elements %[[GREATER_RANK]] { + // CHECK: ^bb0(%[[OUTPUT_DIMENSION:.*]]: index): + // CHECK: %[[IS_UNCHALLENGED_DIMENSION:.*]] = cmpi "ult", %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[OUTPUT_DIMENSION]]] : tensor + // CHECK: %[[OUTPUT_EXTENT:.*]] = scf.if %[[IS_UNCHALLENGED_DIMENSION]] -> (index) { + // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index + // CHECK: } else { + // CHECK: %[[LESSER_RANK_OPERAND_DIMENSION:.*]] = subi %[[OUTPUT_DIMENSION]], %[[RANK_DIFF]] : index + // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[LESSER_RANK_OPERAND_DIMENSION]]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: %[[BROADCASTED_EXTENT:.*]] = select %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT]], %[[GREATER_RANK_OPERAND_EXTENT]] : index + // CHECK: scf.yield %[[BROADCASTED_EXTENT]] : index + // CHECK: } + // CHECK: yield %[[OUTPUT_EXTENT:.*]] : index + // CHECK: } : tensor + // CHECK: return + // CHECK: } %0 = shape.broadcast %a, %b : tensor<2xindex>, tensor<3xindex> -> tensor return