diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -91,13 +91,13 @@ static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, SPIRVTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, - ScfToSPIRVContextImpl *scfToSPIRVContext) { + ScfToSPIRVContextImpl *scfToSPIRVContext, + ArrayRef returnTypes) { Location loc = scfOp.getLoc(); auto &allocas = scfToSPIRVContext->outputVars[newOp]; SmallVector resultValue; - for (Value result : scfOp.results()) { - auto convertedType = typeConverter.convertType(result.getType()); + for (Type convertedType : returnTypes) { auto pointerType = spirv::PointerType::get(convertedType, spirv::StorageClass::Function); rewriter.setInsertionPoint(newOp); @@ -185,8 +185,15 @@ loc, newIndVar.getType(), newIndVar, forOperands.step()); rewriter.create(loc, header, updatedIndVar); + // Infer the return types from the init operands. Vector type may get + // converted to CooperativeMatrix or to Vector type, to avoid having complex + // extra logic to figure out the right type we just infer it from the Init + // operands. + SmallVector initTypes; + for (auto arg : forOperands.initArgs()) + initTypes.push_back(arg.getType()); replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter, - scfToSPIRVContext); + scfToSPIRVContext, initTypes); return success(); } @@ -238,8 +245,13 @@ thenBlock, ArrayRef(), elseBlock, ArrayRef()); + SmallVector returnTypes; + for (auto result : ifOp.results()) { + auto convertedType = typeConverter.convertType(result.getType()); + returnTypes.push_back(convertedType); + } replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter, - scfToSPIRVContext); + scfToSPIRVContext, returnTypes); return success(); }