Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
Show First 20 Lines • Show All 231 Lines • ▼ Show 20 Lines | |||||
} // namespace | } // namespace | ||||
LogicalResult IsBroadcastableOpConverter::matchAndRewrite( | LogicalResult IsBroadcastableOpConverter::matchAndRewrite( | ||||
IsBroadcastableOp op, ArrayRef<Value> operands, | IsBroadcastableOp op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const { | ConversionPatternRewriter &rewriter) const { | ||||
// For now, this lowering is only defined on `tensor<?xindex>` operands, not | // For now, this lowering is only defined on `tensor<?xindex>` operands, not | ||||
// on shapes. | // on shapes. | ||||
IsBroadcastableOp::Adaptor transformed(operands); | IsBroadcastableOp::Adaptor transformed(operands); | ||||
if (transformed.lhs().getType().isa<ShapeType>() || | if (!llvm::all_of(op.shapes(), | ||||
transformed.rhs().getType().isa<ShapeType>()) | [](Value v) { return !v.getType().isa<ShapeType>(); })) | ||||
return failure(); | return failure(); | ||||
frgossen: They can be shapes, can't they?
In that case, don't we want to make the pattern fail normally? | |||||
Thanks for catching that. Done. tpopp: Thanks for catching that. Done. | |||||
auto loc = op.getLoc(); | auto loc = op.getLoc(); | ||||
Value zero = rewriter.create<ConstantIndexOp>(loc, 0); | ImplicitLocOpBuilder lb(loc, rewriter); | ||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1); | Value zero = lb.create<ConstantIndexOp>(0); | ||||
Value one = lb.create<ConstantIndexOp>(1); | |||||
Type indexTy = lb.getIndexType(); | |||||
// Save all the ranks for bounds checking. Because this is a tensor | |||||
// representing the shape extents, the rank is the extent of the only | |||||
// dimension in the tensor. | |||||
SmallVector<Value> ranks, rankDiffs; | |||||
llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { | |||||
return lb.create<DimOp>(v, zero); | |||||
})); | |||||
// Find the maximum rank | |||||
Value maxRank = ranks.front(); | |||||
for (Value v : llvm::drop_begin(ranks, 1)) { | |||||
Value rankIsGreater = lb.create<CmpIOp>(CmpIPredicate::ugt, v, maxRank); | |||||
maxRank = lb.create<SelectOp>(rankIsGreater, v, maxRank); | |||||
} | |||||
// Calculate the difference of ranks and the maximum rank for later offsets. | |||||
llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) { | |||||
return lb.create<SubIOp>(indexTy, maxRank, v); | |||||
})); | |||||
// Find smaller and greater rank and extent tensor. | |||||
Value lhsRank = rewriter.create<DimOp>(loc, transformed.lhs(), zero); | |||||
Value rhsRank = rewriter.create<DimOp>(loc, transformed.rhs(), zero); | |||||
Value lhsRankULE = | |||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank); | |||||
Type indexTy = rewriter.getIndexType(); | |||||
Value lesserRank = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, lhsRank, rhsRank); | |||||
Value greaterRank = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, rhsRank, lhsRank); | |||||
auto erasedRankType = | |||||
RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); | |||||
Value rankErasedLhs = | |||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.lhs()); | |||||
Value rankErasedRhs = | |||||
rewriter.create<tensor::CastOp>(loc, erasedRankType, transformed.rhs()); | |||||
Value lesserRankOperand = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedLhs, rankErasedRhs); | |||||
Value greaterRankOperand = | |||||
rewriter.create<SelectOp>(loc, lhsRankULE, rankErasedRhs, rankErasedLhs); | |||||
Value rankDiff = | |||||
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank); | |||||
Type i1Ty = rewriter.getI1Type(); | Type i1Ty = rewriter.getI1Type(); | ||||
Value init = | Value trueVal = | ||||
rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); | rewriter.create<ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true)); | ||||
// Determine if all overlapping extents are broadcastable. | auto reduceResult = lb.create<ForOp>( | ||||
auto reduceResult = rewriter.create<ForOp>( | loc, zero, maxRank, one, ValueRange{trueVal}, | ||||
loc, rankDiff, greaterRank, one, ValueRange{init}, | |||||
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { | [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) { | ||||
Value greaterRankOperandExtent = b.create<tensor::ExtractOp>( | // Find a non-1 dim, if it exists. Note that the first part of this | ||||
loc, greaterRankOperand, ValueRange{iv}); | // could reuse the Broadcast lowering entirely, but we redo the work | ||||
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>( | // here to make optimizations easier between the two loops. | ||||
loc, CmpIPredicate::eq, greaterRankOperandExtent, one); | Value broadcastedDim = getBroadcastedDim( | ||||
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff); | ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); | ||||
Value lesserRankOperandExtent = b.create<tensor::ExtractOp>( | |||||
loc, lesserRankOperand, ValueRange{ivShifted}); | Value broadcastable = iterArgs[0]; | ||||
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>( | for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { | ||||
loc, CmpIPredicate::eq, lesserRankOperandExtent, one); | Value shape, rankDiff; | ||||
Value extentsAreEqual = | std::tie(shape, rankDiff) = tup; | ||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent, | Value outOfBounds = | ||||
nit: std::tie? herhut: nit: std::tie? | |||||
lesserRankOperandExtent); | b.create<CmpIOp>(loc, CmpIPredicate::ult, iv, rankDiff); | ||||
Value broadcastableExtents = b.create<AndOp>( | broadcastable = | ||||
loc, iterArgs[0], | b.create<IfOp>( | ||||
b.create<OrOp>(loc, | loc, TypeRange{i1Ty}, outOfBounds, | ||||
b.create<OrOp>(loc, greaterRankOperandExtentIsOne, | [&](OpBuilder &b, Location loc) { | ||||
lesserRankOperandExtentIsOne), | // Non existent dimensions are always broadcastable | ||||
extentsAreEqual)); | b.create<scf::YieldOp>(loc, broadcastable); | ||||
b.create<scf::YieldOp>(loc, broadcastableExtents); | }, | ||||
[&](OpBuilder &b, Location loc) { | |||||
// Every value needs to be either 1, or the same non-1 | |||||
// value to be broadcastable in this dim. | |||||
Value operandDimension = | |||||
b.create<SubIOp>(loc, indexTy, iv, rankDiff); | |||||
Value dimensionExtent = b.create<tensor::ExtractOp>( | |||||
loc, shape, ValueRange{operandDimension}); | |||||
Value equalOne = b.create<CmpIOp>(loc, CmpIPredicate::eq, | |||||
dimensionExtent, one); | |||||
Value equalBroadcasted = | |||||
b.create<CmpIOp>(loc, CmpIPredicate::eq, | |||||
dimensionExtent, broadcastedDim); | |||||
Value result = b.create<AndOp>( | |||||
loc, broadcastable, | |||||
b.create<OrOp>(loc, equalOne, equalBroadcasted)); | |||||
b.create<scf::YieldOp>(loc, result); | |||||
}) | |||||
.getResult(0); | |||||
} | |||||
b.create<scf::YieldOp>(loc, broadcastable); | |||||
Why do you sometimes use iv and sometimes outputDimension? herhut: Why do you sometimes use iv and sometimes outputDimension? | |||||
}); | }); | ||||
rewriter.replaceOp(op, reduceResult.results().front()); | rewriter.replaceOp(op, reduceResult.results().front()); | ||||
return success(); | return success(); | ||||
} | } | ||||
namespace { | namespace { | ||||
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { | class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> { | ||||
▲ Show 20 Lines • Show All 335 Lines • Show Last 20 Lines |
They can be shapes, can't they?
In that case, don't we want to make the pattern fail normally?