diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -291,18 +291,28 @@ ConversionPatternRewriter &rewriter) const override { ValueRange operands = adaptor.getOperands(); - // If the region is return values, store each value into the associated + Operation *parent = terminatorOp->getParentOp(); + + // TODO: Implement conversion for the remaining `scf` ops. + if (parent->getDialect()->getNamespace() == + scf::SCFDialect::getDialectNamespace() && + !isa(parent)) + return rewriter.notifyMatchFailure( + terminatorOp, + llvm::formatv("conversion not supported for parent op: '{0}'", + parent->getName())); + + // If the region return values, store each value into the associated // VariableOp created during lowering of the parent region. if (!operands.empty()) { - auto &allocas = - scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; + auto &allocas = scfToSPIRVContext->outputVars[parent]; if (allocas.size() != operands.size()) return failure(); auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) rewriter.create(loc, allocas[i], operands[i]); - if (isa(terminatorOp->getParentOp())) { + if (isa(parent)) { // For loops we also need to update the branch jumping back to the // header. auto br = cast( diff --git a/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToSPIRV/unsupported.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s + +// `scf.parallel` conversion is not supported yet. +// Make sure that we do not accidentally invalidate this functio by removing +// `scf.yield`. +// CHECK-LABEL: func.func @func +// CHECK: scf.parallel +// CHECK-NEXT: spirv.Constant +// CHECK-NEXT: memref.store +// CHECK-NEXT: scf.yield +// CHECK: spirv.Return +func.func @func(%arg0: i64) { + %0 = arith.index_cast %arg0 : i64 to index + %alloc = memref.alloc() : memref<16xf32> + scf.parallel (%arg1) = (%0) to (%0) step (%0) { + %cst = arith.constant 1.000000e+00 : f32 + memref.store %cst, %alloc[%arg1] : memref<16xf32> + scf.yield + } + return +}