diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -15,58 +15,102 @@ using namespace mlir::scf; namespace { + +// Unpacks the single unrealized_conversion_cast using the list of inputs +// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d) +static void unpackUnrealizedConversionCast(Value v, + SmallVectorImpl &unpacked) { + if (auto cast = + dyn_cast_or_null(v.getDefiningOp())) { + if (cast.getInputs().size() != 1) { + // 1 : N type conversion. + unpacked.append(cast.getInputs().begin(), cast.getInputs().end()); + return; + } + } + // 1 : 1 type conversion. + unpacked.push_back(v); +} + class ConvertForOpTypes : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector newResultTypes; - for (auto type : op.getResultTypes()) { - Type newType = typeConverter->convertType(type); - if (!newType) - return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); - newResultTypes.push_back(newType); + SmallVector newResultTypes; + SmallVector offsets; + offsets.push_back(0); + // Do the type conversion and record the offsets. + for (Type type : op.getResultTypes()) { + if (failed(typeConverter->convertTypes(type, newResultTypes))) + return rewriter.notifyMatchFailure(op, "could not convert result"); + offsets.push_back(newResultTypes.size()); } - // Clone the op without the regions and inline the regions from the old op. + // Create a empty new op and inline the regions from the old op. // // This is a little bit tricky. We have two concerns here: // // 1. We cannot update the op in place because the dialect conversion // framework does not track type changes for ops updated in place, so it // won't insert appropriate materializations on the changed result types. - // PR47938 tracks this issue, but it seems hard to fix. Instead, we need to - // clone the op. + // PR47938 tracks this issue, but it seems hard to fix. Instead, we need + // to clone the op. // - // 2. We cannot simply call `op.clone()` to get the cloned op. Besides being - // inefficient to recursively clone the regions, there is a correctness - // issue: if we clone with the regions, then the dialect conversion - // framework thinks that we just inserted all the cloned child ops. But what - // we want is to "take" the child regions and let the dialect conversion - // framework continue recursively into ops inside those regions (which are - // already in its worklist; inlining them into the new op's regions doesn't - // remove the child ops from the worklist). - ForOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); - // Take the region from the old op and put it in the new op. + // 2. We need to resue the original region instead of cloning it, otherwise + // the dialect conversion framework thinks that we just inserted all the + // cloned child ops. But what we want is to "take" the child regions and let + // the dialect conversion framework continue recursively into ops inside + // those regions (which are already in its worklist; inlining them into the + // new op's regions doesn't remove the child ops from the worklist). + + // convertRegionTypes already takes care of 1:N conversion. + if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter))) + return failure(); + + // Unpacked the iteration arguments. + SmallVector flatArgs; + for (Value arg : adaptor.getInitArgs()) + unpackUnrealizedConversionCast(arg, flatArgs); + + // We can not do clone as the number of result types after conversion might + // be different. + ForOp newOp = rewriter.create(op.getLoc(), adaptor.getLowerBound(), + adaptor.getUpperBound(), + adaptor.getStep(), flatArgs); + + // Reserve whatever attributes in the original op. + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty block created by rewriter. + rewriter.eraseBlock(newOp.getBody(0)); + // Inline the type converted region from the original operation. rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), newOp.getLoopBody().end()); - // Now, update all the types. - - // Convert the type of the entry block of the ForOp's body. - if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), - *getTypeConverter()))) { - return rewriter.notifyMatchFailure(op, "could not convert body types"); + // Pack the return value. + SmallVector packedRets; + for (unsigned i = 1, e = offsets.size(); i < e; i++) { + unsigned start = offsets[i - 1], end = offsets[i]; + unsigned len = end - start; + ValueRange mappedValue = newOp.getResults().slice(start, len); + if (len != 1) { + // 1 : N type conversion. + Type origType = op.getResultTypes()[i - 1]; + Value mat = typeConverter->materializeSourceConversion( + rewriter, op.getLoc(), origType, mappedValue); + if (!mat) + return rewriter.notifyMatchFailure( + op, "Failed to materialize 1:N type conversion"); + packedRets.push_back(mat); + } else { + // 1 : 1 type conversion. + packedRets.push_back(mappedValue.front()); + } } - // Change the clone to use the updated operands. We could have cloned with - // a BlockAndValueMapping, but this seems a bit more direct. - newOp->setOperands(adaptor.getOperands()); - // Update the result types to the new converted types. - for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) - std::get<0>(t).setType(std::get<1>(t)); - rewriter.replaceOp(op, newOp.getResults()); + rewriter.replaceOp(op, packedRets); return success(); } }; @@ -81,12 +125,12 @@ ConversionPatternRewriter &rewriter) const override { // TODO: Generalize this to any type conversion, not just 1:1. // - // We need to implement something more sophisticated here that tracks which - // types convert to which other types and does the appropriate + // We need to implement something more sophisticated here that tracks + // which types convert to which other types and does the appropriate // materialization logic. // For example, it's possible that one result type converts to 0 types and - // another to 2 types, so newResultTypes would at least be the right size to - // not crash in the llvm::zip call below, but then we would set the the + // another to 2 types, so newResultTypes would at least be the right size + // to not crash in the llvm::zip call below, but then we would set the the // wrong type on the SSA values! These edge cases are also why we cannot // safely use the TypeConverter::convertTypes helper here. SmallVector newResultTypes; @@ -125,7 +169,11 @@ LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + SmallVector unpackedYield; + for (Value operand : adaptor.getOperands()) + unpackUnrealizedConversionCast(operand, unpackedYield); + + rewriter.replaceOpWithNewOp(op, unpackedYield); return success(); } }; diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -sparse-tensor-codegen -cse | FileCheck %s + +#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> +// CHECK-LABEL: func @for( +// CHECK-SAME: %[[DIM_SIZE:.*0]]: memref<1xindex>, +// CHECK-SAME: %[[MEM_SIZE:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[POINTER:.*2]]: memref, +// CHECK-SAME: %[[INDICES:.*3]]: memref, +// CHECK-SAME: %[[VALUE:.*4]]: memref, +// CHECK-SAME: %[[TMP_arg5:.*5]]: index, +// CHECK-SAME: %[[TMP_arg6:.*6]]: index, +// CHECK-SAME: %[[TMP_arg7:.*7]]: index +// CHECK: %[[TMP_0:.*]]:5 = scf.for %[[TMP_arg8:.*]] = %[[TMP_arg5]] to %[[TMP_arg6]] step %[[TMP_arg7]] iter_args( +// CHECK-SAME: %[[TMP_arg9:.*]] = %[[DIM_SIZE]], +// CHECK-SAME: %[[TMP_arg10:.*]] = %[[MEM_SIZE]], +// CHECK-SAME: %[[TMP_arg11:.*]] = %[[POINTER]], +// CHECK-SAME: %[[TMP_arg12:.*]] = %[[INDICES]], +// CHECK-SAME: %[[TMP_arg13:.*]] = %[[VALUE]]) +// CHECK: scf.yield %[[TMP_arg9]], %[[TMP_arg10]], %[[TMP_arg11]], %[[TMP_arg12]], %[[TMP_arg13]] : memref<1xindex>, memref<3xindex>, memref, memref, memref +// CHECK: } +// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2, %[[TMP_0]]#3, %[[TMP_0]]#4 : memref<1xindex>, memref<3xindex>, memref, memref, memref +func.func @for(%in: tensor<1024xf32, #SparseVector>, + %lb: index, %ub: index, %step: index) -> tensor<1024xf32, #SparseVector> { + %1 = scf.for %i = %lb to %ub step %step iter_args(%vin = %in) + -> tensor<1024xf32, #SparseVector> { + scf.yield %vin : tensor<1024xf32, #SparseVector> + } + return %1 : tensor<1024xf32, #SparseVector> +}