diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -15,6 +15,32 @@ using namespace mlir::scf; namespace { + +static void unpackUnrealizedConversionCast(Value v, + SmallVectorImpl &unpacked) { + if (auto cast = llvm::dyn_cast_or_null( + v.getDefiningOp())) { + if (cast.getInputs().size() != 1) { + // 1 : N type conversion + unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); + return; + } + } + // 1 : 1 type conversion. + unpacked.push_back(v); +} + +static Value packUnrealizedConversionCast(OpBuilder &builder, Location loc, + ValueRange inputs, Type origType) { + if (inputs.size() == 1) + // 1 : 1 conversion, simply use the value from adapter. + return inputs.front(); + + auto convertOp = + builder.create(loc, origType, inputs); + return convertOp.getResult(0); +} + class ConvertForOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -22,51 +48,74 @@ matchAndRewrite(ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newResultTypes; - for (auto type : op.getResultTypes()) { - Type newType = typeConverter->convertType(type); - if (!newType) - return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); - newResultTypes.push_back(newType); + SmallVector offsets; + offsets.push_back(0); + // Do the type conversion and record the offsets. + for (Type type : op.getResultTypes()) { + if (failed(typeConverter->convertTypes(type, newResultTypes))) + return rewriter.notifyMatchFailure(op, "could not convert result"); + offsets.push_back(newResultTypes.size()); } - // Clone the op without the regions and inline the regions from the old op. + // Clone the op without the regions and inline the regions from the old + // op. // // This is a little bit tricky. We have two concerns here: // // 1. We cannot update the op in place because the dialect conversion // framework does not track type changes for ops updated in place, so it // won't insert appropriate materializations on the changed result types. - // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to - // clone the op. + // PR47938 tracks this issue, but it seems hard to fix. Instead, we need + // to clone the op. // - // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being - // inefficient to recursively clone the regions, there is a correctness - // issue: if we clone with the regions, then the dialect conversion - // framework thinks that we just inserted all the cloned child ops. But what - // we want is to "take" the child regions and let the dialect conversion - // framework continue recursively into ops inside those regions (which are - // already in its worklist; inlining them into the new op's regions doesn't - // remove the child ops from the worklist). - ForOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); - // Take the region from the old op and put it in the new op. + // 2. We cannot simply call `op.clone()` to get the cloned op. Besides + // being inefficient to recursively clone the regions, there is a + // correctness issue: if we clone with the regions, then the dialect + // conversion framework thinks that we just inserted all the cloned child + // ops. But what we want is to "take" the child regions and let the + // dialect conversion framework continue recursively into ops inside those + // regions (which are already in its worklist; inlining them into the new + // op's regions doesn't remove the child ops from the worklist). + // + // convertRegionTypes already takes care of 1:N conversion. + if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter))) + return failure(); + + // Unpacked the iteration arguments. + SmallVector flatArgs; + for (Value arg : adaptor.getInitArgs()) + unpackUnrealizedConversionCast(arg, flatArgs); + + // We can not do clone as the number of result types after conversion might + // be different. + ForOp newOp = rewriter.create(op.getLoc(), adaptor.getLowerBound(), + adaptor.getUpperBound(), + adaptor.getStep(), flatArgs); + + // Since there is no way we can set the attribute through ForOp. + static_cast(newOp)->setAttrs( + static_cast(op)->getAttrs()); + + // We do not need the empty block created by rewriter. + newOp.getBody(0)->erase(); + // Inline the type converted region from the original operation. rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); - // Now, update all the types. - - // Convert the type of the entry block of the ForOp's body. - if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), - *getTypeConverter()))) { - return rewriter.notifyMatchFailure(op, "could not convert body types"); + // Pack the return value. + SmallVector packedRets; + for (unsigned i = 1, e = offsets.size(); i < e; i++) { + unsigned start = offsets[i - 1], end = offsets[i]; + unsigned len = end - start; + + ValueRange mappedValue = newOp.getResults().slice(start, len); + Type origType = op.getResultTypes()[i - 1]; + // Generate a unresolved conversion cast. + packedRets.push_back(packUnrealizedConversionCast(rewriter, op.getLoc(), + mappedValue, origType)); } - // Change the clone to use the updated operands. We could have cloned with - // a BlockAndValueMapping, but this seems a bit more direct. - newOp->setOperands(adaptor.getOperands()); - // Update the result types to the new converted types. - for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, packedRets); - rewriter.replaceOp(op, newOp.getResults()); return success(); } }; @@ -81,12 +130,12 @@ ConversionPatternRewriter &rewriter) const override { // TODO: Generalize this to any type conversion, not just 1:1. // - // We need to implement something more sophisticated here that tracks which - // types convert to which other types and does the appropriate + // We need to implement something more sophisticated here that tracks + // which types convert to which other types and does the appropriate // materialization logic. // For example, it's possible that one result type converts to 0 types and - // another to 2 types, so newResultTypes would at least be the right size to - // not crash in the llvm::zip call below, but then we would set the the + // another to 2 types, so newResultTypes would at least be the right size + // to not crash in the llvm::zip call below, but then we would set the the // wrong type on the SSA values! These edge cases are also why we cannot // safely use the TypeConverter::convertTypes helper here. SmallVector newResultTypes; @@ -125,7 +174,17 @@ LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + SmallVector unpackedYield; + for (Value operand : adaptor.getOperands()) { + if (!typeConverter->isLegal(operand.getType())) { + // The type from adopter is illegal, probaly an 1 : N type conversion. + unpackUnrealizedConversionCast(operand, unpackedYield); + } else { + // 1 : 1 type conversion + unpackedYield.push_back(operand); + } + } + rewriter.replaceOpWithNewOp(op, unpackedYield); return success(); } };