diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -39,25 +39,17 @@ // Find smaller and greater rank and extent tensor. Value lhsRank = rewriter.create(loc, op.lhs(), zero); Value rhsRank = rewriter.create(loc, op.rhs(), zero); - Value lhsSmaller = + Value lhsRankULE = rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); Type indexTy = rewriter.getIndexType(); - Type extentTensorTy = op.lhs().getType(); - auto ifOp = rewriter.create( - loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, - lhsSmaller, - [&](OpBuilder &b, Location loc) { - b.create( - loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()}); - }, - [&](OpBuilder &b, Location loc) { - b.create( - loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()}); - }); - Value lesserRank = ifOp.getResult(0); - Value lesserRankOperand = ifOp.getResult(1); - Value greaterRank = ifOp.getResult(2); - Value greaterRankOperand = ifOp.getResult(3); + Value lesserRank = + rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); + Value greaterRank = + rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); + Value lesserRankOperand = + rewriter.create(loc, lhsRankULE, op.lhs(), op.rhs()); + Value greaterRankOperand = + rewriter.create(loc, lhsRankULE, op.rhs(), op.lhs()); Value rankDiff = rewriter.create(loc, indexTy, greaterRank, lesserRank); diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -90,27 +90,19 @@ Value one = rewriter.create(loc, 1); // Find smaller and greater rank and extent tensor. - Value lhsRank = rewriter.create(loc, transformed.lhs(), zero); - Value rhsRank = rewriter.create(loc, transformed.rhs(), zero); - Value lhsSmaller = + Value lhsRank = rewriter.create(loc, op.lhs(), zero); + Value rhsRank = rewriter.create(loc, op.rhs(), zero); + Value lhsRankULE = rewriter.create(loc, CmpIPredicate::ule, lhsRank, rhsRank); Type indexTy = rewriter.getIndexType(); - Type extentTensorTy = op.getType(); - auto ifOp = rewriter.create( - loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy}, - lhsSmaller, - [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{lhsRank, transformed.lhs(), - rhsRank, transformed.rhs()}); - }, - [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{rhsRank, transformed.rhs(), - lhsRank, transformed.lhs()}); - }); - Value smallerRank = ifOp.getResult(0); - Value smallerOperand = ifOp.getResult(1); - Value greaterRank = ifOp.getResult(2); - Value greaterOperand = ifOp.getResult(3); + Value lesserRank = + rewriter.create(loc, lhsRankULE, lhsRank, rhsRank); + Value greaterRank = + rewriter.create(loc, lhsRankULE, rhsRank, lhsRank); + Value lesserRankOperand = + rewriter.create(loc, lhsRankULE, op.lhs(), op.rhs()); + Value greaterRankOperand = + rewriter.create(loc, lhsRankULE, op.rhs(), op.lhs()); // Allocate stack memory for the broadcasted extent tensor. Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); @@ -118,11 +110,11 @@ // Copy extents from greater operand that are not challenged. Value rankDiff = - rewriter.create(loc, indexTy, greaterRank, smallerRank); + rewriter.create(loc, indexTy, greaterRank, lesserRank); rewriter.create(loc, zero, rankDiff, one, llvm::None, [&](OpBuilder &b, Location loc, Value iv, ValueRange) { Value extent = b.create( - loc, greaterOperand, ValueRange{iv}); + loc, greaterRankOperand, ValueRange{iv}); b.create(loc, extent, mem, ValueRange{iv}); b.create(loc); }); @@ -132,16 +124,16 @@ loc, rankDiff, greaterRank, one, llvm::None, [&](OpBuilder &b, Location loc, Value iv, ValueRange) { Value greaterOperandExtent = - b.create(loc, greaterOperand, ValueRange{iv}); + b.create(loc, greaterRankOperand, ValueRange{iv}); Value greaterOperandExtentIsOne = b.create(loc, CmpIPredicate::eq, greaterOperandExtent, one); auto ifOp = b.create( loc, TypeRange{indexTy}, greaterOperandExtentIsOne, [&](OpBuilder &b, Location loc) { Value ivShifted = b.create(loc, indexTy, iv, rankDiff); - Value smallerOperandExtent = b.create( - loc, smallerOperand, ValueRange{ivShifted}); - b.create(loc, smallerOperandExtent); + Value lesserRankOperandExtent = b.create( + loc, lesserRankOperand, ValueRange{ivShifted}); + b.create(loc, lesserRankOperandExtent); }, [&](OpBuilder &b, Location loc) { b.create(loc, greaterOperandExtent); diff --git a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir --- a/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir +++ b/mlir/test/Conversion/ShapeToStandard/convert-shape-constraints.mlir @@ -7,25 +7,24 @@ // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[RET:.*]] = shape.const_witness true -// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor -// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor -// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index -// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor, index, tensor) { -// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor, index, tensor -// CHECK: } else { -// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor, index, tensor -// CHECK: } -// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index -// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] { -// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor -// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index -// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor -// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index -// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index -// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index -// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1 -// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1 -// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast" +// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor +// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor +// CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index +// CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor +// CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor +// CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index +// CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { +// CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor +// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index +// CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IVSHIFTED]]] : tensor +// CHECK: %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index +// CHECK: %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LESSER_RANK_OPERAND_EXTENT]], %[[C1]] : index +// CHECK: %[[EXTENTS_AGREE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[LESSER_RANK_OPERAND_EXTENT]] : index +// CHECK: %[[OR_TMP:.*]] = or %[[GREATER_RANK_OPERAND_EXTENT_IS_ONE]], %[[LESSER_RANK_OPERAND_EXTENT_IS_ONE]] : i1 +// CHECK: %[[BROADCAST_IS_VALID:.*]] = or %[[EXTENTS_AGREE]], %[[OR_TMP]] : i1 +// CHECK: assert %[[BROADCAST_IS_VALID]], "invalid broadcast" // CHECK: } // CHECK: return %[[RET]] : !shape.witness // CHECK: } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -312,27 +312,26 @@ // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor - // CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] - // CHECK: %[[ARG:.*]]:4 = scf.if %[[LHS_SMALLER]] -> (index, tensor, index, tensor) { - // CHECK: scf.yield %[[LHS_RANK]], %[[LHS]], %[[RHS_RANK]], %[[RHS]] : index, tensor, index, tensor - // CHECK: } else { - // CHECK: scf.yield %[[RHS_RANK]], %[[RHS]], %[[LHS_RANK]], %[[LHS]] : index, tensor, index, tensor - // CHECK: } - // CHECK: %[[MEM:.*]] = alloca(%[[ARG]]#2) : memref - // CHECK: %[[RANK_DIFF:.*]] = subi %[[ARG]]#2, %[[ARG]]#0 : index + // CHECK: %[[LHS_RANK_ULE:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[LESSER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[LHS_RANK]], %[[RHS_RANK]] : index + // CHECK: %[[GREATER_RANK:.*]] = select %[[LHS_RANK_ULE]], %[[RHS_RANK]], %[[LHS_RANK]] : index + // CHECK: %[[LESSER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[LHS]], %[[RHS]] : tensor + // CHECK: %[[GREATER_RANK_OPERAND:.*]] = select %[[LHS_RANK_ULE]], %[[RHS]], %[[LHS]] : tensor + // CHECK: %[[MEM:.*]] = alloca(%[[GREATER_RANK]]) : memref + // CHECK: %[[RANK_DIFF:.*]] = subi %[[GREATER_RANK]], %[[LESSER_RANK]] : index // CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[RANK_DIFF]] step %[[C1]] { - // CHECK: %[[EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor + // CHECK: %[[EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref // CHECK: } - // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[ARG]]#2 step %[[C1]] { - // CHECK: %[[GREATER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#3[%[[IV]]] : tensor - // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_OPERAND_EXTENT]], %[[C1]] : index + // CHECK: scf.for %[[IV:.*]] = %[[RANK_DIFF]] to %[[GREATER_RANK]] step %[[C1]] { + // CHECK: %[[GREATER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[GREATER_RANK_OPERAND]][%[[IV]]] : tensor + // CHECK: %[[GREATER_OPERAND_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[GREATER_RANK_OPERAND_EXTENT]], %[[C1]] : index // CHECK: %[[EXTENT:.*]] = scf.if %[[GREATER_OPERAND_EXTENT_IS_ONE]] -> (index) { // CHECK: %[[IV_SHIFTED:.*]] = subi %[[IV]], %[[RANK_DIFF]] : index - // CHECK: %[[SMALLER_OPERAND_EXTENT:.*]] = extract_element %[[ARG]]#1[%[[IV_SHIFTED]]] : tensor - // CHECK: scf.yield %[[SMALLER_OPERAND_EXTENT]] : index + // CHECK: %[[LESSER_RANK_OPERAND_EXTENT:.*]] = extract_element %[[LESSER_RANK_OPERAND]][%[[IV_SHIFTED]]] : tensor + // CHECK: scf.yield %[[LESSER_RANK_OPERAND_EXTENT]] : index // CHECK: } else { - // CHECK: scf.yield %[[GREATER_OPERAND_EXTENT]] : index + // CHECK: scf.yield %[[GREATER_RANK_OPERAND_EXTENT]] : index // CHECK: } // CHECK: store %[[EXTENT]], %[[MEM]][%[[IV]]] : memref // CHECK: } @@ -341,4 +340,3 @@ : tensor, tensor -> tensor return } -