diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp @@ -68,9 +68,11 @@ spirv::VariableOp varOp; if (adaptor.getTensor().getDefiningOp()) { - varOp = rewriter.create( - loc, varType, spirv::StorageClass::Function, - /*initializer=*/adaptor.getTensor()); + varOp = rewriter.create(loc, varType, + spirv::StorageClass::Function, + /*initializer=*/nullptr); + spirv::StoreOp storeOp = + rewriter.create(loc, varOp, adaptor.getTensor()); } else { // Need to store the value to the local variable. It's questionable // whether we want to support such case though. diff --git a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir --- a/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir +++ b/mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir @@ -9,7 +9,8 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 { // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> %cst = arith.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> - // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr, Function> + // CHECK: %[[VAR:.+]] = spv.Variable : !spv.ptr, Function> + // CHECK: spv.Store "Function" %[[VAR]], %[[CST]] : !spv.array<12 x i32> // CHECK: %[[C0:.+]] = spv.Constant 0 : i32 // CHECK: %[[C6:.+]] = spv.Constant 6 : i32 // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32